From 961e3acfeff4b0655bae8a652535cff8f6131586 Mon Sep 17 00:00:00 2001 From: gujinkai Date: Mon, 15 Apr 2024 16:50:41 +0800 Subject: [Feature][core] file format adapt to aes --- groot-core/pom.xml | 7 + .../AbstractSingleKnowledgeBaseHandler.java | 81 ++++++----- .../cn/AbstractSingleKnowledgeBaseHandlerTest.java | 155 +++++++++++++++++++++ .../core/udf/cn/HighCsvReaderTest.java | 25 ++++ .../core/udf/cn/LookupTestUtils.java | 4 + 5 files changed, 240 insertions(+), 32 deletions(-) create mode 100644 groot-core/src/test/java/com/geedgenetworks/core/udf/cn/AbstractSingleKnowledgeBaseHandlerTest.java create mode 100644 groot-core/src/test/java/com/geedgenetworks/core/udf/cn/HighCsvReaderTest.java diff --git a/groot-core/pom.xml b/groot-core/pom.xml index f19e4b1..08ccffe 100644 --- a/groot-core/pom.xml +++ b/groot-core/pom.xml @@ -13,6 +13,13 @@ + + org.mock-server + mockserver-netty + 5.11.2 + test + + org.mockito mockito-core diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/knowlegdebase/handler/AbstractSingleKnowledgeBaseHandler.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/knowlegdebase/handler/AbstractSingleKnowledgeBaseHandler.java index c460961..3869569 100644 --- a/groot-core/src/main/java/com/geedgenetworks/core/udf/knowlegdebase/handler/AbstractSingleKnowledgeBaseHandler.java +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/knowlegdebase/handler/AbstractSingleKnowledgeBaseHandler.java @@ -1,24 +1,16 @@ package com.geedgenetworks.core.udf.knowlegdebase.handler; -import com.alibaba.fastjson2.JSON; + import com.geedgenetworks.common.config.KnowledgeBaseConfig; import com.geedgenetworks.common.exception.CommonErrorCode; import com.geedgenetworks.common.exception.GrootStreamRuntimeException; import com.geedgenetworks.core.pojo.KnowLedgeBaseFileMeta; +import com.geedgenetworks.crypt.AESUtil; import lombok.Data; -import org.apache.http.HttpEntity; -import org.apache.http.client.methods.CloseableHttpResponse; -import org.apache.http.client.methods.HttpGet; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; -import org.apache.http.util.EntityUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; - /** * @author gujinkai * @version 1.0 @@ -33,6 +25,8 @@ public abstract class AbstractSingleKnowledgeBaseHandler extends AbstractKnowled protected KnowLedgeBaseFileMeta knowledgeMetedataCache; private static final CloseableHttpClient HTTP_CLIENT = HttpClients.createMinimal(); + private static final String AES_KEY = "86cf0e2ffba3f541a6c6761313e5cc7e"; + @Override public boolean initKnowledgeBase(KnowledgeBaseConfig knowledgeBaseConfig) { this.knowledgeBaseConfig = knowledgeBaseConfig; @@ -56,14 +50,54 @@ public abstract class AbstractSingleKnowledgeBaseHandler extends AbstractKnowled protected abstract Boolean buildKnowledgeBase(); + /** + * 下载文件 + * 在decrypt方法中解密,并在其中处理了文件下载异常后返回null的情况 + * + * @return byte[] + */ public byte[] downloadFile() { - if ("http".equals(knowledgeBaseConfig.getFsType())) { - return downloadFile(knowledgeMetedataCache.getPath(), knowledgeMetedataCache.getIsValid()); + byte[] data; + switch (knowledgeBaseConfig.getFsType()) { + case "http": + data = downloadFile(knowledgeMetedataCache.getPath(), knowledgeMetedataCache.getIsValid()); + break; + case "local": + data = getFileFromLocal(knowledgeBaseConfig.getFsPath() + knowledgeBaseConfig.getFiles().get(0)); + break; + default: + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, knowledgeBaseConfig.getFsType() + " is illegal"); } - if ("local".equals(knowledgeBaseConfig.getFsType())) { - return getFileFromLocal(knowledgeBaseConfig.getFsPath() + knowledgeBaseConfig.getFiles().get(0)); + return decrypt(data); + } + + /** + * 解密 + * 支持的文件格式: csv、aes + * + * @param data byte[] + * @return byte[] + */ + private byte[] decrypt(byte[] data) { + byte[] result = new byte[0]; + try { + if (data == null) { + data = new byte[0]; + } + switch (knowledgeMetedataCache.getFormat()) { + case "aes": + result = AESUtil.decrypt(data, AES_KEY); + break; + case "csv": + result = data; + break; + default: + logger.error("unknown format: " + knowledgeMetedataCache.getFormat()); + } + } catch (Exception e) { + logger.error("decrypt error", e); } - throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, knowledgeBaseConfig.getFsType() + " is illegal"); + return result; } protected Boolean ifNeedUpdate() { @@ -83,23 +117,6 @@ public abstract class AbstractSingleKnowledgeBaseHandler extends AbstractKnowled } } - public List getMetadata(String url) { - final HttpGet httpGet = new HttpGet(url); - httpGet.addHeader("Accept", "application/json"); - try { - CloseableHttpResponse response = HTTP_CLIENT.execute(httpGet); - HttpEntity entity = response.getEntity(); - if (entity != null) { - String content = EntityUtils.toString(entity, "UTF-8"); - KnowledgeResponse knowledgeResponse = JSON.parseObject(content, KnowledgeResponse.class); - return JSON.parseArray(knowledgeResponse.data, KnowLedgeBaseFileMeta.class).stream().filter(metadata -> "latest".equals(metadata.getVersion()) && metadata.getIsValid() == 1).collect(Collectors.toList()); - } - } catch (IOException e) { - logger.error("fetch knowledge metadata error", e); - } - return Collections.singletonList(null); - } - @Data private static final class KnowledgeResponse { private int status; diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/AbstractSingleKnowledgeBaseHandlerTest.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/AbstractSingleKnowledgeBaseHandlerTest.java new file mode 100644 index 0000000..e259654 --- /dev/null +++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/AbstractSingleKnowledgeBaseHandlerTest.java @@ -0,0 +1,155 @@ +package com.geedgenetworks.core.udf.cn; + +import com.alibaba.fastjson2.JSON; +import com.geedgenetworks.common.config.KnowledgeBaseConfig; +import com.geedgenetworks.core.pojo.KnowLedgeBaseFileMeta; +import com.geedgenetworks.core.udf.knowlegdebase.handler.AbstractSingleKnowledgeBaseHandler; +import com.geedgenetworks.crypt.AESUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockserver.client.MockServerClient; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.HttpRequest; +import org.mockserver.model.HttpResponse; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class AbstractSingleKnowledgeBaseHandlerTest { + + private static KnowledgeBaseConfig knowledgeBaseConfig; + + private ClientAndServer mockGatewayServer; + + private MockServerClient mockHosServer; + + @BeforeEach + void beforeEach() { + knowledgeBaseConfig = new KnowledgeBaseConfig(); + knowledgeBaseConfig.setFsPath("http://localhost:9999/v1/knowledge_base"); + knowledgeBaseConfig.setFsType("http"); + knowledgeBaseConfig.setFiles(List.of("1")); + } + + @Test + void downloadCsvFile() { + KnowLedgeBaseFileMeta knowLedgeBaseFileMeta = new KnowLedgeBaseFileMeta(); + knowLedgeBaseFileMeta.setPath("http://localhost:9098/hos/knowledge_base_bucket/1_latest"); + knowLedgeBaseFileMeta.setIsValid(1); + knowLedgeBaseFileMeta.setFormat("csv"); + knowLedgeBaseFileMeta.setVersion("latest"); + Map> gatewayResponse = new HashMap<>(); + gatewayResponse.put("data", List.of(knowLedgeBaseFileMeta)); + + mockGatewayServer = ClientAndServer.startClientAndServer(9999); + MockServerClient gatewayClient = new MockServerClient("localhost", 9999); + + // 定义 MockServer 的行为 + gatewayClient.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/knowledge_base") + .withQueryStringParameter("kb_id", "1") + ).respond( + HttpResponse.response() + .withStatusCode(200) + .withBody(JSON.toJSONString(gatewayResponse)) + ); + + mockHosServer = ClientAndServer.startClientAndServer(9098); + MockServerClient hosClient = new MockServerClient("localhost", 9098); + + // 定义 MockServer 的行为 + hosClient.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/hos/knowledge_base_bucket/1_latest") + ).respond( + HttpResponse.response() + .withStatusCode(200) + .withBody("test") + ); + + AbstractSingleKnowledgeBaseHandler baseHandler = new AbstractSingleKnowledgeBaseHandler() { + @Override + protected Boolean buildKnowledgeBase() { + byte[] bytes = downloadFile(); + assertEquals("test", new String(bytes)); + return true; + } + + @Override + public void close() { + + } + }; + baseHandler.initKnowledgeBase(knowledgeBaseConfig); + } + + @Test + void downloadAesFile() throws Exception { + KnowLedgeBaseFileMeta knowLedgeBaseFileMeta = new KnowLedgeBaseFileMeta(); + knowLedgeBaseFileMeta.setPath("http://localhost:9098/hos/knowledge_base_bucket/1_latest"); + knowLedgeBaseFileMeta.setIsValid(1); + knowLedgeBaseFileMeta.setFormat("aes"); + knowLedgeBaseFileMeta.setVersion("latest"); + Map> gatewayResponse = new HashMap<>(); + gatewayResponse.put("data", List.of(knowLedgeBaseFileMeta)); + + mockGatewayServer = ClientAndServer.startClientAndServer(9999); + MockServerClient gatewayClient = new MockServerClient("localhost", 9999); + + // 定义 MockServer 的行为 + gatewayClient.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/v1/knowledge_base") + .withQueryStringParameter("kb_id", "1") + ).respond( + HttpResponse.response() + .withStatusCode(200) + .withBody(JSON.toJSONString(gatewayResponse)) + ); + + mockHosServer = ClientAndServer.startClientAndServer(9098); + MockServerClient hosClient = new MockServerClient("localhost", 9098); + + // 定义 MockServer 的行为 + hosClient.when( + HttpRequest.request() + .withMethod("GET") + .withPath("/hos/knowledge_base_bucket/1_latest") + ).respond( + HttpResponse.response() + .withStatusCode(200) + .withBody(AESUtil.encrypt("test".getBytes(), "86cf0e2ffba3f541a6c6761313e5cc7e")) + ); + + AbstractSingleKnowledgeBaseHandler baseHandler = new AbstractSingleKnowledgeBaseHandler() { + @Override + protected Boolean buildKnowledgeBase() { + byte[] bytes = downloadFile(); + assertEquals("test", new String(bytes)); + return true; + } + + @Override + public void close() { + + } + }; + baseHandler.initKnowledgeBase(knowledgeBaseConfig); + } + + @AfterEach + void afterEach() { + mockGatewayServer.stop(); + mockGatewayServer = null; + mockHosServer.stop(); + mockHosServer = null; + } +} \ No newline at end of file diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/HighCsvReaderTest.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/HighCsvReaderTest.java new file mode 100644 index 0000000..fdb61f8 --- /dev/null +++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/HighCsvReaderTest.java @@ -0,0 +1,25 @@ +package com.geedgenetworks.core.udf.cn; + +import com.geedgenetworks.core.utils.cn.csv.HighCsvReader; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.List; + +class HighCsvReaderTest { + + @Test + void inputTest() { + List needColumns = new ArrayList<>(); + needColumns.add("test"); + byte[] content = new byte[0]; + HighCsvReader highCsvReader = new HighCsvReader(new InputStreamReader(new ByteArrayInputStream(content)), needColumns); + System.out.println(highCsvReader.getLineNumber()); + HighCsvReader.CsvIterator iterator = highCsvReader.getIterator(); + while (iterator.hasNext()) { + System.out.println(iterator.next()); + } + } +} \ No newline at end of file diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/LookupTestUtils.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/LookupTestUtils.java index b70edcc..200b420 100644 --- a/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/LookupTestUtils.java +++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/cn/LookupTestUtils.java @@ -37,6 +37,8 @@ public class LookupTestUtils { private static String fsType = "http"; private static int isValid = 1; + + private static String format = "csv"; private static List fsFiles = Arrays.asList("testFile"); public static String kbName = "testKbName"; private static String downloadPath = "testDownloadPath"; @@ -80,6 +82,7 @@ public class LookupTestUtils { KnowLedgeBaseFileMeta knowLedgeBaseFileMeta = new KnowLedgeBaseFileMeta(); knowLedgeBaseFileMeta.setPath(downloadPath); knowLedgeBaseFileMeta.setIsValid(isValid); + knowLedgeBaseFileMeta.setFormat(format); abstractKnowledgeBaseHandlerMockedStatic.when(() -> AbstractKnowledgeBaseHandler.getMetadata(fsType, fsPath, fsFiles.get(0))).thenReturn(knowLedgeBaseFileMeta); abstractKnowledgeBaseHandlerMockedStatic.when(() -> AbstractKnowledgeBaseHandler.downloadFile(downloadPath, 1)).thenReturn(downloadContent.getBytes()); } @@ -90,6 +93,7 @@ public class LookupTestUtils { KnowLedgeBaseFileMeta.setKb_id("1"); KnowLedgeBaseFileMeta.setPath(downloadPath); KnowLedgeBaseFileMeta.setIsValid(isValid); + KnowLedgeBaseFileMeta.setFormat(format); Map KnowLedgeBaseFileMetaMap = new HashMap<>(); KnowLedgeBaseFileMetaMap.put("1", KnowLedgeBaseFileMeta); abstractMultipleKnowledgeBaseHandlerMockedStatic.when(() -> AbstractMultipleKnowledgeBaseHandler.getMetadata(fsPath)).thenReturn(KnowLedgeBaseFileMetaMap); -- cgit v1.2.3