diff options
| author | wangkuan <[email protected]> | 2024-11-21 16:47:56 +0800 |
|---|---|---|
| committer | wangkuan <[email protected]> | 2024-11-21 16:47:56 +0800 |
| commit | 86b7acc211fe325867303299bfd4cfacc9b66da4 (patch) | |
| tree | ef9755cc470c894f310c63f4410d27a52aa4742c | |
| parent | 30c7d561189236529810bd2d16aa246c1a9aa4c4 (diff) | |
[feature][core]CN-1730 拓展CollectList和CollectSet,增加collect_type配置项用于区分对每个array元素或整个object聚合
7 files changed, 285 insertions, 167 deletions
diff --git a/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobAggTest.java b/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobAggTest.java index fa9c2dd..5945e51 100644 --- a/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobAggTest.java +++ b/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobAggTest.java @@ -38,7 +38,7 @@ public class JobAggTest { .build()); @Test - public void testSplitForAgg() { + public void testAgg() { CollectSink.values.clear(); String[] args ={"--target", "test", "-c", ".\\grootstream_job_agg_test.yaml"}; @@ -71,7 +71,9 @@ public class JobAggTest { Assert.assertEquals("3", CollectSink.values.get(1).getExtractedFields().get("pktsmin").toString()); List<String> list = (List<String>) CollectSink.values.get(1).getExtractedFields().get("client_ip_list"); Set<String> set = (Set<String>) CollectSink.values.get(1).getExtractedFields().get("server_ip_set"); + Set<String> set2 = (Set<String>) CollectSink.values.get(1).getExtractedFields().get("client_ips_set"); Assert.assertEquals(1, set.size()); + Assert.assertEquals(3, set2.size()); Assert.assertEquals(2, list.size()); Assert.assertEquals("2", CollectSink.values.get(1).getExtractedFields().get("count").toString()); diff --git a/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobDosTest.java b/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobDosTest.java index bd4f9d8..8016e48 100644 --- a/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobDosTest.java +++ b/groot-bootstrap/src/test/java/com/geedgenetworks/bootstrap/main/simple/JobDosTest.java @@ -36,7 +36,7 @@ public class JobDosTest { .build()); @Test - public void testSplit() { + public void testDos() { CollectSink.values.clear(); String[] args ={"--target", "test", "-c", ".\\grootstream_job_dos_test.yaml"}; diff --git a/groot-bootstrap/src/test/resources/grootstream_job_agg_test.yaml b/groot-bootstrap/src/test/resources/grootstream_job_agg_test.yaml index 36a9ad3..1ccaf3d 100644 --- a/groot-bootstrap/src/test/resources/grootstream_job_agg_test.yaml +++ b/groot-bootstrap/src/test/resources/grootstream_job_agg_test.yaml @@ -3,7 +3,7 @@ sources: type : inline fields: # [array of object] Field List, if not set, all fields(Map<String, Object>) will be output. properties: # record 3,4 will be aggreated - data: '[{"pkts":1,"sessions":1,"log_id": 1, "recv_time":"1724925692000", "client_ip":"192.168.0.2","server_ip":"2600:1015:b002::"},{"pkts":1,"sessions":1,"decoded_as":null,"log_id": 1, "recv_time":"1724925692000", "client_ip":"192.168.0.1","server_ip":"2600:1015:b002::"},{"pkts":2,"sessions":1,"decoded_as":"HTTP","log_id": 2, "recv_time":"1724925692000", "client_ip":"192.168.0.2","server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"DNS","log_id": 2, "recv_time":"1724925692000", "client_ip":"192.168.0.2","pkts":3,"server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"DNS","log_id": 1, "recv_time":"1724936692000", "client_ip":"192.168.0.2","pkts":4,"server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"HTTP","log_id": 1, "recv_time":"1724937692000", "client_ip":"192.168.0.2","pkts":5,"server_ip":"2600:1015:b002::"}]' + data: '[{"pkts":1,"sessions":1,"log_id": 1, "recv_time":"1724925692000","client_ips":["192.168.0.2","192.168.0.1"],"client_ip":"192.168.0.2","server_ip":"2600:1015:b002::"},{"pkts":1,"sessions":1,"decoded_as":null,"log_id": 1, "recv_time":"1724925692000","client_ips":["192.168.0.2","192.168.0.1"], "client_ip":"192.168.0.1","server_ip":"2600:1015:b002::"},{"pkts":2,"sessions":1,"decoded_as":"HTTP","log_id": 2, "recv_time":"1724925692000","client_ips":["192.168.0.2","192.168.0.3"], "client_ip":"192.168.0.2","server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"DNS","log_id": 2, "recv_time":"1724925692000","client_ips":["192.168.0.2","192.168.0.1"], "client_ip":"192.168.0.2","pkts":3,"server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"DNS","log_id": 1,"client_ips":["192.168.0.2","192.168.0.3"], "recv_time":"1724936692000", "client_ip":"192.168.0.2","pkts":4,"server_ip":"2600:1015:b002::"},{"sessions":1,"decoded_as":"HTTP","log_id": 1, "recv_time":"1724937692000", "client_ip":"192.168.0.2","pkts":5,"server_ip":"2600:1015:b002::"}]' interval.per.row: 1s # 可选 repeat.count: 1 # 可选 format: json @@ -51,7 +51,11 @@ postprocessing_pipelines: - function: LAST_VALUE lookup_fields: [ log_id ] output_fields: [ log_id_last ] - + - function: COLLECT_SET + lookup_fields: [ client_ips ] + output_fields: [ client_ips_set ] + parameters: + collect_type: array application: # [object] Application Configuration env: # [object] Environment Variables name: groot-stream-job # [string] Job Name diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectList.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectList.java index b585fdb..bd6b76a 100644 --- a/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectList.java +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectList.java @@ -1,14 +1,13 @@ package com.geedgenetworks.core.udf.udaf; -import com.geedgenetworks.common.config.Accumulator; -import com.geedgenetworks.common.exception.CommonErrorCode; -import com.geedgenetworks.common.exception.GrootStreamRuntimeException; import com.geedgenetworks.api.common.udf.AggregateFunction; import com.geedgenetworks.api.common.udf.UDFContext; import com.geedgenetworks.api.event.Event; +import com.geedgenetworks.common.config.Accumulator; +import com.geedgenetworks.common.exception.CommonErrorCode; +import com.geedgenetworks.common.exception.GrootStreamRuntimeException; -import java.util.ArrayList; -import java.util.List; +import java.util.*; /** * Collects elements within a group and returns the list of aggregated objects @@ -17,19 +16,21 @@ public class CollectList implements AggregateFunction { private String lookupField; private String outputField; + private String collectType; @Override public void open(UDFContext udfContext) { - if (udfContext.getLookupFields() == null) { - throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required parameters"); + // Validate input fields + if (udfContext.getLookupFields() == null || udfContext.getLookupFields().isEmpty()) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required lookup field parameter"); } this.lookupField = udfContext.getLookupFields().get(0); - if (udfContext.getOutputFields() != null && !udfContext.getOutputFields().isEmpty()) { - this.outputField = udfContext.getOutputFields().get(0); - } else { - outputField = lookupField; - } - + this.outputField = udfContext.getOutputFields() == null || udfContext.getOutputFields().isEmpty() + ? lookupField + : udfContext.getOutputFields().get(0); + this.collectType = udfContext.getParameters() == null + ? "object" + : udfContext.getParameters().getOrDefault("collect_type", "object").toString(); } @Override @@ -41,18 +42,28 @@ public class CollectList implements AggregateFunction { @Override public Accumulator add(Event event, Accumulator acc) { Object valueObj = event.getExtractedFields().get(lookupField); - if (valueObj != null) { - Object object = valueObj; - List<Object> aggregate = (List<Object>) acc.getMetricsFields().get(outputField); - aggregate.add(object); - acc.getMetricsFields().put(outputField, aggregate); + if (valueObj == null) { + return acc; + } + if (collectType.equals("array")) { + if (valueObj instanceof List<?>) { + List<?> valueList = (List<?>) valueObj; + List<Object> aggregateList = getOrInitList(acc); + aggregateList.addAll(valueList); + } + } else { + getOrInitList(acc).add(valueObj); } + return acc; } @Override - public String functionName() { - return "COLLECT_LIST"; + public Accumulator merge(Accumulator firstAcc, Accumulator secondAcc) { + Object firstValueObj = firstAcc.getMetricsFields().get(outputField); + Object secondValueObj = secondAcc.getMetricsFields().get(outputField); + mergeLists(firstAcc, firstValueObj, secondValueObj); + return firstAcc; } @Override @@ -61,18 +72,25 @@ public class CollectList implements AggregateFunction { } @Override - public Accumulator merge(Accumulator firstAcc, Accumulator secondAcc) { - Object firstValueObj = firstAcc.getMetricsFields().get(outputField); - Object secondValueObj = secondAcc.getMetricsFields().get(outputField); - if (firstValueObj != null && secondValueObj != null) { - List<Object> firstValue = (List<Object>) firstValueObj; - List<Object> secondValue = (List<Object>) secondValueObj; - firstValue.addAll(secondValue); - } else if (firstValueObj == null && secondValueObj != null) { - List<Object> secondValue = (List<Object>) secondValueObj; - firstAcc.getMetricsFields().put(outputField, secondValue); + public String functionName() { + return "COLLECT_LIST"; + } + + private List<Object> getOrInitList(Accumulator acc) { + return (List<Object>) acc.getMetricsFields() + .computeIfAbsent(outputField, k -> new ArrayList<>()); + } + + + @SuppressWarnings("unchecked") + private void mergeLists(Accumulator acc, Object firstValueObj, Object secondValueObj) { + if (firstValueObj instanceof List<?> && secondValueObj instanceof List<?>) { + ((List<Object>) firstValueObj).addAll((List<Object>) secondValueObj); + } else if (firstValueObj == null && secondValueObj instanceof List<?>) { + acc.getMetricsFields().put(outputField, secondValueObj); } - return firstAcc; } + + } diff --git a/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectSet.java b/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectSet.java index 34789a7..ddca8d1 100644 --- a/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectSet.java +++ b/groot-core/src/main/java/com/geedgenetworks/core/udf/udaf/CollectSet.java @@ -9,8 +9,7 @@ import com.geedgenetworks.api.common.udf.UDFContext; import com.geedgenetworks.api.event.Event; -import java.util.HashSet; -import java.util.Set; +import java.util.*; /** * Collects elements within a group and returns the list of aggregated objects @@ -19,19 +18,21 @@ public class CollectSet implements AggregateFunction { private String lookupField; private String outputField; - + private String collectType; @Override public void open(UDFContext udfContext) { - if (udfContext.getLookupFields() == null) { - throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required parameters"); + // 校验输入字段 + if (udfContext.getLookupFields() == null || udfContext.getLookupFields().isEmpty()) { + throw new GrootStreamRuntimeException(CommonErrorCode.ILLEGAL_ARGUMENT, "Missing required lookup field parameter"); } this.lookupField = udfContext.getLookupFields().get(0); - if (udfContext.getOutputFields() != null && !udfContext.getOutputFields().isEmpty()) { - this.outputField = udfContext.getOutputFields().get(0); - } else { - outputField = lookupField; - } + this.outputField = udfContext.getOutputFields() == null || udfContext.getOutputFields().isEmpty() + ? lookupField + : udfContext.getOutputFields().get(0); + this.collectType = udfContext.getParameters() == null + ? "object" + : udfContext.getParameters().getOrDefault("collect_type", "object").toString(); } @Override @@ -43,18 +44,29 @@ public class CollectSet implements AggregateFunction { @Override public Accumulator add(Event event, Accumulator acc) { Object valueObj = event.getExtractedFields().get(lookupField); - if (valueObj != null) { - Object object = valueObj; - Set<Object> aggregate = (Set<Object>) acc.getMetricsFields().get(outputField); - aggregate.add(object); - acc.getMetricsFields().put(outputField, aggregate); + if (valueObj == null) { + return acc; } + + if (collectType.equals("array")) { + if (valueObj instanceof List<?>) { + List<?> valueList = (List<?>) valueObj; + Set<Object> aggregateSet = getOrInitSet(acc); + aggregateSet.addAll(valueList); + } + } else { + getOrInitSet(acc).add(valueObj); + } + return acc; } @Override - public String functionName() { - return "COLLECT_SET"; + public Accumulator merge(Accumulator firstAcc, Accumulator secondAcc) { + Object firstValueObj = firstAcc.getMetricsFields().get(outputField); + Object secondValueObj = secondAcc.getMetricsFields().get(outputField); + mergeSets(firstAcc, firstValueObj, secondValueObj); + return firstAcc; } @Override @@ -63,17 +75,23 @@ public class CollectSet implements AggregateFunction { } @Override - public Accumulator merge(Accumulator firstAcc, Accumulator secondAcc) { - Object firstValueObj = firstAcc.getMetricsFields().get(outputField); - Object secondValueObj = secondAcc.getMetricsFields().get(outputField); - if (firstValueObj != null && secondValueObj != null) { - Set<Object> firstValue = (Set<Object>) firstValueObj; - Set<Object> secondValue = (Set<Object>) secondValueObj; - firstValue.addAll(secondValue); - } else if (firstValueObj == null && secondValueObj !=null) { - Set<Object> secondValue = (Set<Object>)secondValueObj; - firstAcc.getMetricsFields().put(outputField, secondValue); + public String functionName() { + return "COLLECT_SET"; + } + + private Set<Object> getOrInitSet(Accumulator acc) { + return (Set<Object>) acc.getMetricsFields() + .computeIfAbsent(outputField, k -> new HashSet<>()); + } + + @SuppressWarnings("unchecked") + private void mergeSets(Accumulator acc, Object firstValueObj, Object secondValueObj) { + if (firstValueObj instanceof Set<?> && secondValueObj instanceof Set<?>) { + ((Set<Object>) firstValueObj).addAll((Set<Object>) secondValueObj); + } else if (firstValueObj == null && secondValueObj instanceof Set<?>) { + acc.getMetricsFields().put(outputField, secondValueObj); } - return firstAcc; } + + } diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectListTest.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectListTest.java index e27ae01..a3a0487 100644 --- a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectListTest.java +++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectListTest.java @@ -7,89 +7,118 @@ import com.geedgenetworks.api.common.udf.UDFContext; import com.geedgenetworks.api.event.Event; import org.junit.jupiter.api.Test; -import java.text.ParseException; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.*; public class CollectListTest { @Test - public void test() throws ParseException { + public void testObjectType() { + // 测试默认 "object" 类型 + List<String> arr = List.of("192.168.1.1", "192.168.1.2", "192.168.1.1"); + UDFContext udfContext = createUDFContext("object"); + CollectList collectList = new CollectList(); + collectList.open(udfContext); - List<String> arr = List.of("192.168.1.1", "192.168.1.2", "192.168.1.3", "192.168.1.4"); - List<String> arr2 = List.of("192.168.1.5", "192.168.1.6", "192.168.1.3", "192.168.1.4"); - testMerge(arr,arr2); - testGetResult(arr); + Accumulator accumulator = initializeAccumulator(collectList, udfContext); + for (String ip : arr) { + accumulator = addEvent(collectList, udfContext, accumulator, "field", ip); + } + Accumulator result = collectList.getResult(accumulator); + List<String> aggregated = (List<String>) result.getMetricsFields().get("field_list"); + assertEquals(aggregated.size(), 3); + assertEquals("192.168.1.1", aggregated.get(0)); } - private void testMerge(List<String> arr,List<String> arr2) { - - UDFContext udfContext = new UDFContext(); - udfContext.setLookupFields(List.of("field")); - udfContext.setOutputFields(Collections.singletonList("field_list")); + @Test + public void testArrayType() { + // 测试 "array" 类型 + List<List<String>> arrays = List.of( + List.of("192.168.1.1", "192.168.1.2"), + List.of("192.168.1.3", "192.168.1.4") + ); + UDFContext udfContext = createUDFContext("array"); CollectList collectList = new CollectList(); - Map<String, Object> metricsFields = new HashMap<>(); - Accumulator accumulator = new Accumulator(); - accumulator.setMetricsFields(metricsFields); collectList.open(udfContext); - Accumulator result1 = getMiddleResult(udfContext,arr); - Accumulator result2 = getMiddleResult(udfContext,arr2); - Accumulator result = collectList.getResult(collectList.merge(result1,result2)); - List<String> vals = (List<String>) result.getMetricsFields().get("field_list"); - assertEquals(vals.size(),8); - assertEquals("192.168.1.6",vals.get(5).toString()); + Accumulator accumulator = initializeAccumulator(collectList, udfContext); + for (List<String> array : arrays) { + accumulator = addEvent(collectList, udfContext, accumulator, "field", array); + } + + Accumulator result = collectList.getResult(accumulator); + List<List<String>> aggregated = (List<List<String>>) result.getMetricsFields().get("field_list"); + assertEquals(aggregated.size(), 4); + assertEquals("192.168.1.1", aggregated.get(0)); } - private Accumulator getMiddleResult(UDFContext udfContext,List<String> arr) { + @Test + public void testMerge() { + // 测试合并逻辑 + List<String> arr1 = List.of("192.168.1.1", "192.168.1.2"); + List<String> arr2 = List.of("192.168.1.3", "192.168.1.4"); + UDFContext udfContext = createUDFContext("object"); CollectList collectList = new CollectList(); - Map<String, Object> metricsFields = new HashMap<>(); - Accumulator accumulator = new Accumulator(); - accumulator.setMetricsFields(metricsFields); collectList.open(udfContext); - Accumulator agg = collectList.initAccumulator(accumulator); - for (String o : arr) { - Event event = new Event(); - Map<String, Object> extractedFields = new HashMap<>(); - extractedFields.put("field", o); - event.setExtractedFields(extractedFields); - agg = collectList.add(event, agg); + Accumulator result1 = createAccumulatorFromList(collectList, udfContext, arr1); + Accumulator result2 = createAccumulatorFromList(collectList, udfContext, arr2); - } - return agg; + Accumulator merged = collectList.merge(result1, result2); + List<String> aggregated = (List<String>) merged.getMetricsFields().get("field_list"); + + assertEquals(aggregated.size(), 4); + assertEquals("192.168.1.4", aggregated.get(3)); } - private void testGetResult(List<String> arr) throws ParseException { + @Test + public void testEmptyInput() { + // 测试空输入 + UDFContext udfContext = createUDFContext("object"); + CollectList collectList = new CollectList(); + collectList.open(udfContext); + + Accumulator accumulator = initializeAccumulator(collectList, udfContext); + Accumulator result = collectList.getResult(accumulator); + + List<String> aggregated = (List<String>) result.getMetricsFields().get("field_list"); + assertTrue(aggregated.isEmpty()); + } + + + private UDFContext createUDFContext(String collectType) { UDFContext udfContext = new UDFContext(); udfContext.setLookupFields(List.of("field")); udfContext.setOutputFields(Collections.singletonList("field_list")); - CollectList collectList = new CollectList(); + Map<String, Object> parameters = new HashMap<>(); + parameters.put("collect_type", collectType); + udfContext.setParameters(parameters); + return udfContext; + } + + private Accumulator initializeAccumulator(CollectList collectList, UDFContext udfContext) { Map<String, Object> metricsFields = new HashMap<>(); Accumulator accumulator = new Accumulator(); accumulator.setMetricsFields(metricsFields); - collectList.open(udfContext); - Accumulator agg = collectList.initAccumulator(accumulator); - - for (String o : arr) { - Event event = new Event(); - Map<String, Object> extractedFields = new HashMap<>(); - extractedFields.put("field", o); - event.setExtractedFields(extractedFields); - agg = collectList.add(event, agg); + return collectList.initAccumulator(accumulator); + } + private Accumulator createAccumulatorFromList(CollectList collectList, UDFContext udfContext, List<String> values) { + Accumulator accumulator = initializeAccumulator(collectList, udfContext); + for (String value : values) { + accumulator = addEvent(collectList, udfContext, accumulator, "field", value); } - Accumulator result = collectList.getResult(agg); - List<String> vals = (List<String>) result.getMetricsFields().get("field_list"); - assertEquals(vals.size(),4); + return accumulator; } - - -}
\ No newline at end of file + private Accumulator addEvent(CollectList collectList, UDFContext udfContext, Accumulator accumulator, String field, Object value) { + Event event = new Event(); + Map<String, Object> extractedFields = new HashMap<>(); + extractedFields.put(field, value); + event.setExtractedFields(extractedFields); + return collectList.add(event, accumulator); + } +} diff --git a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectSetTest.java b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectSetTest.java index defd6a7..694e482 100644 --- a/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectSetTest.java +++ b/groot-core/src/test/java/com/geedgenetworks/core/udf/test/aggregate/CollectSetTest.java @@ -1,90 +1,137 @@ package com.geedgenetworks.core.udf.test.aggregate; - import com.geedgenetworks.common.config.Accumulator; import com.geedgenetworks.core.udf.udaf.CollectSet; import com.geedgenetworks.api.common.udf.UDFContext; import com.geedgenetworks.api.event.Event; import org.junit.jupiter.api.Test; -import java.text.ParseException; import java.util.*; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.*; public class CollectSetTest { @Test - public void test() throws ParseException { + public void testObjectType() { + // 测试 "object" 类型,验证去重 + List<String> arr = List.of("192.168.1.1", "192.168.1.2", "192.168.1.1"); + UDFContext udfContext = createUDFContext("object"); + CollectSet collectSet = new CollectSet(); + collectSet.open(udfContext); - List<String> arr = List.of("192.168.1.1", "192.168.1.2", "192.168.1.3", "192.168.1.4","192.168.1.4"); - List<String> arr2 = List.of("192.168.1.5", "192.168.1.6", "192.168.1.3", "192.168.1.4"); - testMerge(arr,arr2); - testGetResult(arr); + Accumulator accumulator = initializeAccumulator(collectSet, udfContext); + for (String ip : arr) { + accumulator = addEvent(collectSet, udfContext, accumulator, "field", ip); + } + Accumulator result = collectSet.getResult(accumulator); + Set<String> aggregated = (Set<String>) result.getMetricsFields().get("field_list"); + assertEquals(aggregated.size(), 2); + assertTrue(aggregated.contains("192.168.1.1")); + assertTrue(aggregated.contains("192.168.1.2")); } - private void testMerge(List<String> arr,List<String> arr2) { + @Test + public void testEmptyInput() { + // 测试空输入 + UDFContext udfContext = createUDFContext("object"); + CollectSet collectSet = new CollectSet(); + collectSet.open(udfContext); - UDFContext udfContext = new UDFContext(); - udfContext.setLookupFields(List.of("field")); - udfContext.setOutputFields(Collections.singletonList("field_list")); + Accumulator accumulator = initializeAccumulator(collectSet, udfContext); + Accumulator result = collectSet.getResult(accumulator); + + Set<String> aggregated = (Set<String>) result.getMetricsFields().get("field_list"); + assertNotNull(aggregated); + assertTrue(aggregated.isEmpty()); + } + + @Test + public void testMerge() { + // 测试合并逻辑 + List<String> arr1 = List.of("192.168.1.1", "192.168.1.2"); + List<String> arr2 = List.of("192.168.1.3", "192.168.1.1"); + + UDFContext udfContext = createUDFContext("object"); CollectSet collectSet = new CollectSet(); - Map<String, Object> metricsFields = new HashMap<>(); - Accumulator accumulator = new Accumulator(); - accumulator.setMetricsFields(metricsFields); collectSet.open(udfContext); - Accumulator result1 = getMiddleResult(udfContext,arr); - Accumulator result2 = getMiddleResult(udfContext,arr2); - Accumulator result = collectSet.getResult(collectSet.merge(result1,result2)); - Set<String> vals = (Set<String>) result.getMetricsFields().get("field_list"); - assertEquals(vals.size(),6); + Accumulator acc1 = createAccumulatorFromList(collectSet, udfContext, arr1); + Accumulator acc2 = createAccumulatorFromList(collectSet, udfContext, arr2); + + Accumulator merged = collectSet.merge(acc1, acc2); + Set<String> aggregated = (Set<String>) merged.getMetricsFields().get("field_list"); + + assertEquals(aggregated.size(), 3); + assertTrue(aggregated.containsAll(List.of("192.168.1.1", "192.168.1.2", "192.168.1.3"))); } - private Accumulator getMiddleResult(UDFContext udfContext,List<String> arr) { + @Test + public void testNullValues() { + // 测试字段为 null 的场景 + UDFContext udfContext = createUDFContext("object"); + CollectSet collectSet = new CollectSet(); + collectSet.open(udfContext); + + Accumulator accumulator = initializeAccumulator(collectSet, udfContext); + accumulator = addEvent(collectSet, udfContext, accumulator, "field", null); + + Accumulator result = collectSet.getResult(accumulator); + Set<String> aggregated = (Set<String>) result.getMetricsFields().get("field_list"); + + assertNotNull(aggregated); + assertTrue(aggregated.isEmpty()); + } + + @Test + public void testDuplicateValues() { + // 测试重复值是否被正确去重 + List<String> arr = List.of("192.168.1.1", "192.168.1.1", "192.168.1.1"); + UDFContext udfContext = createUDFContext("object"); CollectSet collectSet = new CollectSet(); - Map<String, Object> metricsFields = new HashMap<>(); - Accumulator accumulator = new Accumulator(); - accumulator.setMetricsFields(metricsFields); collectSet.open(udfContext); - Accumulator agg = collectSet.initAccumulator(accumulator); - for (String o : arr) { - Event event = new Event(); - Map<String, Object> extractedFields = new HashMap<>(); - extractedFields.put("field", o); - event.setExtractedFields(extractedFields); - agg = collectSet.add(event, agg); + Accumulator accumulator = createAccumulatorFromList(collectSet, udfContext, arr); - } - return agg; + Accumulator result = collectSet.getResult(accumulator); + Set<String> aggregated = (Set<String>) result.getMetricsFields().get("field_list"); + + assertEquals(aggregated.size(), 1); + assertTrue(aggregated.contains("192.168.1.1")); } - private static void testGetResult(List<String> arr) throws ParseException { + private UDFContext createUDFContext(String collectType) { UDFContext udfContext = new UDFContext(); udfContext.setLookupFields(List.of("field")); udfContext.setOutputFields(Collections.singletonList("field_list")); - CollectSet collectSet = new CollectSet(); + Map<String, Object> parameters = new HashMap<>(); + parameters.put("collect_type", collectType); + udfContext.setParameters(parameters); + return udfContext; + } + + private Accumulator initializeAccumulator(CollectSet collectSet, UDFContext udfContext) { Map<String, Object> metricsFields = new HashMap<>(); Accumulator accumulator = new Accumulator(); accumulator.setMetricsFields(metricsFields); - collectSet.open(udfContext); - Accumulator agg = collectSet.initAccumulator(accumulator); - for (String o : arr) { - Event event = new Event(); - Map<String, Object> extractedFields = new HashMap<>(); - extractedFields.put("field", o); - event.setExtractedFields(extractedFields); - agg = collectSet.add(event, agg); + return collectSet.initAccumulator(accumulator); + } + private Accumulator createAccumulatorFromList(CollectSet collectSet, UDFContext udfContext, List<String> values) { + Accumulator accumulator = initializeAccumulator(collectSet, udfContext); + for (String value : values) { + accumulator = addEvent(collectSet, udfContext, accumulator, "field", value); } - Accumulator result = collectSet.getResult(agg); - Set<String> vals = (Set<String>) result.getMetricsFields().get("field_list"); - assertEquals(vals.size(),4); + return accumulator; } - - -}
\ No newline at end of file + private Accumulator addEvent(CollectSet collectSet, UDFContext udfContext, Accumulator accumulator, String field, Object value) { + Event event = new Event(); + Map<String, Object> extractedFields = new HashMap<>(); + extractedFields.put(field, value); + event.setExtractedFields(extractedFields); + return collectSet.add(event, accumulator); + } +} |
