summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwangkuan <[email protected]>2024-11-21 16:47:56 +0800
committerwangkuan <[email protected]>2024-11-21 16:47:56 +0800
commit86b7acc211fe325867303299bfd4cfacc9b66da4 (patch)
treeef9755cc470c894f310c63f4410d27a52aa4742c
parent30c7d561189236529810bd2d16aa246c1a9aa4c4 (diff)
[feature][core]CN-1730 拓展CollectList和CollectSet,增加collect_type配置项用于区分对每个array元素或整个object聚合
-rw-r--r--groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobAggTest.java4
-rw-r--r--groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobDosTest.java2
-rw-r--r--groot-bootstrap/src/test/resources/grootstream_job_agg_test.yaml8
-rw-r--r--groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectList.java80
-rw-r--r--groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectSet.java74
-rw-r--r--groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectListTest.java139
-rw-r--r--groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectSetTest.java145
7 files changed, 285 insertions, 167 deletions
diff --git a/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobAggTest.java b/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobAggTest.java
index fa9c2dd..5945e51 100644
--- a/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobAggTest.java
+++ b/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobAggTest.java
@@ -38,7 +38,7 @@ public class JobAggTest {
.build());
@Test
- public void testSplitForAgg() {
+ public void testAgg() {
CollectSink.values.clear();
String[] args ={"--target", "test", "-c", ".\\grootstream_job_agg_test.yaml"};
@@ -71,7 +71,9 @@ public class JobAggTest {
Assert.assertEquals("3", CollectSink.values.get(1).getExtractedFields().get("pktsmin").toString());
List<String> list = (List<String>) CollectSink.values.get(1).getExtractedFields().get("client_ip_list");
Set<String> set = (Set<String>) CollectSink.values.get(1).getExtractedFields().get("server_ip_set");
+ Set<String> set2 = (Set<String>) CollectSink.values.get(1).getExtractedFields().get("client_ips_set");
Assert.assertEquals(1, set.size());
+ Assert.assertEquals(3, set2.size());
Assert.assertEquals(2, list.size());
Assert.assertEquals("2", CollectSink.values.get(1).getExtractedFields().get("count").toString());
diff --git a/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobDosTest.java b/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobDosTest.java
index bd4f9d8..8016e48 100644
--- a/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobDosTest.java
+++ b/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobDosTest.java
@@ -36,7 +36,7 @@ public class JobDosTest {
.build());
@Test
- public void testSplit() {
+ public void testDos() {
CollectSink.values.clear();
String[] args ={"--target", "test", "-c", ".\\grootstream_job_dos_test.yaml"};
diff --git a/groot-bootstrap/src/test/resources/grootstream_job_agg_test.yaml b/groot-bootstrap/src/test/resources/grootstream_job_agg_test.yaml
index 36a9ad3..1ccaf3d 100644
--- a/groot-bootstrap/src/test/resources/grootstream_job_agg_test.yaml
+++ b/groot-bootstrap/src/test/resources/grootstream_job_agg_test.yaml
@@ -3,7 +3,7 @@ sources:
type : inline
fields: # [array of object] Field List, if not set, all fields(Map<String, Object>) will be output.
properties: # record 3,4 will be aggreated
- data: '[{"pkts":1,"sessions":1,"log_id": 1, "recv_time":"1724925692000", "client_ip":"192.168.0.2","server_ip":"2600:1015:b002::"},{"pkts":1,"sessions":1,"decoded_as":null,"log_id": 1, "recv_time":"1724925692000", "client_ip":"192.168.0.1","server_ip":"2600:1015:b002::"},{"pkts":2,"sessions":1,"decoded_as":"HTTP","log_id": 2, "recv_time":"1724925692000", "client_ip":"192.168.0.2","server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"DNS","log_id": 2, "recv_time":"1724925692000", "client_ip":"192.168.0.2","pkts":3,"server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"DNS","log_id": 1, "recv_time":"1724936692000", "client_ip":"192.168.0.2","pkts":4,"server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"HTTP","log_id": 1, "recv_time":"1724937692000", "client_ip":"192.168.0.2","pkts":5,"server_ip":"2600:1015:b002::"}]'
+ data: '[{"pkts":1,"sessions":1,"log_id": 1, "recv_time":"1724925692000","client_ips":["192.168.0.2","192.168.0.1"],"client_ip":"192.168.0.2","server_ip":"2600:1015:b002::"},{"pkts":1,"sessions":1,"decoded_as":null,"log_id": 1, "recv_time":"1724925692000","client_ips":["192.168.0.2","192.168.0.1"], "client_ip":"192.168.0.1","server_ip":"2600:1015:b002::"},{"pkts":2,"sessions":1,"decoded_as":"HTTP","log_id": 2, "recv_time":"1724925692000","client_ips":["192.168.0.2","192.168.0.3"], "client_ip":"192.168.0.2","server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"DNS","log_id": 2, "recv_time":"1724925692000","client_ips":["192.168.0.2","192.168.0.1"], "client_ip":"192.168.0.2","pkts":3,"server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"DNS","log_id": 1,"client_ips":["192.168.0.2","192.168.0.3"], "recv_time":"1724936692000", "client_ip":"192.168.0.2","pkts":4,"server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"HTTP","log_id": 1, "recv_time":"1724937692000", "client_ip":"192.168.0.2","pkts":5,"server_ip":"2600:1015:b002::"}]'
interval.per.row: 1s # 可选
repeat.count: 1 # 可选
format: json
@@ -51,7 +51,11 @@ postprocessing_pipelines:
- function: LAST_VALUE
lookup_fields: [ log_id ]
output_fields: [ log_id_last ]
-
+ - function: COLLECT_SET
+ lookup_fields: [ client_ips ]
+ output_fields: [ client_ips_set ]
+ parameters:
+ collect_type: array
application: # [object] Application Configuration
env: # [object] Environment Variables
name: groot-stream-job # [string] Job Name
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectList.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectList.java
index b585fdb..bd6b76a 100644
--- a/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectList.java
+++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectList.java
@@ -1,14 +1,13 @@
package com.geedgenetworks.core.udf.udaf;
-import com.geedgenetworks.common.config.Accumulator;
-import com.geedgenetworks.common.exception.CommonErrorCode;
-import com.geedgenetworks.common.exception.GrootStreamRuntimeException;
import com.geedgenetworks.api.common.udf.AggregateFunction;
import com.geedgenetworks.api.common.udf.UDFContext;
import com.geedgenetworks.api.event.Event;
+import com.geedgenetworks.common.config.Accumulator;
+import com.geedgenetworks.common.exception.CommonErrorCode;
+import com.geedgenetworks.common.exception.GrootStreamRuntimeException;
-import java.util.ArrayList;
-import java.util.List;
+import java.util.*;
/**
* Collects elements within a group and returns the list of aggregated objects
@@ -17,19 +16,21 @@ public class CollectList implements AggregateFunction {
private String lookupField;
private String outputField;
+ private String collectType;
@Override
public void open(UDFContext udfContext) {
- if (udfContext.getLookupFields() == null) {
- throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required parameters");
+ // Validate input fields
+ if (udfContext.getLookupFields() == null || udfContext.getLookupFields().isEmpty()) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required lookup field parameter");
}
this.lookupField = udfContext.getLookupFields().get(0);
- if (udfContext.getOutputFields() != null && !udfContext.getOutputFields().isEmpty()) {
- this.outputField = udfContext.getOutputFields().get(0);
- } else {
- outputField = lookupField;
- }
-
+ this.outputField = udfContext.getOutputFields() == null || udfContext.getOutputFields().isEmpty()
+ ? lookupField
+ : udfContext.getOutputFields().get(0);
+ this.collectType = udfContext.getParameters() == null
+ ? "object"
+ : udfContext.getParameters().getOrDefault("collect_type", "object").toString();
}
@Override
@@ -41,18 +42,28 @@ public class CollectList implements AggregateFunction {
@Override
public Accumulator add(Event event, Accumulator acc) {
Object valueObj = event.getExtractedFields().get(lookupField);
- if (valueObj != null) {
- Object object = valueObj;
- List<Object> aggregate = (List<Object>) acc.getMetricsFields().get(outputField);
- aggregate.add(object);
- acc.getMetricsFields().put(outputField, aggregate);
+ if (valueObj == null) {
+ return acc;
+ }
+ if (collectType.equals("array")) {
+ if (valueObj instanceof List<?>) {
+ List<?> valueList = (List<?>) valueObj;
+ List<Object> aggregateList = getOrInitList(acc);
+ aggregateList.addAll(valueList);
+ }
+ } else {
+ getOrInitList(acc).add(valueObj);
}
+
return acc;
}
@Override
- public String functionName() {
- return "COLLECT_LIST";
+ public Accumulator merge(Accumulator firstAcc, Accumulator secondAcc) {
+ Object firstValueObj = firstAcc.getMetricsFields().get(outputField);
+ Object secondValueObj = secondAcc.getMetricsFields().get(outputField);
+ mergeLists(firstAcc, firstValueObj, secondValueObj);
+ return firstAcc;
}
@Override
@@ -61,18 +72,25 @@ public class CollectList implements AggregateFunction {
}
@Override
- public Accumulator merge(Accumulator firstAcc, Accumulator secondAcc) {
- Object firstValueObj = firstAcc.getMetricsFields().get(outputField);
- Object secondValueObj = secondAcc.getMetricsFields().get(outputField);
- if (firstValueObj != null && secondValueObj != null) {
- List<Object> firstValue = (List<Object>) firstValueObj;
- List<Object> secondValue = (List<Object>) secondValueObj;
- firstValue.addAll(secondValue);
- } else if (firstValueObj == null && secondValueObj != null) {
- List<Object> secondValue = (List<Object>) secondValueObj;
- firstAcc.getMetricsFields().put(outputField, secondValue);
+ public String functionName() {
+ return "COLLECT_LIST";
+ }
+
+ private List<Object> getOrInitList(Accumulator acc) {
+ return (List<Object>) acc.getMetricsFields()
+ .computeIfAbsent(outputField, k -> new ArrayList<>());
+ }
+
+
+ @SuppressWarnings("unchecked")
+ private void mergeLists(Accumulator acc, Object firstValueObj, Object secondValueObj) {
+ if (firstValueObj instanceof List<?> && secondValueObj instanceof List<?>) {
+ ((List<Object>) firstValueObj).addAll((List<Object>) secondValueObj);
+ } else if (firstValueObj == null && secondValueObj instanceof List<?>) {
+ acc.getMetricsFields().put(outputField, secondValueObj);
}
- return firstAcc;
}
+
+
}
diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectSet.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectSet.java
index 34789a7..ddca8d1 100644
--- a/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectSet.java
+++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectSet.java
@@ -9,8 +9,7 @@ import com.geedgenetworks.api.common.udf.UDFContext;
import com.geedgenetworks.api.event.Event;
-import java.util.HashSet;
-import java.util.Set;
+import java.util.*;
/**
* Collects elements within a group and returns the list of aggregated objects
@@ -19,19 +18,21 @@ public class CollectSet implements AggregateFunction {
private String lookupField;
private String outputField;
-
+ private String collectType;
@Override
public void open(UDFContext udfContext) {
- if (udfContext.getLookupFields() == null) {
- throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required parameters");
+ // 校验输入字段
+ if (udfContext.getLookupFields() == null || udfContext.getLookupFields().isEmpty()) {
+ throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required lookup field parameter");
}
this.lookupField = udfContext.getLookupFields().get(0);
- if (udfContext.getOutputFields() != null && !udfContext.getOutputFields().isEmpty()) {
- this.outputField = udfContext.getOutputFields().get(0);
- } else {
- outputField = lookupField;
- }
+ this.outputField = udfContext.getOutputFields() == null || udfContext.getOutputFields().isEmpty()
+ ? lookupField
+ : udfContext.getOutputFields().get(0);
+ this.collectType = udfContext.getParameters() == null
+ ? "object"
+ : udfContext.getParameters().getOrDefault("collect_type", "object").toString();
}
@Override
@@ -43,18 +44,29 @@ public class CollectSet implements AggregateFunction {
@Override
public Accumulator add(Event event, Accumulator acc) {
Object valueObj = event.getExtractedFields().get(lookupField);
- if (valueObj != null) {
- Object object = valueObj;
- Set<Object> aggregate = (Set<Object>) acc.getMetricsFields().get(outputField);
- aggregate.add(object);
- acc.getMetricsFields().put(outputField, aggregate);
+ if (valueObj == null) {
+ return acc;
}
+
+ if (collectType.equals("array")) {
+ if (valueObj instanceof List<?>) {
+ List<?> valueList = (List<?>) valueObj;
+ Set<Object> aggregateSet = getOrInitSet(acc);
+ aggregateSet.addAll(valueList);
+ }
+ } else {
+ getOrInitSet(acc).add(valueObj);
+ }
+
return acc;
}
@Override
- public String functionName() {
- return "COLLECT_SET";
+ public Accumulator merge(Accumulator firstAcc, Accumulator secondAcc) {
+ Object firstValueObj = firstAcc.getMetricsFields().get(outputField);
+ Object secondValueObj = secondAcc.getMetricsFields().get(outputField);
+ mergeSets(firstAcc, firstValueObj, secondValueObj);
+ return firstAcc;
}
@Override
@@ -63,17 +75,23 @@ public class CollectSet implements AggregateFunction {
}
@Override
- public Accumulator merge(Accumulator firstAcc, Accumulator secondAcc) {
- Object firstValueObj = firstAcc.getMetricsFields().get(outputField);
- Object secondValueObj = secondAcc.getMetricsFields().get(outputField);
- if (firstValueObj != null && secondValueObj != null) {
- Set<Object> firstValue = (Set<Object>) firstValueObj;
- Set<Object> secondValue = (Set<Object>) secondValueObj;
- firstValue.addAll(secondValue);
- } else if (firstValueObj == null && secondValueObj !=null) {
- Set<Object> secondValue = (Set<Object>)secondValueObj;
- firstAcc.getMetricsFields().put(outputField, secondValue);
+ public String functionName() {
+ return "COLLECT_SET";
+ }
+
+ private Set<Object> getOrInitSet(Accumulator acc) {
+ return (Set<Object>) acc.getMetricsFields()
+ .computeIfAbsent(outputField, k -> new HashSet<>());
+ }
+
+ @SuppressWarnings("unchecked")
+ private void mergeSets(Accumulator acc, Object firstValueObj, Object secondValueObj) {
+ if (firstValueObj instanceof Set<?> && secondValueObj instanceof Set<?>) {
+ ((Set<Object>) firstValueObj).addAll((Set<Object>) secondValueObj);
+ } else if (firstValueObj == null && secondValueObj instanceof Set<?>) {
+ acc.getMetricsFields().put(outputField, secondValueObj);
}
- return firstAcc;
}
+
+
}
diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectListTest.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectListTest.java
index e27ae01..a3a0487 100644
--- a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectListTest.java
+++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectListTest.java
@@ -7,89 +7,118 @@ import com.geedgenetworks.api.common.udf.UDFContext;
import com.geedgenetworks.api.event.Event;
import org.junit.jupiter.api.Test;
-import java.text.ParseException;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
-import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.*;
public class CollectListTest {
@Test
- public void test() throws ParseException {
+ public void testObjectType() {
+ // 测试默认 "object" 类型
+ List<String> arr = List.of("192.168.1.1", "192.168.1.2", "192.168.1.1");
+ UDFContext udfContext = createUDFContext("object");
+ CollectList collectList = new CollectList();
+ collectList.open(udfContext);
- List<String> arr = List.of("192.168.1.1", "192.168.1.2", "192.168.1.3", "192.168.1.4");
- List<String> arr2 = List.of("192.168.1.5", "192.168.1.6", "192.168.1.3", "192.168.1.4");
- testMerge(arr,arr2);
- testGetResult(arr);
+ Accumulator accumulator = initializeAccumulator(collectList, udfContext);
+ for (String ip : arr) {
+ accumulator = addEvent(collectList, udfContext, accumulator, "field", ip);
+ }
+ Accumulator result = collectList.getResult(accumulator);
+ List<String> aggregated = (List<String>) result.getMetricsFields().get("field_list");
+ assertEquals(aggregated.size(), 3);
+ assertEquals("192.168.1.1", aggregated.get(0));
}
- private void testMerge(List<String> arr,List<String> arr2) {
-
- UDFContext udfContext = new UDFContext();
- udfContext.setLookupFields(List.of("field"));
- udfContext.setOutputFields(Collections.singletonList("field_list"));
+ @Test
+ public void testArrayType() {
+ // 测试 "array" 类型
+ List<List<String>> arrays = List.of(
+ List.of("192.168.1.1", "192.168.1.2"),
+ List.of("192.168.1.3", "192.168.1.4")
+ );
+ UDFContext udfContext = createUDFContext("array");
CollectList collectList = new CollectList();
- Map<String, Object> metricsFields = new HashMap<>();
- Accumulator accumulator = new Accumulator();
- accumulator.setMetricsFields(metricsFields);
collectList.open(udfContext);
- Accumulator result1 = getMiddleResult(udfContext,arr);
- Accumulator result2 = getMiddleResult(udfContext,arr2);
- Accumulator result = collectList.getResult(collectList.merge(result1,result2));
- List<String> vals = (List<String>) result.getMetricsFields().get("field_list");
- assertEquals(vals.size(),8);
- assertEquals("192.168.1.6",vals.get(5).toString());
+ Accumulator accumulator = initializeAccumulator(collectList, udfContext);
+ for (List<String> array : arrays) {
+ accumulator = addEvent(collectList, udfContext, accumulator, "field", array);
+ }
+
+ Accumulator result = collectList.getResult(accumulator);
+ List<List<String>> aggregated = (List<List<String>>) result.getMetricsFields().get("field_list");
+ assertEquals(aggregated.size(), 4);
+ assertEquals("192.168.1.1", aggregated.get(0));
}
- private Accumulator getMiddleResult(UDFContext udfContext,List<String> arr) {
+ @Test
+ public void testMerge() {
+ // 测试合并逻辑
+ List<String> arr1 = List.of("192.168.1.1", "192.168.1.2");
+ List<String> arr2 = List.of("192.168.1.3", "192.168.1.4");
+ UDFContext udfContext = createUDFContext("object");
CollectList collectList = new CollectList();
- Map<String, Object> metricsFields = new HashMap<>();
- Accumulator accumulator = new Accumulator();
- accumulator.setMetricsFields(metricsFields);
collectList.open(udfContext);
- Accumulator agg = collectList.initAccumulator(accumulator);
- for (String o : arr) {
- Event event = new Event();
- Map<String, Object> extractedFields = new HashMap<>();
- extractedFields.put("field", o);
- event.setExtractedFields(extractedFields);
- agg = collectList.add(event, agg);
+ Accumulator result1 = createAccumulatorFromList(collectList, udfContext, arr1);
+ Accumulator result2 = createAccumulatorFromList(collectList, udfContext, arr2);
- }
- return agg;
+ Accumulator merged = collectList.merge(result1, result2);
+ List<String> aggregated = (List<String>) merged.getMetricsFields().get("field_list");
+
+ assertEquals(aggregated.size(), 4);
+ assertEquals("192.168.1.4", aggregated.get(3));
}
- private void testGetResult(List<String> arr) throws ParseException {
+ @Test
+ public void testEmptyInput() {
+ // 测试空输入
+ UDFContext udfContext = createUDFContext("object");
+ CollectList collectList = new CollectList();
+ collectList.open(udfContext);
+
+ Accumulator accumulator = initializeAccumulator(collectList, udfContext);
+ Accumulator result = collectList.getResult(accumulator);
+
+ List<String> aggregated = (List<String>) result.getMetricsFields().get("field_list");
+ assertTrue(aggregated.isEmpty());
+ }
+
+
+ private UDFContext createUDFContext(String collectType) {
UDFContext udfContext = new UDFContext();
udfContext.setLookupFields(List.of("field"));
udfContext.setOutputFields(Collections.singletonList("field_list"));
- CollectList collectList = new CollectList();
+ Map<String, Object> parameters = new HashMap<>();
+ parameters.put("collect_type", collectType);
+ udfContext.setParameters(parameters);
+ return udfContext;
+ }
+
+ private Accumulator initializeAccumulator(CollectList collectList, UDFContext udfContext) {
Map<String, Object> metricsFields = new HashMap<>();
Accumulator accumulator = new Accumulator();
accumulator.setMetricsFields(metricsFields);
- collectList.open(udfContext);
- Accumulator agg = collectList.initAccumulator(accumulator);
-
- for (String o : arr) {
- Event event = new Event();
- Map<String, Object> extractedFields = new HashMap<>();
- extractedFields.put("field", o);
- event.setExtractedFields(extractedFields);
- agg = collectList.add(event, agg);
+ return collectList.initAccumulator(accumulator);
+ }
+ private Accumulator createAccumulatorFromList(CollectList collectList, UDFContext udfContext, List<String> values) {
+ Accumulator accumulator = initializeAccumulator(collectList, udfContext);
+ for (String value : values) {
+ accumulator = addEvent(collectList, udfContext, accumulator, "field", value);
}
- Accumulator result = collectList.getResult(agg);
- List<String> vals = (List<String>) result.getMetricsFields().get("field_list");
- assertEquals(vals.size(),4);
+ return accumulator;
}
-
-
-} \ No newline at end of file
+ private Accumulator addEvent(CollectList collectList, UDFContext udfContext, Accumulator accumulator, String field, Object value) {
+ Event event = new Event();
+ Map<String, Object> extractedFields = new HashMap<>();
+ extractedFields.put(field, value);
+ event.setExtractedFields(extractedFields);
+ return collectList.add(event, accumulator);
+ }
+}
diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectSetTest.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectSetTest.java
index defd6a7..694e482 100644
--- a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectSetTest.java
+++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectSetTest.java
@@ -1,90 +1,137 @@
package com.geedgenetworks.core.udf.test.aggregate;
-
import com.geedgenetworks.common.config.Accumulator;
import com.geedgenetworks.core.udf.udaf.CollectSet;
import com.geedgenetworks.api.common.udf.UDFContext;
import com.geedgenetworks.api.event.Event;
import org.junit.jupiter.api.Test;
-import java.text.ParseException;
import java.util.*;
-import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.*;
public class CollectSetTest {
@Test
- public void test() throws ParseException {
+ public void testObjectType() {
+ // 测试 "object" 类型,验证去重
+ List<String> arr = List.of("192.168.1.1", "192.168.1.2", "192.168.1.1");
+ UDFContext udfContext = createUDFContext("object");
+ CollectSet collectSet = new CollectSet();
+ collectSet.open(udfContext);
- List<String> arr = List.of("192.168.1.1", "192.168.1.2", "192.168.1.3", "192.168.1.4","192.168.1.4");
- List<String> arr2 = List.of("192.168.1.5", "192.168.1.6", "192.168.1.3", "192.168.1.4");
- testMerge(arr,arr2);
- testGetResult(arr);
+ Accumulator accumulator = initializeAccumulator(collectSet, udfContext);
+ for (String ip : arr) {
+ accumulator = addEvent(collectSet, udfContext, accumulator, "field", ip);
+ }
+ Accumulator result = collectSet.getResult(accumulator);
+ Set<String> aggregated = (Set<String>) result.getMetricsFields().get("field_list");
+ assertEquals(aggregated.size(), 2);
+ assertTrue(aggregated.contains("192.168.1.1"));
+ assertTrue(aggregated.contains("192.168.1.2"));
}
- private void testMerge(List<String> arr,List<String> arr2) {
+ @Test
+ public void testEmptyInput() {
+ // 测试空输入
+ UDFContext udfContext = createUDFContext("object");
+ CollectSet collectSet = new CollectSet();
+ collectSet.open(udfContext);
- UDFContext udfContext = new UDFContext();
- udfContext.setLookupFields(List.of("field"));
- udfContext.setOutputFields(Collections.singletonList("field_list"));
+ Accumulator accumulator = initializeAccumulator(collectSet, udfContext);
+ Accumulator result = collectSet.getResult(accumulator);
+
+ Set<String> aggregated = (Set<String>) result.getMetricsFields().get("field_list");
+ assertNotNull(aggregated);
+ assertTrue(aggregated.isEmpty());
+ }
+
+ @Test
+ public void testMerge() {
+ // 测试合并逻辑
+ List<String> arr1 = List.of("192.168.1.1", "192.168.1.2");
+ List<String> arr2 = List.of("192.168.1.3", "192.168.1.1");
+
+ UDFContext udfContext = createUDFContext("object");
CollectSet collectSet = new CollectSet();
- Map<String, Object> metricsFields = new HashMap<>();
- Accumulator accumulator = new Accumulator();
- accumulator.setMetricsFields(metricsFields);
collectSet.open(udfContext);
- Accumulator result1 = getMiddleResult(udfContext,arr);
- Accumulator result2 = getMiddleResult(udfContext,arr2);
- Accumulator result = collectSet.getResult(collectSet.merge(result1,result2));
- Set<String> vals = (Set<String>) result.getMetricsFields().get("field_list");
- assertEquals(vals.size(),6);
+ Accumulator acc1 = createAccumulatorFromList(collectSet, udfContext, arr1);
+ Accumulator acc2 = createAccumulatorFromList(collectSet, udfContext, arr2);
+
+ Accumulator merged = collectSet.merge(acc1, acc2);
+ Set<String> aggregated = (Set<String>) merged.getMetricsFields().get("field_list");
+
+ assertEquals(aggregated.size(), 3);
+ assertTrue(aggregated.containsAll(List.of("192.168.1.1", "192.168.1.2", "192.168.1.3")));
}
- private Accumulator getMiddleResult(UDFContext udfContext,List<String> arr) {
+ @Test
+ public void testNullValues() {
+ // 测试字段为 null 的场景
+ UDFContext udfContext = createUDFContext("object");
+ CollectSet collectSet = new CollectSet();
+ collectSet.open(udfContext);
+
+ Accumulator accumulator = initializeAccumulator(collectSet, udfContext);
+ accumulator = addEvent(collectSet, udfContext, accumulator, "field", null);
+
+ Accumulator result = collectSet.getResult(accumulator);
+ Set<String> aggregated = (Set<String>) result.getMetricsFields().get("field_list");
+
+ assertNotNull(aggregated);
+ assertTrue(aggregated.isEmpty());
+ }
+
+ @Test
+ public void testDuplicateValues() {
+ // 测试重复值是否被正确去重
+ List<String> arr = List.of("192.168.1.1", "192.168.1.1", "192.168.1.1");
+ UDFContext udfContext = createUDFContext("object");
CollectSet collectSet = new CollectSet();
- Map<String, Object> metricsFields = new HashMap<>();
- Accumulator accumulator = new Accumulator();
- accumulator.setMetricsFields(metricsFields);
collectSet.open(udfContext);
- Accumulator agg = collectSet.initAccumulator(accumulator);
- for (String o : arr) {
- Event event = new Event();
- Map<String, Object> extractedFields = new HashMap<>();
- extractedFields.put("field", o);
- event.setExtractedFields(extractedFields);
- agg = collectSet.add(event, agg);
+ Accumulator accumulator = createAccumulatorFromList(collectSet, udfContext, arr);
- }
- return agg;
+ Accumulator result = collectSet.getResult(accumulator);
+ Set<String> aggregated = (Set<String>) result.getMetricsFields().get("field_list");
+
+ assertEquals(aggregated.size(), 1);
+ assertTrue(aggregated.contains("192.168.1.1"));
}
- private static void testGetResult(List<String> arr) throws ParseException {
+ private UDFContext createUDFContext(String collectType) {
UDFContext udfContext = new UDFContext();
udfContext.setLookupFields(List.of("field"));
udfContext.setOutputFields(Collections.singletonList("field_list"));
- CollectSet collectSet = new CollectSet();
+ Map<String, Object> parameters = new HashMap<>();
+ parameters.put("collect_type", collectType);
+ udfContext.setParameters(parameters);
+ return udfContext;
+ }
+
+ private Accumulator initializeAccumulator(CollectSet collectSet, UDFContext udfContext) {
Map<String, Object> metricsFields = new HashMap<>();
Accumulator accumulator = new Accumulator();
accumulator.setMetricsFields(metricsFields);
- collectSet.open(udfContext);
- Accumulator agg = collectSet.initAccumulator(accumulator);
- for (String o : arr) {
- Event event = new Event();
- Map<String, Object> extractedFields = new HashMap<>();
- extractedFields.put("field", o);
- event.setExtractedFields(extractedFields);
- agg = collectSet.add(event, agg);
+ return collectSet.initAccumulator(accumulator);
+ }
+ private Accumulator createAccumulatorFromList(CollectSet collectSet, UDFContext udfContext, List<String> values) {
+ Accumulator accumulator = initializeAccumulator(collectSet, udfContext);
+ for (String value : values) {
+ accumulator = addEvent(collectSet, udfContext, accumulator, "field", value);
}
- Accumulator result = collectSet.getResult(agg);
- Set<String> vals = (Set<String>) result.getMetricsFields().get("field_list");
- assertEquals(vals.size(),4);
+ return accumulator;
}
-
-
-} \ No newline at end of file
+ private Accumulator addEvent(CollectSet collectSet, UDFContext udfContext, Accumulator accumulator, String field, Object value) {
+ Event event = new Event();
+ Map<String, Object> extractedFields = new HashMap<>();
+ extractedFields.put(field, value);
+ event.setExtractedFields(extractedFields);
+ return collectSet.add(event, accumulator);
+ }
+}