diff options
| author | 侯晋川 <[email protected]> | 2024-10-28 18:03:42 +0800 |
|---|---|---|
| committer | 侯晋川 <[email protected]> | 2024-10-28 18:03:42 +0800 |
| commit | 8055b40a031833562308e7d7fcae9c923eec9880 (patch) | |
| tree | 1a8fbb8cefc3ba14245d207d17705536829c08b8 | |
| parent | 152fa429b30717cbb5964973f43c8ca0d5a22218 (diff) | |
[feature][core] 优化Encrypt和HMAC函数。新增Encrypt单元测试feature/udf-encrypt
14 files changed, 326 insertions, 74 deletions
diff --git a/config/grootstream_job_example.yaml b/config/grootstream_job_example.yaml index 8f27609..8c7a1b1 100644 --- a/config/grootstream_job_example.yaml +++ b/config/grootstream_job_example.yaml @@ -66,6 +66,8 @@ application: env: name: example-inline-to-print parallelism: 3 + shade.identifier: sm4 + kms.type: vault pipeline: object-reuse: true execution: @@ -76,6 +78,7 @@ application: hos.bucket.name.http_file: traffic_http_file_bucket hos.bucket.name.eml_file: traffic_eml_file_bucket hos.bucket.name.policy_capture_file: traffic_policy_capture_file_bucket + projection.encrypt.schema.registry.uri: 192.168.44.12:9999/v1/schema/session_record?option=encrypt_fields topology: - name: inline_source downstream: [decoded_as_split] diff --git a/groot-bootstrap/src/main/java/com/geedgenetworks/bootstrap/utils/EnvironmentUtil.java b/groot-bootstrap/src/main/java/com/geedgenetworks/bootstrap/utils/EnvironmentUtil.java index 13db3d4..8028608 100644 --- a/groot-bootstrap/src/main/java/com/geedgenetworks/bootstrap/utils/EnvironmentUtil.java +++ b/groot-bootstrap/src/main/java/com/geedgenetworks/bootstrap/utils/EnvironmentUtil.java @@ -1,8 +1,10 @@ package com.geedgenetworks.bootstrap.utils; import com.geedgenetworks.bootstrap.execution.ExecutionConfigKeyName; +import com.geedgenetworks.common.Constants; import com.geedgenetworks.common.config.CheckResult; import com.typesafe.config.Config; +import com.typesafe.config.ConfigUtil; import com.typesafe.config.ConfigValue; import lombok.extern.slf4j.Slf4j; import org.apache.flink.api.common.ExecutionConfig; @@ -16,7 +18,7 @@ import java.util.concurrent.TimeUnit; @Slf4j public final class EnvironmentUtil { - private EnvironmentUtil() { + private EnvironmentUtil() { throw new UnsupportedOperationException("EnvironmentUtil is a utility class and cannot be instantiated"); } @@ -30,10 +32,13 @@ public final class EnvironmentUtil { configuration.setString( PipelineOptions.CLASSPATHS.key(), pipeline.getString("classpaths")); } - if(pipeline.hasPath("object-reuse")) { + if (pipeline.hasPath("object-reuse")) { configuration.setBoolean(PipelineOptions.OBJECT_REUSE.key(), pipeline.getBoolean("object-reuse")); } } + if (envConfig.hasPath(ConfigUtil.joinPath(Constants.SYSPROP_KMS_TYPE_CONFIG))) { + configuration.setString(Constants.SYSPROP_KMS_TYPE_CONFIG, envConfig.getString(ConfigUtil.joinPath(Constants.SYSPROP_KMS_TYPE_CONFIG))); + } String prefixConf = "flink."; if (!envConfig.isEmpty()) { for (Map.Entry<String, ConfigValue> entryConfKey : envConfig.entrySet()) { @@ -117,5 +122,4 @@ public final class EnvironmentUtil { } - } diff --git a/groot-common/src/main/java/com/geedgenetworks/common/Constants.java b/groot-common/src/main/java/com/geedgenetworks/common/Constants.java index b523591..27ce8fb 100644 --- a/groot-common/src/main/java/com/geedgenetworks/common/Constants.java +++ b/groot-common/src/main/java/com/geedgenetworks/common/Constants.java @@ -2,7 +2,7 @@ package com.geedgenetworks.common; public final class Constants { - public static final String DEFAULT_JOB_NAME="groot-stream-job"; + public static final String DEFAULT_JOB_NAME = "groot-stream-job"; public static final String SOURCES = "sources"; public static final String FILTERS = "filters"; public static final String PREPROCESSING_PIPELINES = "preprocessing_pipelines"; @@ -14,7 +14,7 @@ public final class Constants { public static final String PROPERTIES = "properties"; public static final String SPLITS = "splits"; - public static final String APPLICATION_ENV ="env"; + public static final String APPLICATION_ENV = "env"; public static final String APPLICATION_TOPOLOGY = "topology"; public static final String JOB_NAME = "name"; public static final String GROOT_LOGO = "\n" + @@ -49,6 +49,8 @@ public final class Constants { public static final String SLIDING_PROCESSING_TIME = "sliding_processing_time"; public static final String SLIDING_EVENT_TIME = "sliding_event_time"; - + public static final String SYSPROP_KMS_TYPE_CONFIG = "kms.type"; + public static final String SYSPROP_ENCRYPT_KMS_KEY_SCHEDULER_INTERVAL_NAME = "scheduler.encrypt.update.kms.key.minutes"; + public static final String SYSPROP_ENCRYPT_SENSITIVE_FIELDS_SCHEDULER_INTERVAL_NAME = "scheduler.encrypt.update.sensitive.fields.minutes"; } diff --git a/groot-common/src/main/java/com/geedgenetworks/common/config/CommonConfigDomProcessor.java b/groot-common/src/main/java/com/geedgenetworks/common/config/CommonConfigDomProcessor.java index 51e2ff0..b3b17e8 100644 --- a/groot-common/src/main/java/com/geedgenetworks/common/config/CommonConfigDomProcessor.java +++ b/groot-common/src/main/java/com/geedgenetworks/common/config/CommonConfigDomProcessor.java @@ -117,8 +117,6 @@ public class CommonConfigDomProcessor extends AbstractDomConfigProcessor { String name = cleanNodeName(node); if (CommonConfigOptions.KMS_TYPE.key().equals(name)) { kmsConfig.setType(getTextContent(node)); - } else if (CommonConfigOptions.KMS_SECRET_KEY.key().equals(name)) { - kmsConfig.setSecretKey(getTextContent(node)); } else if (CommonConfigOptions.KMS_URL.key().equals(name)) { kmsConfig.setUrl(getTextContent(node)); } else if (CommonConfigOptions.KMS_USERNAME.key().equals(name)) { diff --git a/groot-common/src/main/java/com/geedgenetworks/common/config/CommonConfigOptions.java b/groot-common/src/main/java/com/geedgenetworks/common/config/CommonConfigOptions.java index 1c3f4d0..167fcba 100644 --- a/groot-common/src/main/java/com/geedgenetworks/common/config/CommonConfigOptions.java +++ b/groot-common/src/main/java/com/geedgenetworks/common/config/CommonConfigOptions.java @@ -70,11 +70,6 @@ public class CommonConfigOptions { .defaultValue("local") .withDescription("The type of KMS."); - public static final Option<String> KMS_SECRET_KEY = Options.key("secret_key") - .stringType() - .defaultValue("") - .withDescription("The type of KMS."); - public static final Option<String> KMS_URL = Options.key("url") .stringType() .defaultValue("") diff --git a/groot-common/src/main/java/com/geedgenetworks/common/config/KmsConfig.java b/groot-common/src/main/java/com/geedgenetworks/common/config/KmsConfig.java index 75a5b4c..f0e213f 100644 --- a/groot-common/src/main/java/com/geedgenetworks/common/config/KmsConfig.java +++ b/groot-common/src/main/java/com/geedgenetworks/common/config/KmsConfig.java @@ -7,7 +7,6 @@ import java.io.Serializable; @Data public class KmsConfig implements Serializable { private String type = CommonConfigOptions.KMS_TYPE.defaultValue(); - private String secretKey = CommonConfigOptions.KMS_TYPE.defaultValue(); private String url = CommonConfigOptions.KMS_URL.defaultValue(); private String username = CommonConfigOptions.KMS_USERNAME.defaultValue(); private String password = CommonConfigOptions.KMS_PASSWORD.defaultValue(); diff --git a/groot-common/src/main/resources/grootstream.yaml b/groot-common/src/main/resources/grootstream.yaml index d7818ab..26752e3 100644 --- a/groot-common/src/main/resources/grootstream.yaml +++ b/groot-common/src/main/resources/grootstream.yaml @@ -13,9 +13,8 @@ grootstream: - 004390bc-3135-4a6f-a492-3662ecb9e289 kms: -# local: -# type: local -# secret_key: .geedgenetworks. + local: + type: local vault: type: vault url: https://192.168.40.223:8200 diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java index cc05397..b20ff18 100644 --- a/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java @@ -1,6 +1,9 @@ package com.geedgenetworks.core.udf; -import cn.hutool.core.util.ArrayUtil; +import cn.hutool.core.util.URLUtil; +import cn.hutool.json.JSONArray; +import cn.hutool.json.JSONObject; +import cn.hutool.json.JSONUtil; import com.alibaba.fastjson2.JSON; import com.geedgenetworks.common.Constants; import com.geedgenetworks.common.Event; @@ -15,13 +18,17 @@ import com.geedgenetworks.core.pojo.KmsKey; import com.geedgenetworks.core.udf.encrypt.EncryptionAlgorithm; import com.geedgenetworks.core.utils.*; import com.geedgenetworks.utils.StringUtil; -import io.github.jopenlibs.vault.VaultException; import lombok.extern.slf4j.Slf4j; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.configuration.Configuration; +import java.io.IOException; +import java.net.URI; import java.util.Arrays; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; @Slf4j public class Encrypt implements ScalarFunction { @@ -31,7 +38,7 @@ public class Encrypt implements ScalarFunction { private String identifier; private String defaultVal; private String type; - private transient SingleValueMap.Data<LoadIntervalDataUtil<String[]>> sensitiveFieldsData; + private transient SingleValueMap.Data<LoadIntervalDataUtil<Set<String>>> sensitiveFieldsData; private transient SingleValueMap.Data<LoadIntervalDataUtil<KmsKey>> kmsKeyData; private transient EncryptionAlgorithm encryptionAlgorithm; @@ -44,35 +51,37 @@ public class Encrypt implements ScalarFunction { this.lookupFieldName = udfContext.getLookup_fields().get(0); this.outputFieldName = udfContext.getOutput_fields().get(0); this.identifier = udfContext.getParameters().get("identifier").toString(); - Configuration configuration = (Configuration) runtimeContext.getExecutionConfig().getGlobalJobParameters(); CommonConfig commonConfig = JSON.parseObject(configuration.toMap().get(Constants.SYSPROP_GROOTSTREAM_CONFIG), CommonConfig.class); - Map<String, KmsConfig> kmsConfigs = commonConfig.getKmsConfig(); - if (kmsConfigs.isEmpty()) { - throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms type is not null!"); - } else if (kmsConfigs.size() > 1) { - throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms type is repeated!"); - } - KmsConfig kmsConfig = kmsConfigs.values().iterator().next(); + KmsConfig kmsConfig = commonConfig.getKmsConfig().get(configuration.toMap().get(Constants.SYSPROP_KMS_TYPE_CONFIG)); SSLConfig sslConfig = commonConfig.getSslConfig(); Map<String, String> propertiesConfig = commonConfig.getPropertiesConfig(); type = kmsConfig.getType(); try { - encryptionAlgorithm = EncryptionAlgorithmUtils.getEncryptionAlgorithm(identifier); + encryptionAlgorithm = EncryptionAlgorithmUtils.createEncryptionAlgorithm(identifier); if (encryptionAlgorithm == null) { throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Parameters identifier is illegal!"); } - kmsKeyData = SingleValueMap.acquireData("kmsKeyData", - () -> LoadIntervalDataUtil.newInstance(() -> getKmsKey(kmsConfig, sslConfig, identifier), - LoadIntervalDataOptions.defaults("kmsKeyData", 60000)), LoadIntervalDataUtil::stop); - sensitiveFieldsData = SingleValueMap.acquireData("sensitiveFields", - () -> LoadIntervalDataUtil.newInstance(() -> getEncryptFields(propertiesConfig.get("projection.encrypt.schema.registry.uri")), - LoadIntervalDataOptions.defaults("sensitiveFields", 60000)), LoadIntervalDataUtil::stop); - KmsKey kmsKey = kmsKeyData.getData().data(); - if (encryptionAlgorithm.getSecretKeyLength() == kmsKey.getKeyData().length) { + if (!type.equals(KmsUtils.KMS_TYPE_LOCAL)) { + kmsKeyData = SingleValueMap.acquireData("kmsKeyData", + () -> LoadIntervalDataUtil.newInstance(() -> KmsUtils.getVaultKey(kmsConfig, sslConfig, identifier), + LoadIntervalDataOptions.defaults("kmsKeyData", Integer.parseInt(propertiesConfig.getOrDefault(Constants.SYSPROP_ENCRYPT_KMS_KEY_SCHEDULER_INTERVAL_NAME, "5")) * 60000L)), + LoadIntervalDataUtil::stop); + KmsKey kmsKey = kmsKeyData.getData().data(); + if (kmsKey == null) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Initialization UDF Encrypt failed!"); + } + if (encryptionAlgorithm.getSecretKeyLength() != kmsKey.getKeyData().length) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms secret Key requires " + encryptionAlgorithm.getSecretKeyLength() + " bytes!"); + } encryptionAlgorithm.setKmsKey(kmsKey); - } else { - throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms secret Key requires " + encryptionAlgorithm.getSecretKeyLength() + " bytes!"); + } + sensitiveFieldsData = SingleValueMap.acquireData("sensitiveFields", + () -> LoadIntervalDataUtil.newInstance(() -> getSensitiveFields(propertiesConfig.get("projection.encrypt.schema.registry.uri")), + LoadIntervalDataOptions.defaults("sensitiveFields", Integer.parseInt(propertiesConfig.getOrDefault(Constants.SYSPROP_ENCRYPT_SENSITIVE_FIELDS_SCHEDULER_INTERVAL_NAME, "5")) * 60000L)), + LoadIntervalDataUtil::stop); + if (sensitiveFieldsData.getData().data() == null) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Initialization UDF Encrypt failed!"); } } catch (Exception e) { throw new GrootStreamRuntimeException(CommonErrorCode.UNSUPPORTED_OPERATION, "Initialization UDF Encrypt failed!", e); @@ -82,11 +91,13 @@ public class Encrypt implements ScalarFunction { @Override public Event evaluate(Event event) { try { - KmsKey kmsKey = kmsKeyData.getData().data(); - if (kmsKey.getKeyVersion() != encryptionAlgorithm.getKmsKey().getKeyVersion() || !Arrays.equals(kmsKey.getKeyData(), encryptionAlgorithm.getKmsKey().getKeyData())) { - encryptionAlgorithm.setKmsKey(kmsKey); + if (!type.equals(KmsUtils.KMS_TYPE_LOCAL)) { + KmsKey kmsKey = kmsKeyData.getData().data(); + if (kmsKey.getKeyVersion() != encryptionAlgorithm.getKmsKey().getKeyVersion() || !Arrays.equals(kmsKey.getKeyData(), encryptionAlgorithm.getKmsKey().getKeyData())) { + encryptionAlgorithm.setKmsKey(kmsKey); + } } - if (ArrayUtil.contains(sensitiveFieldsData.getData().data(), lookupFieldName) && event.getExtractedFields().containsKey(lookupFieldName)) { + if (sensitiveFieldsData.getData().data().contains(lookupFieldName) && event.getExtractedFields().containsKey(lookupFieldName)) { String value = (String) event.getExtractedFields().get(lookupFieldName); if (StringUtil.isNotBlank(value)) { String encryptResult = encryptionAlgorithm.encrypt(value); @@ -136,24 +147,18 @@ public class Encrypt implements ScalarFunction { } } - private KmsKey getKmsKey(KmsConfig kmsConfig, SSLConfig sslConfig, String identifier) throws VaultException { - KmsKey kmsKey = null; - if (KmsUtils.KMS_TYPE_VAULT.equals(kmsConfig.getType())) { - kmsKey = KmsUtils.getVaultKey(kmsConfig, sslConfig, identifier); - } else if (KmsUtils.KMS_TYPE_LOCAL.equals(kmsConfig.getType())) { - kmsKey = new KmsKey(kmsConfig.getSecretKey().getBytes(), 1); + public Set<String> getSensitiveFields(String url) throws IOException { + Set<String> sensitiveFieldsSet; + String sensitiveFieldsStr = HttpClientPoolUtil.getInstance().httpGet(URI.create(URLUtil.normalize(url))); + JSONObject sensitiveFieldsJson = JSONUtil.parseObj(sensitiveFieldsStr); + if (sensitiveFieldsJson.getInt("status", 500) == 200) { + JSONArray sensitiveFieldsJsonArr = sensitiveFieldsJson.getJSONArray("data"); + sensitiveFieldsSet = IntStream.range(0, sensitiveFieldsJsonArr.size()) + .mapToObj(sensitiveFieldsJsonArr::getStr) + .collect(Collectors.toSet()); + } else { + throw new GrootStreamRuntimeException(CommonErrorCode.UNSUPPORTED_OPERATION, "Get encrypt fields error! Error message: " + sensitiveFieldsStr); } - return kmsKey; - } - - private String[] getEncryptFields(String url) { - String[] encryptFields = new String[]{"phone_number", "server_ip"}; -// try { -// String s = HttpClientPoolUtil.getInstance().httpGet(URI.create(URLUtil.normalize(url))); -// encryptFields = s.split(","); -// } catch (Exception e) { -// log.error("Get encrypt fields error! " + e.getMessage()); -// } - return encryptFields; + return sensitiveFieldsSet; } } diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java index 74be5a8..db4369e 100644 --- a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java @@ -15,12 +15,14 @@ public class AES128GCM96Algorithm implements EncryptionAlgorithm { private static final int GCM_TAG_LENGTH = 128; private static final int GCM_96_NONCE_LENGTH = 12; private static final int SECRET_KEY_LENGTH = 16; + private static final byte[] DEFAULT_SECRET_KEY = ".geedgenetworks.".getBytes(); private final Cipher cipher; private KmsKey kmsKey; public AES128GCM96Algorithm() throws Exception { this.cipher = Cipher.getInstance(TRANSFORMATION); + this.kmsKey = new KmsKey(DEFAULT_SECRET_KEY, 1); } @Override @@ -66,13 +68,12 @@ public class AES128GCM96Algorithm implements EncryptionAlgorithm { String decryptedString = ""; try { byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH); - Cipher cipher = Cipher.getInstance(TRANSFORMATION); - GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); - cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); byte[] combined = Base64.getDecoder().decode(content); byte[] encryptedBytes = new byte[combined.length - GCM_96_NONCE_LENGTH]; System.arraycopy(combined, 0, nonce, 0, GCM_96_NONCE_LENGTH); System.arraycopy(combined, GCM_96_NONCE_LENGTH, encryptedBytes, 0, encryptedBytes.length); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); + cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); byte[] decryptedBytes = cipher.doFinal(encryptedBytes); decryptedString = new String(decryptedBytes); } catch (Exception e) { diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java index 64d88d9..dec7e01 100644 --- a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java @@ -15,12 +15,14 @@ public class AES256GCM96Algorithm implements EncryptionAlgorithm { private static final int GCM_TAG_LENGTH = 128; private static final int GCM_96_NONCE_LENGTH = 12; private static final int SECRET_KEY_LENGTH = 32; + private static final byte[] DEFAULT_SECRET_KEY = ".........geedgenetworks.........".getBytes(); private final Cipher cipher; private KmsKey kmsKey; public AES256GCM96Algorithm() throws Exception { this.cipher = Cipher.getInstance(TRANSFORMATION); + this.kmsKey = new KmsKey(DEFAULT_SECRET_KEY, 1); } @Override @@ -66,13 +68,12 @@ public class AES256GCM96Algorithm implements EncryptionAlgorithm { String decryptedString = ""; try { byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH); - Cipher cipher = Cipher.getInstance(TRANSFORMATION); - GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); - cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); byte[] combined = Base64.getDecoder().decode(content); byte[] encryptedBytes = new byte[combined.length - GCM_96_NONCE_LENGTH]; System.arraycopy(combined, 0, nonce, 0, GCM_96_NONCE_LENGTH); System.arraycopy(combined, GCM_96_NONCE_LENGTH, encryptedBytes, 0, encryptedBytes.length); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); + cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); byte[] decryptedBytes = cipher.doFinal(encryptedBytes); decryptedString = new String(decryptedBytes); } catch (Exception e) { diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java index 3c13820..e13cb40 100644 --- a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java @@ -15,12 +15,14 @@ public class SM4GCM96Algorithm implements EncryptionAlgorithm { private static final int GCM_TAG_LENGTH = 128; private static final int GCM_96_NONCE_LENGTH = 12; private static final int SECRET_KEY_LENGTH = 16; + private static final byte[] DEFAULT_SECRET_KEY = ".geedgenetworks.".getBytes(); private final Cipher cipher; private KmsKey kmsKey; public SM4GCM96Algorithm() throws Exception { this.cipher = Cipher.getInstance(TRANSFORMATION); + this.kmsKey = new KmsKey(DEFAULT_SECRET_KEY, 1); } @Override @@ -66,13 +68,12 @@ public class SM4GCM96Algorithm implements EncryptionAlgorithm { String decryptedString = ""; try { byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH); - Cipher cipher = Cipher.getInstance(TRANSFORMATION); - GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); - cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); byte[] combined = Base64.getDecoder().decode(content); byte[] encryptedBytes = new byte[combined.length - GCM_96_NONCE_LENGTH]; System.arraycopy(combined, 0, nonce, 0, GCM_96_NONCE_LENGTH); System.arraycopy(combined, GCM_96_NONCE_LENGTH, encryptedBytes, 0, encryptedBytes.length); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); + cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); byte[] decryptedBytes = cipher.doFinal(encryptedBytes); decryptedString = new String(decryptedBytes); } catch (Exception e) { diff --git a/groot-core/src/main/java/com/geedgenetworks/core/utils/EncryptionAlgorithmUtils.java b/groot-core/src/main/java/com/geedgenetworks/core/utils/EncryptionAlgorithmUtils.java index 7041c73..0a0fe33 100644 --- a/groot-core/src/main/java/com/geedgenetworks/core/utils/EncryptionAlgorithmUtils.java +++ b/groot-core/src/main/java/com/geedgenetworks/core/utils/EncryptionAlgorithmUtils.java @@ -15,7 +15,7 @@ public final class EncryptionAlgorithmUtils { public static final String ALGORITHM_AES_256_GCM96_NAME = "aes-256-gcm96"; public static final String ALGORITHM_SM4_GCM96_NAME = "sm4-gcm96"; - public static EncryptionAlgorithm getEncryptionAlgorithm(String identifier) throws Exception { + public static EncryptionAlgorithm createEncryptionAlgorithm(String identifier) throws Exception { switch (identifier) { case ALGORITHM_AES_128_GCM96_NAME: return new AES128GCM96Algorithm(); diff --git a/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java b/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java index 8e6a345..9519dd5 100644 --- a/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java +++ b/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java @@ -21,7 +21,7 @@ public class KmsUtils { public static final String KMS_TYPE_LOCAL = "local"; public static final String KMS_TYPE_VAULT = "vault"; - public static KmsKey getVaultKey(KmsConfig kmsConfig, SSLConfig sslConfig, String identifier) throws VaultException { + public static KmsKey getVaultKey(KmsConfig kmsConfig, SSLConfig sslConfig, String identifier) throws Exception { Vault vault = getVaultClient(kmsConfig, sslConfig); String exportKeyPath; if (EncryptionAlgorithmUtils.ALGORITHM_SM4_GCM96_NAME.equals(identifier)) { @@ -34,8 +34,7 @@ public class KmsUtils { JsonObject keys = exportResponse.getDataObject().get("keys").asObject(); return new KmsKey(Base64.getDecoder().decode(StrUtil.trim(keys.get(keys.size() + "").asString(), '"')), keys.size()); } else { - log.error("Get kms key error! code: {} body: {}", exportResponse.getRestResponse().getStatus(), new String(exportResponse.getRestResponse().getBody())); - return null; + throw new RuntimeException("Get vault key error! code: " + exportResponse.getRestResponse().getStatus() + " body: " + new String(exportResponse.getRestResponse().getBody())); } } diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/EncryptFunctionTest.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/EncryptFunctionTest.java new file mode 100644 index 0000000..a83d853 --- /dev/null +++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/EncryptFunctionTest.java @@ -0,0 +1,245 @@ +package com.geedgenetworks.core.udf.test.simple; + +import cn.hutool.core.util.RandomUtil; +import com.alibaba.fastjson2.JSON; +import com.geedgenetworks.common.Constants; +import com.geedgenetworks.common.Event; +import com.geedgenetworks.common.config.CommonConfig; +import com.geedgenetworks.common.config.KmsConfig; +import com.geedgenetworks.common.config.SSLConfig; +import com.geedgenetworks.common.exception.GrootStreamRuntimeException; +import com.geedgenetworks.common.udf.UDFContext; +import com.geedgenetworks.core.pojo.KmsKey; +import com.geedgenetworks.core.udf.Encrypt; +import com.geedgenetworks.core.udf.encrypt.EncryptionAlgorithm; +import com.geedgenetworks.core.utils.EncryptionAlgorithmUtils; +import com.geedgenetworks.core.utils.HttpClientPoolUtil; +import com.geedgenetworks.core.utils.KmsUtils; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.metrics.SimpleCounter; +import org.apache.flink.runtime.metrics.groups.OperatorMetricGroup; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.mockito.*; + +import java.io.IOException; +import java.security.Security; +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; + +public class EncryptFunctionTest { + private static UDFContext udfContext; + private static MockedStatic<HttpClientPoolUtil> httpClientPoolUtilMockedStatic; + private static final String DATA = "13812345678"; + + @BeforeAll + public static void setUp() throws IOException { + Security.addProvider(new BouncyCastleProvider()); + udfContext = new UDFContext(); + udfContext.setLookup_fields(Collections.singletonList("phone_number")); + udfContext.setOutput_fields(Collections.singletonList("phone_number")); + httpClientPoolUtilMockedStatic = mockSensitiveFields(); + } + + @AfterAll + public static void after() { + httpClientPoolUtilMockedStatic.close(); + } + + @Test + public void testEncryptByVault() throws Exception { + String secretKey = RandomUtil.randomString(32); + MockedStatic<KmsUtils> kmsUtilsMockedStatic = Mockito.mockStatic(KmsUtils.class); + Mockito.when(KmsUtils.getVaultKey(ArgumentMatchers.any(), ArgumentMatchers.any(), ArgumentMatchers.any())).thenReturn(new KmsKey(secretKey.getBytes(), 1)); + RuntimeContext runtimeContext = mockVaultRuntimeContext(); + Map<String, Object> map = new HashMap<>(); + map.put("identifier", EncryptionAlgorithmUtils.ALGORITHM_AES_256_GCM96_NAME); + udfContext.setParameters(map); + Encrypt encrypt = new Encrypt(); + encrypt.open(runtimeContext, udfContext); + Event event = new Event(); + Map<String, Object> extractedFields = new HashMap<>(); + extractedFields.put("phone_number", DATA); + event.setExtractedFields(extractedFields); + Event result = encrypt.evaluate(event); + EncryptionAlgorithm encryptionAlgorithm = EncryptionAlgorithmUtils.createEncryptionAlgorithm(EncryptionAlgorithmUtils.ALGORITHM_AES_256_GCM96_NAME); + assertNotNull(encryptionAlgorithm); + encryptionAlgorithm.setKmsKey(new KmsKey(secretKey.getBytes(), 1)); + String encrypted = result.getExtractedFields().get("phone_number").toString(); + assertTrue(encrypted.contains("vault:v1:")); + String decrypted = encryptionAlgorithm.decrypt(encrypted.split(":")[2]); + assertEquals(DATA, decrypted); + encrypt.close(); + kmsUtilsMockedStatic.close(); + } + + @Test + public void testEncryptByLocal() throws Exception { + byte[] secretKey = ".........geedgenetworks.........".getBytes(); + RuntimeContext runtimeContext = mockLocalRuntimeContext(); + Map<String, Object> map = new HashMap<>(); + map.put("identifier", EncryptionAlgorithmUtils.ALGORITHM_AES_256_GCM96_NAME); + udfContext.setParameters(map); + Encrypt encrypt = new Encrypt(); + encrypt.open(runtimeContext, udfContext); + Event event = new Event(); + Map<String, Object> extractedFields = new HashMap<>(); + extractedFields.put("phone_number", DATA); + event.setExtractedFields(extractedFields); + Event result = encrypt.evaluate(event); + EncryptionAlgorithm encryptionAlgorithm = EncryptionAlgorithmUtils.createEncryptionAlgorithm(EncryptionAlgorithmUtils.ALGORITHM_AES_256_GCM96_NAME); + assertNotNull(encryptionAlgorithm); + encryptionAlgorithm.setKmsKey(new KmsKey(secretKey, 1)); + String decrypted = encryptionAlgorithm.decrypt((String) result.getExtractedFields().get("phone_number")); + assertEquals(DATA, decrypted); + encrypt.close(); + } + + @Test + public void testEncryptByIdentifier() { + Map<String, Object> map = new HashMap<>(); + map.put("identifier", EncryptionAlgorithmUtils.ALGORITHM_AES_256_GCM96_NAME); + udfContext.setParameters(map); + Encrypt encrypt1 = new Encrypt(); + assertDoesNotThrow(() -> encrypt1.open(mockLocalRuntimeContext(), udfContext)); + encrypt1.close(); + + Encrypt encrypt2 = new Encrypt(); + map.put("identifier", EncryptionAlgorithmUtils.ALGORITHM_AES_128_GCM96_NAME); + udfContext.setParameters(map); + assertDoesNotThrow(() -> encrypt2.open(mockLocalRuntimeContext(), udfContext)); + encrypt2.close(); + + Encrypt encrypt3 = new Encrypt(); + map.put("identifier", EncryptionAlgorithmUtils.ALGORITHM_SM4_GCM96_NAME); + udfContext.setParameters(map); + assertDoesNotThrow(() -> encrypt3.open(mockLocalRuntimeContext(), udfContext)); + encrypt3.close(); + } + + @Test + public void testEncryptionAlgorithm() throws Exception { + EncryptionAlgorithm encryptionAlgorithm = EncryptionAlgorithmUtils.createEncryptionAlgorithm(EncryptionAlgorithmUtils.ALGORITHM_AES_128_GCM96_NAME); + assertNotNull(encryptionAlgorithm); + encryptionAlgorithm.setKmsKey(new KmsKey("aaaaaaaaaaaaaaaa".getBytes(), 1)); + String encryptData = encryptionAlgorithm.encrypt(DATA); + String decryptData = encryptionAlgorithm.decrypt(encryptData); + assertEquals(DATA, decryptData); + + encryptionAlgorithm = EncryptionAlgorithmUtils.createEncryptionAlgorithm(EncryptionAlgorithmUtils.ALGORITHM_AES_256_GCM96_NAME); + assertNotNull(encryptionAlgorithm); + encryptionAlgorithm.setKmsKey(new KmsKey("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".getBytes(), 1)); + encryptData = encryptionAlgorithm.encrypt(DATA); + decryptData = encryptionAlgorithm.decrypt(encryptData); + assertEquals(DATA, decryptData); + + encryptionAlgorithm = EncryptionAlgorithmUtils.createEncryptionAlgorithm(EncryptionAlgorithmUtils.ALGORITHM_SM4_GCM96_NAME); + assertNotNull(encryptionAlgorithm); + encryptionAlgorithm.setKmsKey(new KmsKey("aaaaaaaaaaaaaaaa".getBytes(), 1)); + encryptData = encryptionAlgorithm.encrypt(DATA); + decryptData = encryptionAlgorithm.decrypt(encryptData); + assertEquals(DATA, decryptData); + + encryptionAlgorithm = EncryptionAlgorithmUtils.createEncryptionAlgorithm("sm4"); + assertNull(encryptionAlgorithm); + } + + @Test + public void testEncryptError() { + RuntimeContext runtimeContext = mockLocalRuntimeContext(); + Encrypt encrypt = new Encrypt(); + udfContext.setParameters(null); + assertThrows(GrootStreamRuntimeException.class, () -> encrypt.open(runtimeContext, udfContext)); + + Map<String, Object> map = new HashMap<>(); + udfContext.setParameters(map); + assertThrows(GrootStreamRuntimeException.class, () -> encrypt.open(runtimeContext, udfContext)); + + map.put("identifier", "aes"); + udfContext.setParameters(map); + assertThrows(GrootStreamRuntimeException.class, () -> encrypt.open(runtimeContext, udfContext)); + } + + static RuntimeContext mockLocalRuntimeContext() { + RuntimeContext runtimeContext = Mockito.mock(RuntimeContext.class); + ExecutionConfig executionConfig = Mockito.mock(ExecutionConfig.class); + Mockito.when(runtimeContext.getExecutionConfig()).thenReturn(executionConfig); + MetricGroup metricGroup = Mockito.mock(OperatorMetricGroup.class); + Mockito.when(runtimeContext.getMetricGroup()).thenReturn(metricGroup); + Mockito.when(metricGroup.addGroup(Mockito.anyString())).thenReturn(metricGroup); + Mockito.when(metricGroup.counter(Mockito.anyString())).thenReturn(new SimpleCounter()); + Configuration configuration = new Configuration(); + CommonConfig commonConfig = new CommonConfig(); + Map<String, KmsConfig> kmsConfigs = new HashMap<>(); + KmsConfig kmsConfig = new KmsConfig(); + kmsConfig.setType(KmsUtils.KMS_TYPE_LOCAL); + kmsConfigs.put(KmsUtils.KMS_TYPE_LOCAL, kmsConfig); + kmsConfig = new KmsConfig(); + kmsConfig.setType(KmsUtils.KMS_TYPE_VAULT); + kmsConfigs.put(KmsUtils.KMS_TYPE_VAULT, kmsConfig); + SSLConfig sslConfig = new SSLConfig(); + sslConfig.setSkipVerification(true); + Map<String, String> propertiesConfig = new HashMap<>(); + propertiesConfig.put("projection.encrypt.schema.registry.uri", "127.0.0.1:9999/v1/schema/session_record?option=encrypt_fields"); + commonConfig.setKmsConfig(kmsConfigs); + commonConfig.setSslConfig(sslConfig); + commonConfig.setPropertiesConfig(propertiesConfig); + configuration.setString(Constants.SYSPROP_GROOTSTREAM_CONFIG, JSON.toJSONString(commonConfig)); + configuration.setString(Constants.SYSPROP_KMS_TYPE_CONFIG, KmsUtils.KMS_TYPE_LOCAL); + Mockito.when(executionConfig.getGlobalJobParameters()).thenReturn(configuration); + return runtimeContext; + } + + static RuntimeContext mockVaultRuntimeContext() { + RuntimeContext runtimeContext = Mockito.mock(RuntimeContext.class); + ExecutionConfig executionConfig = Mockito.mock(ExecutionConfig.class); + Mockito.when(runtimeContext.getExecutionConfig()).thenReturn(executionConfig); + MetricGroup metricGroup = Mockito.mock(OperatorMetricGroup.class); + Mockito.when(runtimeContext.getMetricGroup()).thenReturn(metricGroup); + Mockito.when(metricGroup.addGroup(Mockito.anyString())).thenReturn(metricGroup); + Mockito.when(metricGroup.counter(Mockito.anyString())).thenReturn(new SimpleCounter()); + Configuration configuration = new Configuration(); + CommonConfig commonConfig = new CommonConfig(); + Map<String, KmsConfig> kmsConfigs = new HashMap<>(); + KmsConfig kmsConfig = new KmsConfig(); + kmsConfig.setType(KmsUtils.KMS_TYPE_VAULT); + kmsConfigs.put(KmsUtils.KMS_TYPE_VAULT, kmsConfig); + kmsConfig = new KmsConfig(); + kmsConfig.setType(KmsUtils.KMS_TYPE_LOCAL); + kmsConfigs.put(KmsUtils.KMS_TYPE_LOCAL, kmsConfig); + SSLConfig sslConfig = new SSLConfig(); + sslConfig.setSkipVerification(true); + Map<String, String> propertiesConfig = new HashMap<>(); + propertiesConfig.put("projection.encrypt.schema.registry.uri", "127.0.0.1:9999/v1/schema/session_record?option=encrypt_fields"); + commonConfig.setKmsConfig(kmsConfigs); + commonConfig.setSslConfig(sslConfig); + commonConfig.setPropertiesConfig(propertiesConfig); + configuration.setString(Constants.SYSPROP_GROOTSTREAM_CONFIG, JSON.toJSONString(commonConfig)); + configuration.setString(Constants.SYSPROP_KMS_TYPE_CONFIG, KmsUtils.KMS_TYPE_VAULT); + Mockito.when(executionConfig.getGlobalJobParameters()).thenReturn(configuration); + return runtimeContext; + } + + static MockedStatic<HttpClientPoolUtil> mockSensitiveFields() throws IOException { + String sensitiveFieldsStr = "{\n" + + " \"status\": 200,\n" + + " \"success\": true,\n" + + " \"message\": \"Success\",\n" + + " \"data\": [\n" + + " \"phone_number\",\n" + + " \"server_ip\"\n" + + " ]\n" + + "}"; + HttpClientPoolUtil instance = Mockito.mock(HttpClientPoolUtil.class); + Mockito.when(instance.httpGet(ArgumentMatchers.any())).thenReturn(sensitiveFieldsStr); + MockedStatic<HttpClientPoolUtil> httpClientPoolUtilMockedStatic = Mockito.mockStatic(HttpClientPoolUtil.class); + Mockito.when(HttpClientPoolUtil.getInstance()).thenReturn(instance); + return httpClientPoolUtilMockedStatic; + } +} |
