summaryrefslogtreecommitdiff
path: root/groot-core/src
diff options
context:
space:
mode:
author侯晋川 <[email protected]>2024-10-25 14:17:14 +0800
committer侯晋川 <[email protected]>2024-10-25 14:17:14 +0800
commit7ab2ffecf20dd0a39c9bc63ff4f879bceb3ca704 (patch)
tree94c43d418e47f853cbca6a412711d4d054f5ac10 /groot-core/src
parent505b04ea10f1e3e37410f5ef1b0721e6f23caebb (diff)
[feature][core]新增Encrypt和HMAC函数
Diffstat (limited to 'groot-core/src')
-rw-r--r--groot-core/src/main/java/com/geedgenetworks/core/pojo/KmsKey.java19
-rw-r--r--groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java159
-rw-r--r--groot-core/src/main/java/com/geedgenetworks/core/udf/Hmac.java104
-rw-r--r--groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java83
-rw-r--r--groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java83
-rw-r--r--groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java83
-rw-r--r--groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java72
-rw-r--r--groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/HmacFunctionTest.java136
8 files changed, 739 insertions, 0 deletions
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/pojo/KmsKey.java b/groot-core/src/main/java/com/geedgenetworks/core/pojo/KmsKey.java
new file mode 100644
index 0000000..2690254
--- /dev/null
+++ b/groot-core/src/main/java/com/geedgenetworks/core/pojo/KmsKey.java
@@ -0,0 +1,19 @@
+package com.geedgenetworks.core.pojo;
+
+
+import lombok.Data;
+
+@Data
+public class KmsKey {
+
+ private byte[] keyData;
+ private int keyVersion;
+
+ public KmsKey() {
+ }
+
+ public KmsKey(byte[] keyData, int keyVersion) {
+ this.keyData = keyData;
+ this.keyVersion = keyVersion;
+ }
+} \ No newline at end of file
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java
new file mode 100644
index 0000000..cc05397
--- /dev/null
+++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/Encrypt.java
@@ -0,0 +1,159 @@
+package com.geedgenetworks.core.udf;
+
+import cn.hutool.core.util.ArrayUtil;
+import com.alibaba.fastjson2.JSON;
+import com.geedgenetworks.common.Constants;
+import com.geedgenetworks.common.Event;
+import com.geedgenetworks.common.config.CommonConfig;
+import com.geedgenetworks.common.config.KmsConfig;
+import com.geedgenetworks.common.config.SSLConfig;
+import com.geedgenetworks.common.exception.CommonErrorCode;
+import com.geedgenetworks.common.exception.GrootStreamRuntimeException;
+import com.geedgenetworks.common.udf.ScalarFunction;
+import com.geedgenetworks.common.udf.UDFContext;
+import com.geedgenetworks.core.pojo.KmsKey;
+import com.geedgenetworks.core.udf.encrypt.EncryptionAlgorithm;
+import com.geedgenetworks.core.utils.*;
+import com.geedgenetworks.utils.StringUtil;
+import io.github.jopenlibs.vault.VaultException;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.configuration.Configuration;
+
+import java.util.Arrays;
+import java.util.Map;
+
+@Slf4j
+public class Encrypt implements ScalarFunction {
+
+ private String lookupFieldName;
+ private String outputFieldName;
+ private String identifier;
+ private String defaultVal;
+ private String type;
+ private transient SingleValueMap.Data<LoadIntervalDataUtil<String[]>> sensitiveFieldsData;
+ private transient SingleValueMap.Data<LoadIntervalDataUtil<KmsKey>> kmsKeyData;
+ private transient EncryptionAlgorithm encryptionAlgorithm;
+
+ @Override
+ public void open(RuntimeContext runtimeContext, UDFContext udfContext) {
+ checkUdfContext(udfContext);
+ if (udfContext.getParameters().containsKey("default_val")) {
+ this.defaultVal = udfContext.getParameters().get("default_val").toString();
+ }
+ this.lookupFieldName = udfContext.getLookup_fields().get(0);
+ this.outputFieldName = udfContext.getOutput_fields().get(0);
+ this.identifier = udfContext.getParameters().get("identifier").toString();
+
+ Configuration configuration = (Configuration) runtimeContext.getExecutionConfig().getGlobalJobParameters();
+ CommonConfig commonConfig = JSON.parseObject(configuration.toMap().get(Constants.SYSPROP_GROOTSTREAM_CONFIG), CommonConfig.class);
+ Map<String, KmsConfig> kmsConfigs = commonConfig.getKmsConfig();
+ if (kmsConfigs.isEmpty()) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms type is not null!");
+ } else if (kmsConfigs.size() > 1) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms type is repeated!");
+ }
+ KmsConfig kmsConfig = kmsConfigs.values().iterator().next();
+ SSLConfig sslConfig = commonConfig.getSslConfig();
+ Map<String, String> propertiesConfig = commonConfig.getPropertiesConfig();
+ type = kmsConfig.getType();
+ try {
+ encryptionAlgorithm = EncryptionAlgorithmUtils.getEncryptionAlgorithm(identifier);
+ if (encryptionAlgorithm == null) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Parameters identifier is illegal!");
+ }
+ kmsKeyData = SingleValueMap.acquireData("kmsKeyData",
+ () -> LoadIntervalDataUtil.newInstance(() -> getKmsKey(kmsConfig, sslConfig, identifier),
+ LoadIntervalDataOptions.defaults("kmsKeyData", 60000)), LoadIntervalDataUtil::stop);
+ sensitiveFieldsData = SingleValueMap.acquireData("sensitiveFields",
+ () -> LoadIntervalDataUtil.newInstance(() -> getEncryptFields(propertiesConfig.get("projection.encrypt.schema.registry.uri")),
+ LoadIntervalDataOptions.defaults("sensitiveFields", 60000)), LoadIntervalDataUtil::stop);
+ KmsKey kmsKey = kmsKeyData.getData().data();
+ if (encryptionAlgorithm.getSecretKeyLength() == kmsKey.getKeyData().length) {
+ encryptionAlgorithm.setKmsKey(kmsKey);
+ } else {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Global parameter kms secret Key requires " + encryptionAlgorithm.getSecretKeyLength() + " bytes!");
+ }
+ } catch (Exception e) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.UNSUPPORTED_OPERATION, "Initialization UDF Encrypt failed!", e);
+ }
+ }
+
+ @Override
+ public Event evaluate(Event event) {
+ try {
+ KmsKey kmsKey = kmsKeyData.getData().data();
+ if (kmsKey.getKeyVersion() != encryptionAlgorithm.getKmsKey().getKeyVersion() || !Arrays.equals(kmsKey.getKeyData(), encryptionAlgorithm.getKmsKey().getKeyData())) {
+ encryptionAlgorithm.setKmsKey(kmsKey);
+ }
+ if (ArrayUtil.contains(sensitiveFieldsData.getData().data(), lookupFieldName) && event.getExtractedFields().containsKey(lookupFieldName)) {
+ String value = (String) event.getExtractedFields().get(lookupFieldName);
+ if (StringUtil.isNotBlank(value)) {
+ String encryptResult = encryptionAlgorithm.encrypt(value);
+ if (StringUtil.isEmpty(encryptResult)) {
+ event.getExtractedFields().put(outputFieldName, StringUtil.isNotBlank(defaultVal) ? defaultVal : value);
+ } else {
+ if (KmsUtils.KMS_TYPE_VAULT.equals(type)) {
+ encryptResult = "vault:v" + encryptionAlgorithm.getKmsKey().getKeyVersion() + ":" + encryptResult;
+ }
+ event.getExtractedFields().put(outputFieldName, encryptResult);
+ }
+ }
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return event;
+ }
+
+ @Override
+ public String functionName() {
+ return "ENCRYPT";
+ }
+
+ @Override
+ public void close() {
+ if (sensitiveFieldsData != null) {
+ sensitiveFieldsData.release();
+ }
+ if (kmsKeyData != null) {
+ kmsKeyData.release();
+ }
+ }
+
+ private void checkUdfContext(UDFContext udfContext) {
+ if (udfContext.getParameters() == null) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required parameters");
+ }
+ if (udfContext.getLookup_fields().size() != 1) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "The function lookup fields only support 1 value");
+ }
+ if (udfContext.getOutput_fields().size() != 1) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "The function output fields only support 1 value");
+ }
+ if (!udfContext.getParameters().containsKey("identifier")) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Parameters must contains identifier");
+ }
+ }
+
+ private KmsKey getKmsKey(KmsConfig kmsConfig, SSLConfig sslConfig, String identifier) throws VaultException {
+ KmsKey kmsKey = null;
+ if (KmsUtils.KMS_TYPE_VAULT.equals(kmsConfig.getType())) {
+ kmsKey = KmsUtils.getVaultKey(kmsConfig, sslConfig, identifier);
+ } else if (KmsUtils.KMS_TYPE_LOCAL.equals(kmsConfig.getType())) {
+ kmsKey = new KmsKey(kmsConfig.getSecretKey().getBytes(), 1);
+ }
+ return kmsKey;
+ }
+
+ private String[] getEncryptFields(String url) {
+ String[] encryptFields = new String[]{"phone_number", "server_ip"};
+// try {
+// String s = HttpClientPoolUtil.getInstance().httpGet(URI.create(URLUtil.normalize(url)));
+// encryptFields = s.split(",");
+// } catch (Exception e) {
+// log.error("Get encrypt fields error! " + e.getMessage());
+// }
+ return encryptFields;
+ }
+}
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/Hmac.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/Hmac.java
new file mode 100644
index 0000000..0d2e1ca
--- /dev/null
+++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/Hmac.java
@@ -0,0 +1,104 @@
+package com.geedgenetworks.core.udf;
+
+import cn.hutool.crypto.digest.HMac;
+import cn.hutool.crypto.digest.HmacAlgorithm;
+import com.geedgenetworks.common.Event;
+import com.geedgenetworks.common.exception.CommonErrorCode;
+import com.geedgenetworks.common.exception.GrootStreamRuntimeException;
+import com.geedgenetworks.common.udf.ScalarFunction;
+import com.geedgenetworks.common.udf.UDFContext;
+import com.geedgenetworks.utils.StringUtil;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.flink.api.common.functions.RuntimeContext;
+
+@Slf4j
+public class Hmac implements ScalarFunction {
+
+ private String lookupFieldName;
+ private String outputFieldName;
+ private String outputFormat;
+ private HMac hMac;
+
+ @Override
+ public void open(RuntimeContext runtimeContext, UDFContext udfContext) {
+ checkUdfContext(udfContext);
+ String secretKey = udfContext.getParameters().get("secret_key").toString();
+ String algorithm = "sha256";
+ if (udfContext.getParameters().containsKey("algorithm")) {
+ algorithm = udfContext.getParameters().get("algorithm").toString();
+ }
+ this.hMac = new HMac(getHmacAlgorithm(algorithm), secretKey.getBytes());
+ this.lookupFieldName = udfContext.getLookup_fields().get(0);
+ this.outputFieldName = udfContext.getOutput_fields().get(0);
+ this.outputFormat = "base64";
+ if (udfContext.getParameters().containsKey("output_format")) {
+ this.outputFormat = udfContext.getParameters().get("output_format").toString();
+ }
+ }
+
+ @Override
+ public Event evaluate(Event event) {
+ String encodeResult = "";
+ String message = (String) event.getExtractedFields().get(lookupFieldName);
+ if (StringUtil.isNotBlank(message)) {
+ switch (outputFormat) {
+ case "hex":
+ encodeResult = hMac.digestHex(message);
+ break;
+ case "base64":
+ encodeResult = hMac.digestBase64(message, false);
+ break;
+ default:
+ encodeResult = hMac.digestBase64(message, false);
+ break;
+ }
+ }
+ event.getExtractedFields().put(outputFieldName, encodeResult);
+ return event;
+ }
+
+ @Override
+ public String functionName() {
+ return "HMAC";
+ }
+
+ @Override
+ public void close() {
+
+ }
+
+ private void checkUdfContext(UDFContext udfContext) {
+ if (udfContext.getParameters() == null || udfContext.getOutput_fields() == null) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required parameters");
+ }
+ if (udfContext.getLookup_fields().size() != 1) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "The function lookup fields only support 1 value");
+ }
+ if (udfContext.getOutput_fields().size() != 1) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "The function output fields only support 1 value");
+ }
+ if (!udfContext.getParameters().containsKey("secret_key")) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "parameters must contains secret_key");
+ }
+ }
+
+ private String getHmacAlgorithm(String algorithm) {
+ if (StringUtil.containsIgnoreCase(algorithm, "sha256")) {
+ return HmacAlgorithm.HmacSHA256.getValue();
+ } else if (StringUtil.containsIgnoreCase(algorithm, "sha1")) {
+ return HmacAlgorithm.HmacSHA1.getValue();
+ } else if (StringUtil.containsIgnoreCase(algorithm, "md5")) {
+ return HmacAlgorithm.HmacMD5.getValue();
+ } else if (StringUtil.containsIgnoreCase(algorithm, "sha384")) {
+ return HmacAlgorithm.HmacSHA384.getValue();
+ } else if (StringUtil.containsIgnoreCase(algorithm, "sha512")) {
+ return HmacAlgorithm.HmacSHA512.getValue();
+ } else if (StringUtil.containsIgnoreCase(algorithm, "sm3")) {
+ return HmacAlgorithm.HmacSM3.getValue();
+ } else if (StringUtil.containsIgnoreCase(algorithm, "sm4")) {
+ return HmacAlgorithm.SM4CMAC.getValue();
+ } else {
+ return HmacAlgorithm.HmacSHA256.getValue();
+ }
+ }
+}
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java
new file mode 100644
index 0000000..74be5a8
--- /dev/null
+++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES128GCM96Algorithm.java
@@ -0,0 +1,83 @@
+package com.geedgenetworks.core.udf.encrypt;
+
+import cn.hutool.core.util.RandomUtil;
+import com.geedgenetworks.core.pojo.KmsKey;
+
+import javax.crypto.Cipher;
+import javax.crypto.spec.GCMParameterSpec;
+import javax.crypto.spec.SecretKeySpec;
+import java.util.Base64;
+
+public class AES128GCM96Algorithm implements EncryptionAlgorithm {
+ private static final String IDENTIFIER = "aes-128-gcm96";
+ private static final String ALGORITHM = "AES";
+ private static final String TRANSFORMATION = "AES/GCM/NoPadding";
+ private static final int GCM_TAG_LENGTH = 128;
+ private static final int GCM_96_NONCE_LENGTH = 12;
+ private static final int SECRET_KEY_LENGTH = 16;
+
+ private final Cipher cipher;
+ private KmsKey kmsKey;
+
+ public AES128GCM96Algorithm() throws Exception {
+ this.cipher = Cipher.getInstance(TRANSFORMATION);
+ }
+
+ @Override
+ public String getIdentifier() {
+ return IDENTIFIER;
+ }
+
+ @Override
+ public int getSecretKeyLength() {
+ return SECRET_KEY_LENGTH;
+ }
+
+ @Override
+ public KmsKey getKmsKey() {
+ return kmsKey;
+ }
+
+ @Override
+ public void setKmsKey(KmsKey kmsKey) {
+ this.kmsKey = kmsKey;
+ }
+
+ @Override
+ public String encrypt(String content) {
+ String encryptedString = "";
+ try {
+ byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH);
+ GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce);
+ cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec);
+ byte[] encryptedBytes = cipher.doFinal(content.getBytes());
+ byte[] combinedBytes = new byte[GCM_96_NONCE_LENGTH + encryptedBytes.length];
+ System.arraycopy(nonce, 0, combinedBytes, 0, GCM_96_NONCE_LENGTH);
+ System.arraycopy(encryptedBytes, 0, combinedBytes, GCM_96_NONCE_LENGTH, encryptedBytes.length);
+ encryptedString = Base64.getEncoder().encodeToString(combinedBytes);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return encryptedString;
+ }
+
+ @Override
+ public String decrypt(String content) {
+ String decryptedString = "";
+ try {
+ byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH);
+ Cipher cipher = Cipher.getInstance(TRANSFORMATION);
+ GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce);
+ cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec);
+ byte[] combined = Base64.getDecoder().decode(content);
+ byte[] encryptedBytes = new byte[combined.length - GCM_96_NONCE_LENGTH];
+ System.arraycopy(combined, 0, nonce, 0, GCM_96_NONCE_LENGTH);
+ System.arraycopy(combined, GCM_96_NONCE_LENGTH, encryptedBytes, 0, encryptedBytes.length);
+ byte[] decryptedBytes = cipher.doFinal(encryptedBytes);
+ decryptedString = new String(decryptedBytes);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return decryptedString;
+ }
+}
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java
new file mode 100644
index 0000000..64d88d9
--- /dev/null
+++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/AES256GCM96Algorithm.java
@@ -0,0 +1,83 @@
+package com.geedgenetworks.core.udf.encrypt;
+
+import cn.hutool.core.util.RandomUtil;
+import com.geedgenetworks.core.pojo.KmsKey;
+
+import javax.crypto.Cipher;
+import javax.crypto.spec.GCMParameterSpec;
+import javax.crypto.spec.SecretKeySpec;
+import java.util.Base64;
+
+public class AES256GCM96Algorithm implements EncryptionAlgorithm {
+ private static final String IDENTIFIER = "aes-256-gcm96";
+ private static final String ALGORITHM = "AES";
+ private static final String TRANSFORMATION = "AES/GCM/NoPadding";
+ private static final int GCM_TAG_LENGTH = 128;
+ private static final int GCM_96_NONCE_LENGTH = 12;
+ private static final int SECRET_KEY_LENGTH = 32;
+
+ private final Cipher cipher;
+ private KmsKey kmsKey;
+
+ public AES256GCM96Algorithm() throws Exception {
+ this.cipher = Cipher.getInstance(TRANSFORMATION);
+ }
+
+ @Override
+ public String getIdentifier() {
+ return IDENTIFIER;
+ }
+
+ @Override
+ public int getSecretKeyLength() {
+ return SECRET_KEY_LENGTH;
+ }
+
+ @Override
+ public KmsKey getKmsKey() {
+ return kmsKey;
+ }
+
+ @Override
+ public void setKmsKey(KmsKey kmsKey) {
+ this.kmsKey = kmsKey;
+ }
+
+ @Override
+ public String encrypt(String content) {
+ String encryptedString = "";
+ try {
+ byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH);
+ GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce);
+ cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec);
+ byte[] encryptedBytes = cipher.doFinal(content.getBytes());
+ byte[] combinedBytes = new byte[GCM_96_NONCE_LENGTH + encryptedBytes.length];
+ System.arraycopy(nonce, 0, combinedBytes, 0, GCM_96_NONCE_LENGTH);
+ System.arraycopy(encryptedBytes, 0, combinedBytes, GCM_96_NONCE_LENGTH, encryptedBytes.length);
+ encryptedString = Base64.getEncoder().encodeToString(combinedBytes);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return encryptedString;
+ }
+
+ @Override
+ public String decrypt(String content) {
+ String decryptedString = "";
+ try {
+ byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH);
+ Cipher cipher = Cipher.getInstance(TRANSFORMATION);
+ GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce);
+ cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec);
+ byte[] combined = Base64.getDecoder().decode(content);
+ byte[] encryptedBytes = new byte[combined.length - GCM_96_NONCE_LENGTH];
+ System.arraycopy(combined, 0, nonce, 0, GCM_96_NONCE_LENGTH);
+ System.arraycopy(combined, GCM_96_NONCE_LENGTH, encryptedBytes, 0, encryptedBytes.length);
+ byte[] decryptedBytes = cipher.doFinal(encryptedBytes);
+ decryptedString = new String(decryptedBytes);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return decryptedString;
+ }
+}
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java
new file mode 100644
index 0000000..3c13820
--- /dev/null
+++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/encrypt/SM4GCM96Algorithm.java
@@ -0,0 +1,83 @@
+package com.geedgenetworks.core.udf.encrypt;
+
+import cn.hutool.core.util.RandomUtil;
+import com.geedgenetworks.core.pojo.KmsKey;
+
+import javax.crypto.Cipher;
+import javax.crypto.spec.GCMParameterSpec;
+import javax.crypto.spec.SecretKeySpec;
+import java.util.Base64;
+
+public class SM4GCM96Algorithm implements EncryptionAlgorithm {
+ private static final String IDENTIFIER = "sm4-gcm96";
+ private static final String ALGORITHM = "SM4";
+ private static final String TRANSFORMATION = "SM4/GCM/NoPadding";
+ private static final int GCM_TAG_LENGTH = 128;
+ private static final int GCM_96_NONCE_LENGTH = 12;
+ private static final int SECRET_KEY_LENGTH = 16;
+
+ private final Cipher cipher;
+ private KmsKey kmsKey;
+
+ public SM4GCM96Algorithm() throws Exception {
+ this.cipher = Cipher.getInstance(TRANSFORMATION);
+ }
+
+ @Override
+ public String getIdentifier() {
+ return IDENTIFIER;
+ }
+
+ @Override
+ public int getSecretKeyLength() {
+ return SECRET_KEY_LENGTH;
+ }
+
+ @Override
+ public KmsKey getKmsKey() {
+ return kmsKey;
+ }
+
+ @Override
+ public void setKmsKey(KmsKey kmsKey) {
+ this.kmsKey = kmsKey;
+ }
+
+ @Override
+ public String encrypt(String content) {
+ String encryptedString = "";
+ try {
+ byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH);
+ GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce);
+ cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec);
+ byte[] encryptedBytes = cipher.doFinal(content.getBytes());
+ byte[] combinedBytes = new byte[GCM_96_NONCE_LENGTH + encryptedBytes.length];
+ System.arraycopy(nonce, 0, combinedBytes, 0, GCM_96_NONCE_LENGTH);
+ System.arraycopy(encryptedBytes, 0, combinedBytes, GCM_96_NONCE_LENGTH, encryptedBytes.length);
+ encryptedString = Base64.getEncoder().encodeToString(combinedBytes);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return encryptedString;
+ }
+
+ @Override
+ public String decrypt(String content) {
+ String decryptedString = "";
+ try {
+ byte[] nonce = RandomUtil.randomBytes(GCM_96_NONCE_LENGTH);
+ Cipher cipher = Cipher.getInstance(TRANSFORMATION);
+ GCMParameterSpec gcmSpec = new GCMParameterSpec(GCM_TAG_LENGTH, nonce);
+ cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(kmsKey.getKeyData(), ALGORITHM), gcmSpec);
+ byte[] combined = Base64.getDecoder().decode(content);
+ byte[] encryptedBytes = new byte[combined.length - GCM_96_NONCE_LENGTH];
+ System.arraycopy(combined, 0, nonce, 0, GCM_96_NONCE_LENGTH);
+ System.arraycopy(combined, GCM_96_NONCE_LENGTH, encryptedBytes, 0, encryptedBytes.length);
+ byte[] decryptedBytes = cipher.doFinal(encryptedBytes);
+ decryptedString = new String(decryptedBytes);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return decryptedString;
+ }
+}
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java b/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java
new file mode 100644
index 0000000..8e6a345
--- /dev/null
+++ b/groot-core/src/main/java/com/geedgenetworks/core/utils/KmsUtils.java
@@ -0,0 +1,72 @@
+package com.geedgenetworks.core.utils;
+
+import cn.hutool.core.util.StrUtil;
+import com.geedgenetworks.common.config.KmsConfig;
+import com.geedgenetworks.common.config.SSLConfig;
+import com.geedgenetworks.core.pojo.KmsKey;
+import io.github.jopenlibs.vault.SslConfig;
+import io.github.jopenlibs.vault.Vault;
+import io.github.jopenlibs.vault.VaultConfig;
+import io.github.jopenlibs.vault.VaultException;
+import io.github.jopenlibs.vault.json.JsonObject;
+import io.github.jopenlibs.vault.response.AuthResponse;
+import io.github.jopenlibs.vault.response.LogicalResponse;
+import lombok.extern.slf4j.Slf4j;
+
+import java.io.File;
+import java.util.Base64;
+
+@Slf4j
+public class KmsUtils {
+ public static final String KMS_TYPE_LOCAL = "local";
+ public static final String KMS_TYPE_VAULT = "vault";
+
+ public static KmsKey getVaultKey(KmsConfig kmsConfig, SSLConfig sslConfig, String identifier) throws VaultException {
+ Vault vault = getVaultClient(kmsConfig, sslConfig);
+ String exportKeyPath;
+ if (EncryptionAlgorithmUtils.ALGORITHM_SM4_GCM96_NAME.equals(identifier)) {
+ exportKeyPath = kmsConfig.getPluginKeyPath() + "/export/encryption-key/" + identifier;
+ } else {
+ exportKeyPath = kmsConfig.getDefaultKeyPath() + "/export/encryption-key/" + identifier;
+ }
+ LogicalResponse exportResponse = vault.logical().read(exportKeyPath);
+ if (exportResponse.getRestResponse().getStatus() == 200) {
+ JsonObject keys = exportResponse.getDataObject().get("keys").asObject();
+ return new KmsKey(Base64.getDecoder().decode(StrUtil.trim(keys.get(keys.size() + "").asString(), '"')), keys.size());
+ } else {
+ log.error("Get kms key error! code: {} body: {}", exportResponse.getRestResponse().getStatus(), new String(exportResponse.getRestResponse().getBody()));
+ return null;
+ }
+ }
+
+ public static Vault getVaultClient(KmsConfig kmsConfig, SSLConfig sslConfig) throws VaultException {
+ String username = kmsConfig.getUsername();
+ String password = kmsConfig.getPassword();
+ String url = kmsConfig.getUrl();
+ boolean skipVerification = true;
+ String caCertificatePath = null;
+ String certificatePath = null;
+ String privateKeyPath = null;
+ if (sslConfig != null) {
+ skipVerification = sslConfig.getSkipVerification();
+ caCertificatePath = sslConfig.getCaCertificatePath();
+ certificatePath = sslConfig.getCertificatePath();
+ privateKeyPath = sslConfig.getPrivateKeyPath();
+ }
+ SslConfig vaultSslConfig = new SslConfig().verify(!skipVerification).build();
+ if (!skipVerification) {
+ vaultSslConfig.pemFile(new File(caCertificatePath))
+ .clientPemFile(new File(certificatePath))
+ .clientKeyPemFile(new File(privateKeyPath))
+ .build();
+ }
+ VaultConfig config = new VaultConfig()
+ .address(url)
+ .engineVersion(1)
+ .sslConfig(vaultSslConfig)
+ .build();
+ AuthResponse authResponse = Vault.create(config).auth().loginByUserPass(username, password);
+ config.token(authResponse.getAuthClientToken());
+ return Vault.create(config);
+ }
+}
diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/HmacFunctionTest.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/HmacFunctionTest.java
new file mode 100644
index 0000000..d2219d8
--- /dev/null
+++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/simple/HmacFunctionTest.java
@@ -0,0 +1,136 @@
+package com.geedgenetworks.core.udf.test.simple;
+
+import com.geedgenetworks.common.Event;
+import com.geedgenetworks.common.exception.GrootStreamRuntimeException;
+import com.geedgenetworks.common.udf.UDFContext;
+import com.geedgenetworks.core.udf.Hmac;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class HmacFunctionTest {
+
+ private static final String SECRET_KEY = ".geedgenetworks.";
+ private static final String DATA = "13812345678";
+ private static UDFContext udfContext;
+
+ @BeforeAll
+ public static void setUp() {
+ udfContext = new UDFContext();
+ udfContext.setLookup_fields(Collections.singletonList("phone_number"));
+ udfContext.setOutput_fields(Collections.singletonList("phone_number_mac"));
+ }
+
+ @Test
+ public void testHmacAsBase64() {
+ Map<String, Object> map = new HashMap<>();
+ map.put("secret_key", SECRET_KEY);
+ map.put("algorithm", "sha256");
+ map.put("output_format", "base64");
+ udfContext.setParameters(map);
+ Hmac hmac = new Hmac();
+ hmac.open(null, udfContext);
+ Event event = new Event();
+ Map<String, Object> extractedFields = new HashMap<>();
+ extractedFields.put("phone_number", DATA);
+ event.setExtractedFields(extractedFields);
+ Event result1 = hmac.evaluate(event);
+ assertEquals("zaj6UKovIsDahIBeRZ2PmgPIfDEr900F2xWu+iQfFrw=", result1.getExtractedFields().get("phone_number_mac"));
+ }
+
+ @Test
+ public void testHmacAsHex() {
+ Map<String, Object> map = new HashMap<>();
+ map.put("secret_key", SECRET_KEY);
+ map.put("algorithm", "sha256");
+ map.put("output_format", "hex");
+ udfContext.setParameters(map);
+ Hmac hmac = new Hmac();
+ hmac.open(null, udfContext);
+ Event event = new Event();
+ Map<String, Object> extractedFields = new HashMap<>();
+ extractedFields.put("phone_number", DATA);
+ event.setExtractedFields(extractedFields);
+ Event result1 = hmac.evaluate(event);
+ assertEquals("cda8fa50aa2f22c0da84805e459d8f9a03c87c312bf74d05db15aefa241f16bc", result1.getExtractedFields().get("phone_number_mac"));
+ }
+
+ @Test
+ public void testHmacAlgorithm() {
+ Map<String, Object> map = new HashMap<>();
+ map.put("secret_key", SECRET_KEY);
+ map.put("algorithm", "sm4");
+ map.put("output_format", "base64");
+ udfContext.setParameters(map);
+ Hmac hmac = new Hmac();
+ hmac.open(null, udfContext);
+ Event event = new Event();
+ Map<String, Object> extractedFields = new HashMap<>();
+ extractedFields.put("phone_number", DATA);
+ event.setExtractedFields(extractedFields);
+ Event result = hmac.evaluate(event);
+ assertEquals("QX1q4Y7y3quYCDje9BuSjg==", result.getExtractedFields().get("phone_number_mac"));
+
+ map = new HashMap<>();
+ map.put("secret_key", SECRET_KEY);
+ map.put("algorithm", "sha1");
+ map.put("output_format", "base64");
+ udfContext.setParameters(map);
+ hmac = new Hmac();
+ hmac.open(null, udfContext);
+ event.setExtractedFields(extractedFields);
+ result = hmac.evaluate(event);
+ assertEquals("NB1b1TsVZ95/0sE+d/6kdtyUFh0=", result.getExtractedFields().get("phone_number_mac"));
+
+ map = new HashMap<>();
+ map.put("secret_key", SECRET_KEY);
+ map.put("algorithm", "sm3");
+ map.put("output_format", "base64");
+ udfContext.setParameters(map);
+ hmac = new Hmac();
+ hmac.open(null, udfContext);
+ event.setExtractedFields(extractedFields);
+ result = hmac.evaluate(event);
+ assertEquals("BbQNpwLWE3rkaI1WlPBJgYeD14UyL2OwTxiEoTNA3UU=", result.getExtractedFields().get("phone_number_mac"));
+
+ map = new HashMap<>();
+ map.put("secret_key", SECRET_KEY);
+ map.put("algorithm", "md5");
+ map.put("output_format", "base64");
+ udfContext.setParameters(map);
+ hmac = new Hmac();
+ hmac.open(null, udfContext);
+ event.setExtractedFields(extractedFields);
+ result = hmac.evaluate(event);
+ assertEquals("BQZzRqD3ZR/nJsDIOO4dBg==", result.getExtractedFields().get("phone_number_mac"));
+
+ map = new HashMap<>();
+ map.put("secret_key", SECRET_KEY);
+ map.put("algorithm", "sha512");
+ map.put("output_format", "base64");
+ udfContext.setParameters(map);
+ hmac = new Hmac();
+ hmac.open(null, udfContext);
+ event.setExtractedFields(extractedFields);
+ result = hmac.evaluate(event);
+ assertEquals("DWrndzlcqf2qvFTbuDC1gZCGmRhuAUayfsxEqr2ZlpY/QOr9HgGUZNOfytRfA4VT8OZK0BwHwcAg5pgGBvPQ4A==", result.getExtractedFields().get("phone_number_mac"));
+ }
+
+ @Test
+ public void testHmacError() {
+ Map<String, Object> map = new HashMap<>();
+ map.put("secret_key", SECRET_KEY);
+ map.put("algorithm", "sha256");
+ map.put("output_format", "hex");
+ udfContext.setParameters(map);
+ Hmac hmac = new Hmac();
+ udfContext.getParameters().remove("secret_key");
+ assertThrows(GrootStreamRuntimeException.class, () -> hmac.open(null, udfContext));
+ }
+}