diff options
| author | 侯晋川 <[email protected]> | 2024-10-25 14:17:14 +0800 |
|---|---|---|
| committer | 侯晋川 <[email protected]> | 2024-10-25 14:17:14 +0800 |
| commit | 7ab2ffecf20dd0a39c9bc63ff4f879bceb3ca704 (patch) | |
| tree | 94c43d418e47f853cbca6a412711d4d054f5ac10 /groot-core/src | |
| parent | 505b04ea10f1e3e37410f5ef1b0721e6f23caebb (diff) | |
[feature][core]新增Encrypt和HMAC函数
Diffstat (limited to 'groot-core/src')
8 files changed, 739 insertions, 0 deletions
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/pojo/KmsKey.java b/groot-core/src/main/java/com/geedgenetworks/core/pojo/KmsKey.java new file mode 100644 index 0000000..2690254 --- /dev/null +++ b/groot-core/src/main/java/com/geedgenetworks/core/pojo/KmsKey.java @@ -0,0 +1,19 @@ +package com.geedgenetworks.core.pojo; + + +import lombok.Data; + +@Data +public class KmsKey { + + private byte[] keyData; + private int keyVersion; + + public KmsKey() { + } + + public KmsKey(byte[] keyData, int keyVersion) { + this.keyData = keyData; + this.keyVersion = keyVersion; + } +}
\ No newline at end of file diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java new file mode 100644 index 0000000..cc05397 --- /dev/null +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java @@ -0,0 +1,159 @@ +package com.geedgenetworks.core.udf; + +import cn.hutool.core.util.ArrayUtil; +import com.alibaba.fastjson2.JSON; +import com.geedgenetworks.common.Constants; +import com.geedgenetworks.common.Event; +import com.geedgenetworks.common.config.CommonConfig; +import com.geedgenetworks.common.config.KmsConfig; +import com.geedgenetworks.common.config.SSLConfig; +import com.geedgenetworks.common.exception.CommonErrorCode; +import com.geedgenetworks.common.exception.GrootStreamRuntimeException; +import com.geedgenetworks.common.udf.ScalarFunction; +import com.geedgenetworks.common.udf.UDFContext; +import com.geedgenetworks.core.pojo.KmsKey; +import com.geedgenetworks.core.udf.encrypt.EncryptionAlgorithm; +import com.geedgenetworks.core.utils.*; +import com.geedgenetworks.utils.StringUtil; +import io.github.jopenlibs.vault.VaultException; +import lombok.extern.slf4j.Slf4j; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.configuration.Configuration; + +import java.util.Arrays; +import java.util.Map; + +@Slf4j +public class Encrypt implements ScalarFunction { + + private String lookupFieldName; + private String outputFieldName; + private String identifier; + private String defaultVal; + private String type; + private transient SingleValueMap.Data<LoadIntervalDataUtil<String[]>> sensitiveFieldsData; + private transient SingleValueMap.Data<LoadIntervalDataUtil<KmsKey>> kmsKeyData; + private transient EncryptionAlgorithm encryptionAlgorithm; + + @Override + public void open(RuntimeContext runtimeContext, UDFContext udfContext) { + checkUdfContext(udfContext); + if (udfContext.getParameters().containsKey("default_val")) { + this.defaultVal = udfContext.getParameters().get("default_val").toString(); + } + this.lookupFieldName = udfContext.getLookup_fields().get(0); + this.outputFieldName = udfContext.getOutput_fields().get(0); + this.identifier = udfContext.getParameters().get("identifier").toString(); + + Configuration configuration = (Configuration) runtimeContext.getExecutionConfig().getGlobalJobParameters(); + CommonConfig commonConfig = JSON.parseObject(configuration.toMap().get(Constants.SYSPROP_GROOTSTREAM_CONFIG), CommonConfig.class); + Map<String, KmsConfig> kmsConfigs = commonConfig.getKmsConfig(); + if (kmsConfigs.isEmpty()) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms type is not null!"); + } else if (kmsConfigs.size() > 1) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms type is repeated!"); + } + KmsConfig kmsConfig = kmsConfigs.values().iterator().next(); + SSLConfig sslConfig = commonConfig.getSslConfig(); + Map<String, String> propertiesConfig = commonConfig.getPropertiesConfig(); + type = kmsConfig.getType(); + try { + encryptionAlgorithm = EncryptionAlgorithmUtils.getEncryptionAlgorithm(identifier); + if (encryptionAlgorithm == null) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Parameters identifier is illegal!"); + } + kmsKeyData = SingleValueMap.acquireData("kmsKeyData", + () -> LoadIntervalDataUtil.newInstance(() -> getKmsKey(kmsConfig, sslConfig, identifier), + LoadIntervalDataOptions.defaults("kmsKeyData", 60000)), LoadIntervalDataUtil::stop); + sensitiveFieldsData = SingleValueMap.acquireData("sensitiveFields", + () -> LoadIntervalDataUtil.newInstance(() -> getEncryptFields(propertiesConfig.get("projection.encrypt.schema.registry.uri")), + LoadIntervalDataOptions.defaults("sensitiveFields", 60000)), LoadIntervalDataUtil::stop); + KmsKey kmsKey = kmsKeyData.getData().data(); + if (encryptionAlgorithm.getSecretKeyLength() == kmsKey.getKeyData().length) { + encryptionAlgorithm.setKmsKey(kmsKey); + } else { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms secret Key requires " + encryptionAlgorithm.getSecretKeyLength() + " bytes!"); + } + } catch (Exception e) { + throw new GrootStreamRuntimeException(CommonErrorCode.UNSUPPORTED_OPERATION, "Initialization UDF Encrypt failed!", e); + } + } + + @Override + public Event evaluate(Event event) { + try { + KmsKey kmsKey = kmsKeyData.getData().data(); + if (kmsKey.getKeyVersion() != encryptionAlgorithm.getKmsKey().getKeyVersion() || !Arrays.equals(kmsKey.getKeyData(), encryptionAlgorithm.getKmsKey().getKeyData())) { + encryptionAlgorithm.setKmsKey(kmsKey); + } + if (ArrayUtil.contains(sensitiveFieldsData.getData().data(), lookupFieldName) && event.getExtractedFields().containsKey(lookupFieldName)) { + String value = (String) event.getExtractedFields().get(lookupFieldName); + if (StringUtil.isNotBlank(value)) { + String encryptResult = encryptionAlgorithm.encrypt(value); + if (StringUtil.isEmpty(encryptResult)) { + event.getExtractedFields().put(outputFieldName, StringUtil.isNotBlank(defaultVal) ? defaultVal : value); + } else { + if (KmsUtils.KMS_TYPE_VAULT.equals(type)) { + encryptResult = "vault:v" + encryptionAlgorithm.getKmsKey().getKeyVersion() + ":" + encryptResult; + } + event.getExtractedFields().put(outputFieldName, encryptResult); + } + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return event; + } + + @Override + public String functionName() { + return "ENCRYPT"; + } + + @Override + public void close() { + if (sensitiveFieldsData != null) { + sensitiveFieldsData.release(); + } + if (kmsKeyData != null) { + kmsKeyData.release(); + } + } + + private void checkUdfContext(UDFContext udfContext) { + if (udfContext.getParameters() == null) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required parameters"); + } + if (udfContext.getLookup_fields().size() != 1) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "The function lookup fields only support 1 value"); + } + if (udfContext.getOutput_fields().size() != 1) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "The function output fields only support 1 value"); + } + if (!udfContext.getParameters().containsKey("identifier")) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Parameters must contains identifier"); + } + } + + private KmsKey getKmsKey(KmsConfig kmsConfig, SSLConfig sslConfig, String identifier) throws VaultException { + KmsKey kmsKey = null; + if (KmsUtils.KMS_TYPE_VAULT.equals(kmsConfig.getType())) { + kmsKey = KmsUtils.getVaultKey(kmsConfig, sslConfig, identifier); + } else if (KmsUtils.KMS_TYPE_LOCAL.equals(kmsConfig.getType())) { + kmsKey = new KmsKey(kmsConfig.getSecretKey().getBytes(), 1); + } + return kmsKey; + } + + private String[] getEncryptFields(String url) { + String[] encryptFields = new String[]{"phone_number", "server_ip"}; +// try { +// String s = HttpClientPoolUtil.getInstance().httpGet(URI.create(URLUtil.normalize(url))); +// encryptFields = s.split(","); +// } catch (Exception e) { +// log.error("Get encrypt fields error! " + e.getMessage()); +// } + return encryptFields; + } +} diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/Hmac.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/Hmac.java new file mode 100644 index 0000000..0d2e1ca --- /dev/null +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/Hmac.java @@ -0,0 +1,104 @@ +package com.geedgenetworks.core.udf; + +import cn.hutool.crypto.digest.HMac; +import cn.hutool.crypto.digest.HmacAlgorithm; +import com.geedgenetworks.common.Event; +import com.geedgenetworks.common.exception.CommonErrorCode; +import com.geedgenetworks.common.exception.GrootStreamRuntimeException; +import com.geedgenetworks.common.udf.ScalarFunction; +import com.geedgenetworks.common.udf.UDFContext; +import com.geedgenetworks.utils.StringUtil; +import lombok.extern.slf4j.Slf4j; +import org.apache.flink.api.common.functions.RuntimeContext; + +@Slf4j +public class Hmac implements ScalarFunction { + + private String lookupFieldName; + private String outputFieldName; + private String outputFormat; + private HMac hMac; + + @Override + public void open(RuntimeContext runtimeContext, UDFContext udfContext) { + checkUdfContext(udfContext); + String secretKey = udfContext.getParameters().get("secret_key").toString(); + String algorithm = "sha256"; + if (udfContext.getParameters().containsKey("algorithm")) { + algorithm = udfContext.getParameters().get("algorithm").toString(); + } + this.hMac = new HMac(getHmacAlgorithm(algorithm), secretKey.getBytes()); + this.lookupFieldName = udfContext.getLookup_fields().get(0); + this.outputFieldName = udfContext.getOutput_fields().get(0); + this.outputFormat = "base64"; + if (udfContext.getParameters().containsKey("output_format")) { + this.outputFormat = udfContext.getParameters().get("output_format").toString(); + } + } + + @Override + public Event evaluate(Event event) { + String encodeResult = ""; + String message = (String) event.getExtractedFields().get(lookupFieldName); + if (StringUtil.isNotBlank(message)) { + switch (outputFormat) { + case "hex": + encodeResult = hMac.digestHex(message); + break; + case "base64": + encodeResult = hMac.digestBase64(message, false); + break; + default: + encodeResult = hMac.digestBase64(message, false); + break; + } + } + event.getExtractedFields().put(outputFieldName, encodeResult); + return event; + } + + @Override + public String functionName() { + return "HMAC"; + } + + @Override + public void close() { + + } + + private void checkUdfContext(UDFContext udfContext) { + if (udfContext.getParameters() == null || udfContext.getOutput_fields() == null) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required parameters"); + } + if (udfContext.getLookup_fields().size() != 1) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "The function lookup fields only support 1 value"); + } + if (udfContext.getOutput_fields().size() != 1) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "The function output fields only support 1 value"); + } + if (!udfContext.getParameters().containsKey("secret_key")) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "parameters must contains secret_key"); + } + } + + private String getHmacAlgorithm(String algorithm) { + if (StringUtil.containsIgnoreCase(algorithm, "sha256")) { + return HmacAlgorithm.HmacSHA256.getValue(); + } else if (StringUtil.containsIgnoreCase(algorithm, "sha1")) { + return HmacAlgorithm.HmacSHA1.getValue(); + } else if (StringUtil.containsIgnoreCase(algorithm, "md5")) { + return HmacAlgorithm.HmacMD5.getValue(); + } else if (StringUtil.containsIgnoreCase(algorithm, "sha384")) { + return HmacAlgorithm.HmacSHA384.getValue(); + } else if (StringUtil.containsIgnoreCase(algorithm, "sha512")) { + return HmacAlgorithm.HmacSHA512.getValue(); + } else if (StringUtil.containsIgnoreCase(algorithm, "sm3")) { + return HmacAlgorithm.HmacSM3.getValue(); + } else if (StringUtil.containsIgnoreCase(algorithm, "sm4")) { + return HmacAlgorithm.SM4CMAC.getValue(); + } else { + return HmacAlgorithm.HmacSHA256.getValue(); + } + } +} diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java new file mode 100644 index 0000000..74be5a8 --- /dev/null +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java @@ -0,0 +1,83 @@ +package com.geedgenetworks.core.udf.encrypt; + +import cn.hutool.core.util.RandomUtil; +import com.geedgenetworks.core.pojo.KmsKey; + +import javax.crypto.Cipher; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.SecretKeySpec; +import java.util.Base64; + +public class AES128GCM96Algorithm implements EncryptionAlgorithm { + private static final String IDENTIFIER = "aes-128-gcm96"; + private static final String ALGORITHM = "AES"; + private static final String TRANSFORMATION = "AES/GCM/NoPadding"; + private static final int GCM_TAG_LENGTH = 128; + private static final int GCM_96_NONCE_LENGTH = 12; + private static final int SECRET_KEY_LENGTH = 16; + + private final Cipher cipher; + private KmsKey kmsKey; + + public AES128GCM96Algorithm() throws Exception { + this.cipher = Cipher.getInstance(TRANSFORMATION); + } + + @Override + public String getIdentifier() { + return IDENTIFIER; + } + + @Override + public int getSecretKeyLength() { + return SECRET_KEY_LENGTH; + } + + @Override + public KmsKey getKmsKey() { + return kmsKey; + } + + @Override + public void setKmsKey(KmsKey kmsKey) { + this.kmsKey = kmsKey; + } + + @Override + public String encrypt(String content) { + String encryptedString = ""; + try { + byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); + cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); + byte[] encryptedBytes = cipher.doFinal(content.getBytes()); + byte[] combinedBytes = new byte[GCM_96_NONCE_LENGTH + encryptedBytes.length]; + System.arraycopy(nonce, 0, combinedBytes, 0, GCM_96_NONCE_LENGTH); + System.arraycopy(encryptedBytes, 0, combinedBytes, GCM_96_NONCE_LENGTH, encryptedBytes.length); + encryptedString = Base64.getEncoder().encodeToString(combinedBytes); + } catch (Exception e) { + throw new RuntimeException(e); + } + return encryptedString; + } + + @Override + public String decrypt(String content) { + String decryptedString = ""; + try { + byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH); + Cipher cipher = Cipher.getInstance(TRANSFORMATION); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); + cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); + byte[] combined = Base64.getDecoder().decode(content); + byte[] encryptedBytes = new byte[combined.length - GCM_96_NONCE_LENGTH]; + System.arraycopy(combined, 0, nonce, 0, GCM_96_NONCE_LENGTH); + System.arraycopy(combined, GCM_96_NONCE_LENGTH, encryptedBytes, 0, encryptedBytes.length); + byte[] decryptedBytes = cipher.doFinal(encryptedBytes); + decryptedString = new String(decryptedBytes); + } catch (Exception e) { + throw new RuntimeException(e); + } + return decryptedString; + } +} diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java new file mode 100644 index 0000000..64d88d9 --- /dev/null +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java @@ -0,0 +1,83 @@ +package com.geedgenetworks.core.udf.encrypt; + +import cn.hutool.core.util.RandomUtil; +import com.geedgenetworks.core.pojo.KmsKey; + +import javax.crypto.Cipher; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.SecretKeySpec; +import java.util.Base64; + +public class AES256GCM96Algorithm implements EncryptionAlgorithm { + private static final String IDENTIFIER = "aes-256-gcm96"; + private static final String ALGORITHM = "AES"; + private static final String TRANSFORMATION = "AES/GCM/NoPadding"; + private static final int GCM_TAG_LENGTH = 128; + private static final int GCM_96_NONCE_LENGTH = 12; + private static final int SECRET_KEY_LENGTH = 32; + + private final Cipher cipher; + private KmsKey kmsKey; + + public AES256GCM96Algorithm() throws Exception { + this.cipher = Cipher.getInstance(TRANSFORMATION); + } + + @Override + public String getIdentifier() { + return IDENTIFIER; + } + + @Override + public int getSecretKeyLength() { + return SECRET_KEY_LENGTH; + } + + @Override + public KmsKey getKmsKey() { + return kmsKey; + } + + @Override + public void setKmsKey(KmsKey kmsKey) { + this.kmsKey = kmsKey; + } + + @Override + public String encrypt(String content) { + String encryptedString = ""; + try { + byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); + cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); + byte[] encryptedBytes = cipher.doFinal(content.getBytes()); + byte[] combinedBytes = new byte[GCM_96_NONCE_LENGTH + encryptedBytes.length]; + System.arraycopy(nonce, 0, combinedBytes, 0, GCM_96_NONCE_LENGTH); + System.arraycopy(encryptedBytes, 0, combinedBytes, GCM_96_NONCE_LENGTH, encryptedBytes.length); + encryptedString = Base64.getEncoder().encodeToString(combinedBytes); + } catch (Exception e) { + throw new RuntimeException(e); + } + return encryptedString; + } + + @Override + public String decrypt(String content) { + String decryptedString = ""; + try { + byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH); + Cipher cipher = Cipher.getInstance(TRANSFORMATION); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); + cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); + byte[] combined = Base64.getDecoder().decode(content); + byte[] encryptedBytes = new byte[combined.length - GCM_96_NONCE_LENGTH]; + System.arraycopy(combined, 0, nonce, 0, GCM_96_NONCE_LENGTH); + System.arraycopy(combined, GCM_96_NONCE_LENGTH, encryptedBytes, 0, encryptedBytes.length); + byte[] decryptedBytes = cipher.doFinal(encryptedBytes); + decryptedString = new String(decryptedBytes); + } catch (Exception e) { + throw new RuntimeException(e); + } + return decryptedString; + } +} diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java new file mode 100644 index 0000000..3c13820 --- /dev/null +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java @@ -0,0 +1,83 @@ +package com.geedgenetworks.core.udf.encrypt; + +import cn.hutool.core.util.RandomUtil; +import com.geedgenetworks.core.pojo.KmsKey; + +import javax.crypto.Cipher; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.SecretKeySpec; +import java.util.Base64; + +public class SM4GCM96Algorithm implements EncryptionAlgorithm { + private static final String IDENTIFIER = "sm4-gcm96"; + private static final String ALGORITHM = "SM4"; + private static final String TRANSFORMATION = "SM4/GCM/NoPadding"; + private static final int GCM_TAG_LENGTH = 128; + private static final int GCM_96_NONCE_LENGTH = 12; + private static final int SECRET_KEY_LENGTH = 16; + + private final Cipher cipher; + private KmsKey kmsKey; + + public SM4GCM96Algorithm() throws Exception { + this.cipher = Cipher.getInstance(TRANSFORMATION); + } + + @Override + public String getIdentifier() { + return IDENTIFIER; + } + + @Override + public int getSecretKeyLength() { + return SECRET_KEY_LENGTH; + } + + @Override + public KmsKey getKmsKey() { + return kmsKey; + } + + @Override + public void setKmsKey(KmsKey kmsKey) { + this.kmsKey = kmsKey; + } + + @Override + public String encrypt(String content) { + String encryptedString = ""; + try { + byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); + cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); + byte[] encryptedBytes = cipher.doFinal(content.getBytes()); + byte[] combinedBytes = new byte[GCM_96_NONCE_LENGTH + encryptedBytes.length]; + System.arraycopy(nonce, 0, combinedBytes, 0, GCM_96_NONCE_LENGTH); + System.arraycopy(encryptedBytes, 0, combinedBytes, GCM_96_NONCE_LENGTH, encryptedBytes.length); + encryptedString = Base64.getEncoder().encodeToString(combinedBytes); + } catch (Exception e) { + throw new RuntimeException(e); + } + return encryptedString; + } + + @Override + public String decrypt(String content) { + String decryptedString = ""; + try { + byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH); + Cipher cipher = Cipher.getInstance(TRANSFORMATION); + GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce); + cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec); + byte[] combined = Base64.getDecoder().decode(content); + byte[] encryptedBytes = new byte[combined.length - GCM_96_NONCE_LENGTH]; + System.arraycopy(combined, 0, nonce, 0, GCM_96_NONCE_LENGTH); + System.arraycopy(combined, GCM_96_NONCE_LENGTH, encryptedBytes, 0, encryptedBytes.length); + byte[] decryptedBytes = cipher.doFinal(encryptedBytes); + decryptedString = new String(decryptedBytes); + } catch (Exception e) { + throw new RuntimeException(e); + } + return decryptedString; + } +} diff --git a/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java b/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java new file mode 100644 index 0000000..8e6a345 --- /dev/null +++ b/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java @@ -0,0 +1,72 @@ +package com.geedgenetworks.core.utils; + +import cn.hutool.core.util.StrUtil; +import com.geedgenetworks.common.config.KmsConfig; +import com.geedgenetworks.common.config.SSLConfig; +import com.geedgenetworks.core.pojo.KmsKey; +import io.github.jopenlibs.vault.SslConfig; +import io.github.jopenlibs.vault.Vault; +import io.github.jopenlibs.vault.VaultConfig; +import io.github.jopenlibs.vault.VaultException; +import io.github.jopenlibs.vault.json.JsonObject; +import io.github.jopenlibs.vault.response.AuthResponse; +import io.github.jopenlibs.vault.response.LogicalResponse; +import lombok.extern.slf4j.Slf4j; + +import java.io.File; +import java.util.Base64; + +@Slf4j +public class KmsUtils { + public static final String KMS_TYPE_LOCAL = "local"; + public static final String KMS_TYPE_VAULT = "vault"; + + public static KmsKey getVaultKey(KmsConfig kmsConfig, SSLConfig sslConfig, String identifier) throws VaultException { + Vault vault = getVaultClient(kmsConfig, sslConfig); + String exportKeyPath; + if (EncryptionAlgorithmUtils.ALGORITHM_SM4_GCM96_NAME.equals(identifier)) { + exportKeyPath = kmsConfig.getPluginKeyPath() + "/export/encryption-key/" + identifier; + } else { + exportKeyPath = kmsConfig.getDefaultKeyPath() + "/export/encryption-key/" + identifier; + } + LogicalResponse exportResponse = vault.logical().read(exportKeyPath); + if (exportResponse.getRestResponse().getStatus() == 200) { + JsonObject keys = exportResponse.getDataObject().get("keys").asObject(); + return new KmsKey(Base64.getDecoder().decode(StrUtil.trim(keys.get(keys.size() + "").asString(), '"')), keys.size()); + } else { + log.error("Get kms key error! code: {} body: {}", exportResponse.getRestResponse().getStatus(), new String(exportResponse.getRestResponse().getBody())); + return null; + } + } + + public static Vault getVaultClient(KmsConfig kmsConfig, SSLConfig sslConfig) throws VaultException { + String username = kmsConfig.getUsername(); + String password = kmsConfig.getPassword(); + String url = kmsConfig.getUrl(); + boolean skipVerification = true; + String caCertificatePath = null; + String certificatePath = null; + String privateKeyPath = null; + if (sslConfig != null) { + skipVerification = sslConfig.getSkipVerification(); + caCertificatePath = sslConfig.getCaCertificatePath(); + certificatePath = sslConfig.getCertificatePath(); + privateKeyPath = sslConfig.getPrivateKeyPath(); + } + SslConfig vaultSslConfig = new SslConfig().verify(!skipVerification).build(); + if (!skipVerification) { + vaultSslConfig.pemFile(new File(caCertificatePath)) + .clientPemFile(new File(certificatePath)) + .clientKeyPemFile(new File(privateKeyPath)) + .build(); + } + VaultConfig config = new VaultConfig() + .address(url) + .engineVersion(1) + .sslConfig(vaultSslConfig) + .build(); + AuthResponse authResponse = Vault.create(config).auth().loginByUserPass(username, password); + config.token(authResponse.getAuthClientToken()); + return Vault.create(config); + } +} diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/HmacFunctionTest.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/HmacFunctionTest.java new file mode 100644 index 0000000..d2219d8 --- /dev/null +++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/HmacFunctionTest.java @@ -0,0 +1,136 @@ +package com.geedgenetworks.core.udf.test.simple; + +import com.geedgenetworks.common.Event; +import com.geedgenetworks.common.exception.GrootStreamRuntimeException; +import com.geedgenetworks.common.udf.UDFContext; +import com.geedgenetworks.core.udf.Hmac; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class HmacFunctionTest { + + private static final String SECRET_KEY = ".geedgenetworks."; + private static final String DATA = "13812345678"; + private static UDFContext udfContext; + + @BeforeAll + public static void setUp() { + udfContext = new UDFContext(); + udfContext.setLookup_fields(Collections.singletonList("phone_number")); + udfContext.setOutput_fields(Collections.singletonList("phone_number_mac")); + } + + @Test + public void testHmacAsBase64() { + Map<String, Object> map = new HashMap<>(); + map.put("secret_key", SECRET_KEY); + map.put("algorithm", "sha256"); + map.put("output_format", "base64"); + udfContext.setParameters(map); + Hmac hmac = new Hmac(); + hmac.open(null, udfContext); + Event event = new Event(); + Map<String, Object> extractedFields = new HashMap<>(); + extractedFields.put("phone_number", DATA); + event.setExtractedFields(extractedFields); + Event result1 = hmac.evaluate(event); + assertEquals("zaj6UKovIsDahIBeRZ2PmgPIfDEr900F2xWu+iQfFrw=", result1.getExtractedFields().get("phone_number_mac")); + } + + @Test + public void testHmacAsHex() { + Map<String, Object> map = new HashMap<>(); + map.put("secret_key", SECRET_KEY); + map.put("algorithm", "sha256"); + map.put("output_format", "hex"); + udfContext.setParameters(map); + Hmac hmac = new Hmac(); + hmac.open(null, udfContext); + Event event = new Event(); + Map<String, Object> extractedFields = new HashMap<>(); + extractedFields.put("phone_number", DATA); + event.setExtractedFields(extractedFields); + Event result1 = hmac.evaluate(event); + assertEquals("cda8fa50aa2f22c0da84805e459d8f9a03c87c312bf74d05db15aefa241f16bc", result1.getExtractedFields().get("phone_number_mac")); + } + + @Test + public void testHmacAlgorithm() { + Map<String, Object> map = new HashMap<>(); + map.put("secret_key", SECRET_KEY); + map.put("algorithm", "sm4"); + map.put("output_format", "base64"); + udfContext.setParameters(map); + Hmac hmac = new Hmac(); + hmac.open(null, udfContext); + Event event = new Event(); + Map<String, Object> extractedFields = new HashMap<>(); + extractedFields.put("phone_number", DATA); + event.setExtractedFields(extractedFields); + Event result = hmac.evaluate(event); + assertEquals("QX1q4Y7y3quYCDje9BuSjg==", result.getExtractedFields().get("phone_number_mac")); + + map = new HashMap<>(); + map.put("secret_key", SECRET_KEY); + map.put("algorithm", "sha1"); + map.put("output_format", "base64"); + udfContext.setParameters(map); + hmac = new Hmac(); + hmac.open(null, udfContext); + event.setExtractedFields(extractedFields); + result = hmac.evaluate(event); + assertEquals("NB1b1TsVZ95/0sE+d/6kdtyUFh0=", result.getExtractedFields().get("phone_number_mac")); + + map = new HashMap<>(); + map.put("secret_key", SECRET_KEY); + map.put("algorithm", "sm3"); + map.put("output_format", "base64"); + udfContext.setParameters(map); + hmac = new Hmac(); + hmac.open(null, udfContext); + event.setExtractedFields(extractedFields); + result = hmac.evaluate(event); + assertEquals("BbQNpwLWE3rkaI1WlPBJgYeD14UyL2OwTxiEoTNA3UU=", result.getExtractedFields().get("phone_number_mac")); + + map = new HashMap<>(); + map.put("secret_key", SECRET_KEY); + map.put("algorithm", "md5"); + map.put("output_format", "base64"); + udfContext.setParameters(map); + hmac = new Hmac(); + hmac.open(null, udfContext); + event.setExtractedFields(extractedFields); + result = hmac.evaluate(event); + assertEquals("BQZzRqD3ZR/nJsDIOO4dBg==", result.getExtractedFields().get("phone_number_mac")); + + map = new HashMap<>(); + map.put("secret_key", SECRET_KEY); + map.put("algorithm", "sha512"); + map.put("output_format", "base64"); + udfContext.setParameters(map); + hmac = new Hmac(); + hmac.open(null, udfContext); + event.setExtractedFields(extractedFields); + result = hmac.evaluate(event); + assertEquals("DWrndzlcqf2qvFTbuDC1gZCGmRhuAUayfsxEqr2ZlpY/QOr9HgGUZNOfytRfA4VT8OZK0BwHwcAg5pgGBvPQ4A==", result.getExtractedFields().get("phone_number_mac")); + } + + @Test + public void testHmacError() { + Map<String, Object> map = new HashMap<>(); + map.put("secret_key", SECRET_KEY); + map.put("algorithm", "sha256"); + map.put("output_format", "hex"); + udfContext.setParameters(map); + Hmac hmac = new Hmac(); + udfContext.getParameters().remove("secret_key"); + assertThrows(GrootStreamRuntimeException.class, () -> hmac.open(null, udfContext)); + } +} |
