diff options
Diffstat (limited to 'ip-learning-spark/src/main/scala/cn/ac')
9 files changed, 579 insertions, 196 deletions
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/config/ApplicationConfig.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/config/ApplicationConfig.scala index 395ea6b..687bdd5 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/config/ApplicationConfig.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/config/ApplicationConfig.scala @@ -36,12 +36,7 @@ object ApplicationConfig { val READ_CLICKHOUSE_MAX_TIME: Long = config.getLong("read.clickhouse.max.time") val READ_CLICKHOUSE_MIN_TIME: Long = config.getLong("read.clickhouse.min.time") - val ARANGO_TIME_LIMIT_TYPE: Int = config.getInt("arango.time.limit.type") - - val READ_ARANGO_MAX_TIME: Long = config.getLong("read.arango.max.time") - val READ_ARANGO_MIN_TIME: Long = config.getLong("read.arango.min.time") - - val ARANGODB_READ_LIMIT: String = config.getString("arangoDB.read.limit") + val ARANGODB_READ_LIMIT: Int = config.getInt("arangoDB.read.limit") val UPDATE_ARANGO_BATCH: Int = config.getInt("update.arango.batch") val RECENT_COUNT_HOUR: Int = config.getInt("recent.count.hour") val DISTINCT_CLIENT_IP_NUM: Int = config.getInt("distinct.client.ip.num") diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/dao/BaseClickhouseData.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/dao/BaseClickhouseData.scala index 952c30c..eb6a736 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/dao/BaseClickhouseData.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/dao/BaseClickhouseData.scala @@ -11,7 +11,7 @@ object BaseClickhouseData { val currentHour: Long = System.currentTimeMillis / (60 * 60 * 1000) * 60 * 60 private val timeLimit: (Long, Long) = getTimeLimit - private def initClickhouseData(sql:String): Unit ={ + private def initClickhouseData(sql:String): DataFrame ={ val dataFrame: DataFrame = spark.read.format("jdbc") .option("url", ApplicationConfig.SPARK_READ_CLICKHOUSE_URL) @@ -28,6 +28,8 @@ object BaseClickhouseData { .load() dataFrame.printSchema() dataFrame.createOrReplaceGlobalTempView("dbtable") + + dataFrame } def loadConnectionDataFromCk(): Unit ={ @@ -68,41 +70,7 @@ object BaseClickhouseData { initClickhouseData(sql) } - def getVertexFqdnDf: DataFrame ={ - loadConnectionDataFromCk() - val sql = - """ - |SELECT - | FQDN,MAX( LAST_FOUND_TIME ) AS LAST_FOUND_TIME,MIN( FIRST_FOUND_TIME ) AS FIRST_FOUND_TIME - |FROM - | ( - | (SELECT - | ssl_sni AS FQDN,MAX( common_recv_time ) AS LAST_FOUND_TIME,MIN( common_recv_time ) AS FIRST_FOUND_TIME - | FROM - | global_temp.dbtable - | WHERE - | common_schema_type = 'SSL' GROUP BY ssl_sni - | ) - | UNION ALL - | (SELECT - | http_host AS FQDN,MAX( common_recv_time ) AS LAST_FOUND_TIME,MIN( common_recv_time ) AS FIRST_FOUND_TIME - | FROM - | global_temp.dbtable - | WHERE - | common_schema_type = 'HTTP' GROUP BY http_host - | ) - | ) - |GROUP BY - | FQDN - |HAVING - | FQDN != '' - """.stripMargin - LOG.warn(sql) - val vertexFqdnDf = spark.sql(sql) - vertexFqdnDf.printSchema() - vertexFqdnDf - } - + /* def getVertexIpDf: DataFrame ={ loadConnectionDataFromCk() val sql = @@ -190,6 +158,149 @@ object BaseClickhouseData { relationFqdnLocateIpDf.printSchema() relationFqdnLocateIpDf } + */ + + def getVertexFqdnDf: DataFrame ={ + val sql = + """ + |(SELECT + | FQDN,MAX( LAST_FOUND_TIME ) AS LAST_FOUND_TIME,MIN( FIRST_FOUND_TIME ) AS FIRST_FOUND_TIME + |FROM + | ((SELECT + | ssl_sni AS FQDN,MAX( common_recv_time ) AS LAST_FOUND_TIME,MIN( common_recv_time ) AS FIRST_FOUND_TIME + | FROM tsg_galaxy_v3.connection_record_log + | WHERE common_schema_type = 'SSL' GROUP BY ssl_sni + | )UNION ALL + | (SELECT + | http_host AS FQDN,MAX( common_recv_time ) AS LAST_FOUND_TIME,MIN( common_recv_time ) AS FIRST_FOUND_TIME + | FROM tsg_galaxy_v3.connection_record_log + | WHERE common_schema_type = 'HTTP' GROUP BY http_host)) + |GROUP BY FQDN HAVING FQDN != '') as dbtable + """.stripMargin + LOG.warn(sql) + val frame = initClickhouseData(sql) + frame.printSchema() + frame + } + + def getVertexIpDf: DataFrame ={ + val where = "common_recv_time >= " + timeLimit._2 + " AND common_recv_time < " + timeLimit._1 + val sql = + s""" + |(SELECT * FROM + |((SELECT common_client_ip AS IP,MIN(common_recv_time) AS FIRST_FOUND_TIME, + |MAX(common_recv_time) AS LAST_FOUND_TIME, + |count(*) as SESSION_COUNT, + |SUM(common_c2s_byte_num+common_s2c_byte_num) as BYTES_SUM, + |groupUniqArray(2)(common_link_info_c2s)[2] as common_link_info, + |'client' as ip_type + |FROM tsg_galaxy_v3.connection_record_log + |where $where + |group by common_client_ip) + |UNION ALL + |(SELECT common_server_ip AS IP, + |MIN(common_recv_time) AS FIRST_FOUND_TIME, + |MAX(common_recv_time) AS LAST_FOUND_TIME, + |count(*) as SESSION_COUNT, + |SUM(common_c2s_byte_num+common_s2c_byte_num) as BYTES_SUM, + |groupUniqArray(2)(common_link_info_s2c)[2] as common_link_info, + |'server' as ip_type + |FROM tsg_galaxy_v3.connection_record_log + |where $where + |group by common_server_ip))) as dbtable + """.stripMargin + LOG.warn(sql) + val frame = initClickhouseData(sql) + frame.printSchema() + frame + } + + + def getRelationFqdnLocateIpDf: DataFrame ={ + val where = "common_recv_time >= " + timeLimit._2 + " AND common_recv_time < " + timeLimit._1 + val sql = + s""" + |(SELECT * FROM + |((SELECT ssl_sni AS FQDN,common_server_ip,MAX(common_recv_time) AS LAST_FOUND_TIME,MIN(common_recv_time) AS FIRST_FOUND_TIME,COUNT(*) AS COUNT_TOTAL, + |toString(groupUniqArray(${ApplicationConfig.DISTINCT_CLIENT_IP_NUM})(common_client_ip)) AS DIST_CIP_RECENT,'TLS' AS schema_type + |FROM tsg_galaxy_v3.connection_record_log + |WHERE $where and common_schema_type = 'SSL' GROUP BY ssl_sni,common_server_ip) + |UNION ALL + |(SELECT http_host AS FQDN,common_server_ip,MAX(common_recv_time) AS LAST_FOUND_TIME,MIN(common_recv_time) AS FIRST_FOUND_TIME,COUNT(*) AS COUNT_TOTAL, + |toString(groupUniqArray(${ApplicationConfig.DISTINCT_CLIENT_IP_NUM})(common_client_ip)) AS DIST_CIP_RECENT,'HTTP' AS schema_type + |FROM tsg_galaxy_v3.connection_record_log + |WHERE $where and common_schema_type = 'HTTP' GROUP BY http_host,common_server_ip)) + |WHERE FQDN != '') as dbtable + """.stripMargin + LOG.warn(sql) + val frame = initClickhouseData(sql) + frame.printSchema() + frame + } + + def getRelationSubidLocateIpDf: DataFrame ={ + val where = + s""" + | common_recv_time >= ${timeLimit._2} + | AND common_recv_time < ${timeLimit._1} + | AND common_subscriber_id != '' + | AND radius_framed_ip != '' + """.stripMargin + val sql = + s""" + |( + |SELECT common_subscriber_id,radius_framed_ip,MAX(common_recv_time) as LAST_FOUND_TIME,MIN(common_recv_time) as FIRST_FOUND_TIME + |FROM radius_record_log + |WHERE $where GROUP BY common_subscriber_id,radius_framed_ip + |) as dbtable + """.stripMargin + LOG.warn(sql) + val frame = initClickhouseData(sql) + frame.printSchema() + frame + } + + def getVertexSubidDf: DataFrame ={ + val where = + s""" + | common_recv_time >= ${timeLimit._2} + | AND common_recv_time < ${timeLimit._1} + | AND common_subscriber_id != '' + | AND radius_framed_ip != '' + """.stripMargin + val sql = + s""" + |( + |SELECT common_subscriber_id,MAX(common_recv_time) as LAST_FOUND_TIME,MIN(common_recv_time) as FIRST_FOUND_TIME FROM radius_record_log + |WHERE $where GROUP BY common_subscriber_id + |)as dbtable + """.stripMargin + LOG.warn(sql) + val frame = initClickhouseData(sql) + frame.printSchema() + frame + } + + def getVertexFramedIpDf: DataFrame ={ + val where = + s""" + | common_recv_time >= ${timeLimit._2} + | AND common_recv_time < ${timeLimit._1} + | AND common_subscriber_id != '' + | AND radius_framed_ip != '' + """.stripMargin + val sql = + s""" + |( + |SELECT DISTINCT radius_framed_ip,common_recv_time as LAST_FOUND_TIME FROM radius_record_log WHERE $where + |)as dbtable + """.stripMargin + LOG.warn(sql) + val frame = initClickhouseData(sql) + frame.printSchema() + frame + } + private def getTimeLimit: (Long,Long) ={ var maxTime = 0L diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/service/transform/MergeDataFrame.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/service/transform/MergeDataFrame.scala index 460caed..3094691 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/service/transform/MergeDataFrame.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/service/transform/MergeDataFrame.scala @@ -4,37 +4,57 @@ import java.util.regex.Pattern import cn.ac.iie.config.ApplicationConfig import cn.ac.iie.dao.BaseClickhouseData +import cn.ac.iie.spark.ArangoSpark import cn.ac.iie.spark.partition.CustomPartitioner +import cn.ac.iie.spark.rdd.ReadOptions +import com.arangodb.entity.{BaseDocument, BaseEdgeDocument} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.functions._ import org.slf4j.LoggerFactory +import cn.ac.iie.utils.SparkSessionUtil._ object MergeDataFrame { private val LOG = LoggerFactory.getLogger(MergeDataFrame.getClass) private val pattern = Pattern.compile("^[\\d]*$") + private val options = ReadOptions(ApplicationConfig.ARANGODB_DB_NAME) - def mergeVertexFqdn(): RDD[Row] ={ - BaseClickhouseData.getVertexFqdnDf - .rdd.filter(row => isDomain(row.getAs[String](0))).map(row => (row.get(0),row)) - .partitionBy(new CustomPartitioner(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)).values + def mergeVertexFqdn(): RDD[(String, (Option[BaseDocument], Option[Row]))] ={ + val fqdnAccmu = getLongAccumulator("FQDN Accumulator") + val fqdnRddRow = BaseClickhouseData.getVertexFqdnDf + .rdd.filter(row => isDomain(row.getAs[String](0))).map(row => { + fqdnAccmu.add(1) + (row.getAs[String]("FQDN"), row) + }).partitionBy(new CustomPartitioner(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)) + fqdnRddRow.cache() + val fqdnRddDoc = ArangoSpark.load[BaseDocument](sparkContext,"FQDN",options) + + fqdnRddDoc.map(doc => (doc.getKey, doc)).fullOuterJoin(fqdnRddRow) } - def mergeVertexIp(): RDD[Row]={ + def mergeVertexIp(): RDD[(String, (Option[BaseDocument], Option[Row]))]={ + val ipAccum = getLongAccumulator("IP Accumulator") val vertexIpDf = BaseClickhouseData.getVertexIpDf val frame = vertexIpDf.groupBy("IP").agg( min("FIRST_FOUND_TIME").alias("FIRST_FOUND_TIME"), max("LAST_FOUND_TIME").alias("LAST_FOUND_TIME"), collect_list("SESSION_COUNT").alias("SESSION_COUNT_LIST"), collect_list("BYTES_SUM").alias("BYTES_SUM_LIST"), - collect_list("ip_type").alias("ip_type_list") + collect_list("ip_type").alias("ip_type_list"), + last("common_link_info").alias("common_link_info") ) - val values = frame.rdd.map(row => (row.get(0), row)) - .partitionBy(new CustomPartitioner(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)).values - values + val ipRddRow = frame.rdd.map(row => { + ipAccum.add(1) + (row.getAs[String]("IP"), row) + }).partitionBy(new CustomPartitioner(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)) + val ipRddDoc = ArangoSpark.load[BaseDocument](sparkContext,"IP",options) + + ipRddDoc.map(doc => (doc.getKey, doc)).fullOuterJoin(ipRddRow) + } - def mergeRelationFqdnLocateIp(): RDD[Row] ={ + def mergeRelationFqdnLocateIp(): RDD[(String, (Option[BaseEdgeDocument], Option[Row]))] ={ + val fqdnLocIpAccum = getLongAccumulator("R_LOCATE_FQDN2IP Accumulator") val frame = BaseClickhouseData.getRelationFqdnLocateIpDf.filter(row => isDomain(row.getAs[String]("FQDN"))) .groupBy("FQDN", "common_server_ip") .agg( @@ -44,28 +64,72 @@ object MergeDataFrame { collect_list("schema_type").alias("schema_type_list"), collect_set("DIST_CIP_RECENT").alias("DIST_CIP_RECENT") ) - frame.rdd.map(row => { + val fqdnLocIpRddRow = frame.rdd.map(row => { val fqdn = row.getAs[String]("FQDN") val serverIp = row.getAs[String]("common_server_ip") - val key = fqdn.concat("-"+serverIp) - (key,row) - }).partitionBy(new CustomPartitioner(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)).values + val key = fqdn.concat("-" + serverIp) + fqdnLocIpAccum.add(1) + (key, row) + }).partitionBy(new CustomPartitioner(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)) + val fqdnLocIpRddDoc = ArangoSpark.load[BaseEdgeDocument](sparkContext,"R_LOCATE_FQDN2IP",options) + + fqdnLocIpRddDoc.map(doc => (doc.getKey, doc)).fullOuterJoin(fqdnLocIpRddRow) } + def mergeRelationSubidLocateIp(): RDD[(String, (Option[BaseEdgeDocument], Option[Row]))] ={ + val subidLocIpAccum = getLongAccumulator("R_LOCATE_SUBSCRIBER2IP Accumulator") + val subidLocIpRddRow = BaseClickhouseData.getRelationSubidLocateIpDf + .repartition(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS) + .rdd.map(row => { + val commonSubscriberId = row.getAs[String]("common_subscriber_id") + val ip = row.getAs[String]("radius_framed_ip") + val key = commonSubscriberId.concat("-" + ip) + subidLocIpAccum.add(1) + (key, row) + }).partitionBy(new CustomPartitioner(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)) + val subidLocIpRddDoc = ArangoSpark.load[BaseEdgeDocument](sparkContext,"R_LOCATE_SUBSCRIBER2IP",options) + + subidLocIpRddDoc.map(doc => (doc.getKey, doc)).fullOuterJoin(subidLocIpRddRow) + } + + def mergeVertexSubid(): RDD[(String, (Option[BaseDocument], Option[Row]))] ={ + val subidAccum = getLongAccumulator("SUBSCRIBER Accumulator") + val subidRddRow = BaseClickhouseData.getVertexSubidDf + .repartition(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS) + .rdd.map(row => { + val commonSubscriberId = row.getAs[String]("common_subscriber_id") + subidAccum.add(1) + (commonSubscriberId, row) + }).partitionBy(new CustomPartitioner(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)) + + val subidRddDoc = ArangoSpark.load[BaseDocument](sparkContext,"SUBSCRIBER",options) + + subidRddDoc.map(doc => (doc.getKey, doc)).fullOuterJoin(subidRddRow) + + } + + def mergeVertexFrameIp: RDD[Row] ={ + val framedIpAccum = getLongAccumulator("framed ip Accumulator") + val values = BaseClickhouseData.getVertexFramedIpDf + .repartition(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS) + .rdd.map(row => { + val ip = row.getAs[String]("radius_framed_ip") + framedIpAccum.add(1) + (ip, row) + }).partitionBy(new CustomPartitioner(ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS)).values + values + } + private def isDomain(fqdn: String): Boolean = { try { if (fqdn == null || fqdn.length == 0) { return false } - if (fqdn.contains(":")) { - val s = fqdn.split(":")(0) - if (s.contains(":")){ - return false - } - } - val fqdnArr = fqdn.split("\\.") - if (fqdnArr.length < 4 || fqdnArr.length > 4){ + + val fqdnArr = fqdn.split(":")(0).split("\\.") + + if (fqdnArr.length != 4){ return true } for (f <- fqdnArr) { @@ -83,6 +147,7 @@ object MergeDataFrame { LOG.error("解析域名 " + fqdn + " 失败:\n" + e.toString) } false + } } diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/service/update/UpdateDocHandler.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/service/update/UpdateDocHandler.scala index bdf8120..a275ab3 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/service/update/UpdateDocHandler.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/service/update/UpdateDocHandler.scala @@ -1,7 +1,8 @@ package cn.ac.iie.service.update -import java.lang +import java.util +import scala.collection.JavaConversions._ import cn.ac.iie.config.ApplicationConfig import cn.ac.iie.service.read.ReadHistoryArangoData @@ -14,16 +15,24 @@ object UpdateDocHandler { val PROTOCOL_SET: Set[String] = Set("HTTP","TLS","DNS") def updateMaxAttribute(hisDoc: BaseDocument,newAttribute:Long,attributeName:String): Unit ={ - var hisAttritube = hisDoc.getAttribute(attributeName).toString.toLong - if (newAttribute > hisAttritube){ - hisAttritube = newAttribute + if(hisDoc.getProperties.containsKey(attributeName)){ + var hisAttritube = hisDoc.getAttribute(attributeName).toString.toLong + if (newAttribute > hisAttritube){ + hisAttritube = newAttribute + } + hisDoc.addAttribute(attributeName,hisAttritube) } - hisDoc.addAttribute(attributeName,hisAttritube) } def updateSumAttribute(hisDoc: BaseDocument,newAttribute:Long,attributeName:String): Unit ={ - val hisAttritube = hisDoc.getAttribute(attributeName).toString.toLong - hisDoc.addAttribute(attributeName,newAttribute+hisAttritube) + if (hisDoc.getProperties.containsKey(attributeName)){ + val hisAttritube = hisDoc.getAttribute(attributeName).toString.toLong + hisDoc.addAttribute(attributeName,newAttribute+hisAttritube) + } + } + + def replaceAttribute(hisDoc: BaseDocument,newAttribute:String,attributeName:String): Unit ={ + hisDoc.addAttribute(attributeName,newAttribute) } def separateAttributeByIpType(ipTypeList:ofRef[String], @@ -62,19 +71,21 @@ object UpdateDocHandler { } def updateProtocolAttritube(hisDoc:BaseEdgeDocument, protocolMap: Map[String, Long]): Unit ={ - var protocolType = hisDoc.getAttribute("PROTOCOL_TYPE").toString - protocolMap.foreach(t => { - if (t._2 > 0 && !protocolType.contains(t._1)){ - protocolType = protocolType.concat(","+ t._1) - } - val cntTotalName = t._1.concat("_CNT_TOTAL") - val cntRecentName = t._1.concat("_CNT_RECENT") - val cntRecent: Array[lang.Long] = hisDoc.getAttribute(cntRecentName).asInstanceOf[Array[java.lang.Long]] - cntRecent.update(0,t._2) - updateSumAttribute(hisDoc,t._2,cntTotalName) - hisDoc.addAttribute(cntRecentName,cntRecent) - }) - hisDoc.addAttribute("PROTOCOL_TYPE",protocolType) + if (hisDoc.getProperties.containsKey("PROTOCOL_TYPE")){ + var protocolType = hisDoc.getAttribute("PROTOCOL_TYPE").toString + protocolMap.foreach((t: (String, Long)) => { + if (t._2 > 0 && !protocolType.contains(t._1)){ + protocolType = protocolType.concat(","+ t._1) + } + val cntTotalName = t._1.concat("_CNT_TOTAL") + val cntRecentName = t._1.concat("_CNT_RECENT") + val cntRecent = hisDoc.getAttribute(cntRecentName).asInstanceOf[Array[Long]] + cntRecent.update(0,t._2) + updateSumAttribute(hisDoc,t._2,cntTotalName) + hisDoc.addAttribute(cntRecentName,cntRecent) + }) + hisDoc.addAttribute("PROTOCOL_TYPE",protocolType) + } } def putProtocolAttritube(doc:BaseEdgeDocument, protocolMap: Map[String, Long]): Unit ={ @@ -93,10 +104,30 @@ object UpdateDocHandler { doc.addAttribute("PROTOCOL_TYPE",protocolTypeBuilder.toString().replaceFirst(",","")) } - def mergeDistinctIp(distCipRecent:ofRef[ofRef[String]]): Array[String] ={ - distCipRecent.flatten.distinct.take(ApplicationConfig.DISTINCT_CLIENT_IP_NUM).toArray + def updateProtocolDocument(doc: BaseEdgeDocument): Unit = { + if (doc.getProperties.containsKey("PROTOCOL_TYPE")) { + for (protocol <- PROTOCOL_SET) { + val protocolRecent = protocol + "_CNT_RECENT" + val cntRecent: util.ArrayList[Long] = doc.getAttribute(protocolRecent).asInstanceOf[util.ArrayList[Long]] + val cntRecentsSrc = cntRecent.toArray().map(_.toString.toLong) + val cntRecentsDst = new Array[Long](24) + System.arraycopy(cntRecentsSrc, 0, cntRecentsDst, 1, cntRecentsSrc.length - 1) + cntRecentsDst(0) = 0L + doc.addAttribute(protocolRecent, cntRecentsDst) + } + } + } + + def mergeDistinctIp(distCipRecent:ofRef[String]): Array[String] ={ + distCipRecent.flatMap(str => { + str.replaceAll("\\[","") + .replaceAll("\\]","") + .replaceAll("\\'","") + .split(",") + }).distinct.take(ApplicationConfig.DISTINCT_CLIENT_IP_NUM).toArray } + def putDistinctIp(doc:BaseEdgeDocument,newDistinctIp:Array[String]): Unit ={ val map = newDistinctIp.map(ip => { (ip, ReadHistoryArangoData.currentHour) @@ -106,17 +137,19 @@ object UpdateDocHandler { } def updateDistinctIp(hisDoc:BaseEdgeDocument,newDistinctIp:Array[String]): Unit ={ - val hisDistCip = hisDoc.getAttribute("DIST_CIP").asInstanceOf[Array[String]] - val hisDistCipTs = hisDoc.getAttribute("DIST_CIP_TS").asInstanceOf[Array[Long]] - if (hisDistCip.length == hisDistCipTs.length){ - val distCipToTsMap: Map[String, Long] = hisDistCip.zip(hisDistCipTs).toMap - val muDistCipToTsMap: mutable.Map[String, Long] = mutable.Map(distCipToTsMap.toSeq:_*) - newDistinctIp.foreach(cip => { - muDistCipToTsMap.put(cip,ReadHistoryArangoData.currentHour) - }) - val resultMap = muDistCipToTsMap.toList.sortBy(-_._2).take(ApplicationConfig.DISTINCT_CLIENT_IP_NUM).toMap - hisDoc.addAttribute("DIST_CIP",resultMap.keys.toArray) - hisDoc.addAttribute("DIST_CIP_TS",resultMap.values.toArray) + if (hisDoc.getProperties.containsKey("DIST_CIP") && hisDoc.getProperties.containsKey("DIST_CIP_TS")){ + val hisDistCip = hisDoc.getAttribute("DIST_CIP").asInstanceOf[util.ArrayList[String]] + val hisDistCipTs = hisDoc.getAttribute("DIST_CIP_TS").asInstanceOf[util.ArrayList[Long]] + if (hisDistCip.length == hisDistCipTs.length){ + val distCipToTsMap: Map[String, Long] = hisDistCip.zip(hisDistCipTs).toMap + val muDistCipToTsMap: mutable.Map[String, Long] = mutable.Map(distCipToTsMap.toSeq:_*) + newDistinctIp.foreach(cip => { + muDistCipToTsMap.put(cip,ReadHistoryArangoData.currentHour) + }) + val resultMap = muDistCipToTsMap.toList.sortBy(-_._2).take(ApplicationConfig.DISTINCT_CLIENT_IP_NUM).toMap + hisDoc.addAttribute("DIST_CIP",resultMap.keys.toArray) + hisDoc.addAttribute("DIST_CIP_TS",resultMap.values.toArray) + } } } diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/service/update/UpdateDocument.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/service/update/UpdateDocument.scala index b7d4875..5162834 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/service/update/UpdateDocument.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/service/update/UpdateDocument.scala @@ -1,17 +1,12 @@ package cn.ac.iie.service.update import java.util -import java.util.concurrent.ConcurrentHashMap import cn.ac.iie.config.ApplicationConfig -import cn.ac.iie.dao.BaseArangoData -import cn.ac.iie.dao.BaseArangoData._ import cn.ac.iie.service.transform.MergeDataFrame._ import cn.ac.iie.service.update.UpdateDocHandler._ -import cn.ac.iie.utils.{ArangoDBConnect, ExecutorThreadPool, SparkSessionUtil} -import cn.ac.iie.utils.SparkSessionUtil.spark +import cn.ac.iie.utils.{ArangoDBConnect, SparkSessionUtil} import com.arangodb.entity.{BaseDocument, BaseEdgeDocument} -import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.slf4j.LoggerFactory @@ -19,44 +14,47 @@ import org.slf4j.LoggerFactory import scala.collection.mutable.WrappedArray.ofRef object UpdateDocument { - private val pool = ExecutorThreadPool.getInstance private val arangoManger: ArangoDBConnect = ArangoDBConnect.getInstance() private val LOG = LoggerFactory.getLogger(UpdateDocument.getClass) - private val baseArangoData = new BaseArangoData() def update(): Unit = { try { - updateDocument("FQDN", historyVertexFqdnMap, getVertexFqdnRow, classOf[BaseDocument], mergeVertexFqdn) - updateDocument("IP", historyVertexIpMap, getVertexIpRow, classOf[BaseDocument], mergeVertexIp) - updateDocument("R_LOCATE_FQDN2IP", historyRelationFqdnAddressIpMap, getRelationFqdnLocateIpRow, classOf[BaseEdgeDocument], mergeRelationFqdnLocateIp) + updateDocument("FQDN", getVertexFqdnRow, mergeVertexFqdn) + + updateDocument("SUBSCRIBER",getVertexSubidRow,mergeVertexSubid) + + insertFrameIp() + + updateDocument("R_LOCATE_SUBSCRIBER2IP",getRelationSubidLocateIpRow,mergeRelationSubidLocateIp) + + updateDocument("R_LOCATE_FQDN2IP", getRelationFqdnLocateIpRow, mergeRelationFqdnLocateIp) + + updateDocument("IP", getVertexIpRow, mergeVertexIp) + } catch { case e: Exception => e.printStackTrace() } finally { - pool.shutdown() arangoManger.clean() SparkSessionUtil.closeSpark() + System.exit(0) } } private def updateDocument[T <: BaseDocument](collName: String, - historyMap: ConcurrentHashMap[Integer, ConcurrentHashMap[String, T]], - getDocumentRow: (Row, ConcurrentHashMap[String, T]) => T, - clazz: Class[T], - getNewDataRdd: () => RDD[Row] + getDocumentRow: ((String, (Option[T], Option[Row]))) => T, + getJoinRdd: () => RDD[(String, (Option[T], Option[Row]))] ): Unit = { - baseArangoData.readHistoryData(collName, historyMap, clazz) - val hisBc = spark.sparkContext.broadcast(historyMap) try { val start = System.currentTimeMillis() - val newDataRdd = getNewDataRdd() - newDataRdd.foreachPartition(iter => { - val partitionId: Int = TaskContext.get.partitionId - val dictionaryMap: ConcurrentHashMap[String, T] = hisBc.value.get(partitionId) + val joinRdd = getJoinRdd() + joinRdd.foreachPartition(iter => { val resultDocumentList = new util.ArrayList[T] var i = 0 iter.foreach(row => { - val document = getDocumentRow(row, dictionaryMap) - resultDocumentList.add(document) + val document = getDocumentRow(row) + if (document != null){ + resultDocumentList.add(document) + } i += 1 if (i >= ApplicationConfig.UPDATE_ARANGO_BATCH) { arangoManger.overwrite(resultDocumentList, collName) @@ -73,88 +71,238 @@ object UpdateDocument { LOG.warn(s"更新$collName 时间:${last - start}") } catch { case e: Exception => e.printStackTrace() - } finally { - hisBc.destroy() } } - private def getVertexFqdnRow(row: Row, dictionaryMap: ConcurrentHashMap[String, BaseDocument]): BaseDocument = { - val fqdn = row.getAs[String]("FQDN") - val lastFoundTime = row.getAs[Long]("LAST_FOUND_TIME") - val firstFoundTime = row.getAs[Long]("FIRST_FOUND_TIME") - var document: BaseDocument = dictionaryMap.getOrDefault(fqdn, null) - if (document != null) { - updateMaxAttribute(document, lastFoundTime, "LAST_FOUND_TIME") - } else { - document = new BaseDocument - document.setKey(fqdn) - document.addAttribute("FQDN_NAME", fqdn) - document.addAttribute("FIRST_FOUND_TIME", firstFoundTime) - document.addAttribute("LAST_FOUND_TIME", lastFoundTime) - } + private def insertFrameIp(): Unit ={ + mergeVertexFrameIp.foreachPartition(iter => { + val resultDocumentList = new util.ArrayList[BaseDocument] + var i = 0 + iter.foreach(row => { + val document = getVertexFrameipRow(row) + resultDocumentList.add(document) + i += 1 + if (i >= ApplicationConfig.UPDATE_ARANGO_BATCH) { + arangoManger.overwrite(resultDocumentList, "IP") + LOG.warn(s"更新:IP" + i) + i = 0 + } + }) + if (i != 0) { + arangoManger.overwrite(resultDocumentList, "IP") + LOG.warn(s"更新IP:" + i) + } + }) + } + + private def getVertexFrameipRow(row: Row): BaseDocument ={ + val ip = row.getAs[String]("radius_framed_ip") + val document = new BaseDocument() + document.setKey(ip) + document.addAttribute("IP",ip) document } - private def getVertexIpRow(row: Row, dictionaryMap: ConcurrentHashMap[String, BaseDocument]): BaseDocument = { - val ip = row.getAs[String]("IP") - val firstFoundTime = row.getAs[Long]("FIRST_FOUND_TIME") - val lastFoundTime = row.getAs[Long]("LAST_FOUND_TIME") - val sessionCountList = row.getAs[ofRef[AnyRef]]("SESSION_COUNT_LIST") - val bytesSumList = row.getAs[ofRef[AnyRef]]("BYTES_SUM_LIST") - val ipTypeList = row.getAs[ofRef[String]]("ip_type_list") - val sepAttributeTuple = separateAttributeByIpType(ipTypeList, sessionCountList, bytesSumList) - - var document = dictionaryMap.getOrDefault(ip, null) - if (document != null) { - updateMaxAttribute(document, lastFoundTime, "LAST_FOUND_TIME") - updateSumAttribute(document, sepAttributeTuple._1, "SERVER_SESSION_COUNT") - updateSumAttribute(document, sepAttributeTuple._2, "SERVER_BYTES_SUM") - updateSumAttribute(document, sepAttributeTuple._3, "CLIENT_SESSION_COUNT") - updateSumAttribute(document, sepAttributeTuple._4, "CLIENT_BYTES_SUM") - } else { - document = new BaseDocument - document.setKey(ip) - document.addAttribute("IP", ip) - document.addAttribute("FIRST_FOUND_TIME", firstFoundTime) - document.addAttribute("LAST_FOUND_TIME", lastFoundTime) - document.addAttribute("SERVER_SESSION_COUNT", sepAttributeTuple._1) - document.addAttribute("SERVER_BYTES_SUM", sepAttributeTuple._2) - document.addAttribute("CLIENT_SESSION_COUNT", sepAttributeTuple._3) - document.addAttribute("CLIENT_BYTES_SUM", sepAttributeTuple._4) - document.addAttribute("COMMON_LINK_INFO", "") + private def getRelationSubidLocateIpRow(joinRow: (String, (Option[BaseEdgeDocument], Option[Row]))): BaseEdgeDocument ={ + + val subidLocIpDocOpt = joinRow._2._1 + var subidLocIpDoc = subidLocIpDocOpt match { + case Some(doc) => doc + case None => null } - document + + val subidLocIpRowOpt = joinRow._2._2 + + val subidLocIpRow = subidLocIpRowOpt match { + case Some(r) => r + case None => null + } + + if (subidLocIpRow != null){ + val subId = subidLocIpRow.getAs[String]("common_subscriber_id") + val ip = subidLocIpRow.getAs[String]("radius_framed_ip") + val lastFoundTime = subidLocIpRow.getAs[Long]("LAST_FOUND_TIME") + val firstFoundTime = subidLocIpRow.getAs[Long]("FIRST_FOUND_TIME") + + val key = subId.concat("-"+ip) + if (subidLocIpDoc != null){ + updateMaxAttribute(subidLocIpDoc,lastFoundTime,"LAST_FOUND_TIME") + } else { + subidLocIpDoc = new BaseEdgeDocument() + subidLocIpDoc.setKey(key) + subidLocIpDoc.setFrom("SUBSCRIBER/" + subId) + subidLocIpDoc.setTo("IP/" + ip) + subidLocIpDoc.addAttribute("SUBSCRIBER",subId) + subidLocIpDoc.addAttribute("IP",ip) + subidLocIpDoc.addAttribute("FIRST_FOUND_TIME",firstFoundTime) + subidLocIpDoc.addAttribute("LAST_FOUND_TIME",lastFoundTime) + } + } + subidLocIpDoc } - private def getRelationFqdnLocateIpRow(row: Row, dictionaryMap: ConcurrentHashMap[String, BaseEdgeDocument]): BaseEdgeDocument = { - val fqdn = row.getAs[String]("FQDN") - val serverIp = row.getAs[String]("common_server_ip") - val firstFoundTime = row.getAs[Long]("FIRST_FOUND_TIME") - val lastFoundTime = row.getAs[Long]("LAST_FOUND_TIME") - val countTotalList = row.getAs[ofRef[AnyRef]]("COUNT_TOTAL_LIST") - val schemaTypeList = row.getAs[ofRef[AnyRef]]("schema_type_list") - val distCipRecent = row.getAs[ofRef[ofRef[String]]]("DIST_CIP_RECENT") - - val sepAttritubeMap: Map[String, Long] = separateAttributeByProtocol(schemaTypeList, countTotalList) - val distinctIp: Array[String] = mergeDistinctIp(distCipRecent) - - val key = fqdn.concat("-" + serverIp) - var document = dictionaryMap.getOrDefault(key, null) - if (document != null) { - updateMaxAttribute(document, lastFoundTime, "LAST_FOUND_TIME") - updateProtocolAttritube(document, sepAttritubeMap) - updateDistinctIp(document, distinctIp) - } else { - document = new BaseEdgeDocument() - document.setKey(key) - document.setFrom("FQDN/" + fqdn) - document.setTo("IP/" + serverIp) - document.addAttribute("FIRST_FOUND_TIME", firstFoundTime) - document.addAttribute("LAST_FOUND_TIME", lastFoundTime) - putProtocolAttritube(document, sepAttritubeMap) - putDistinctIp(document, distinctIp) + private def getVertexSubidRow(joinRow: (String, (Option[BaseDocument], Option[Row]))): BaseDocument ={ + val subidDocOpt = joinRow._2._1 + var subidDoc = subidDocOpt match { + case Some(doc) => doc + case None => null } - document + + val subidRowOpt = joinRow._2._2 + + val subidRow = subidRowOpt match { + case Some(r) => r + case None => null + } + + if (subidRow != null){ + val subId = subidRow.getAs[String]("common_subscriber_id") + val subLastFoundTime = subidRow.getAs[Long]("LAST_FOUND_TIME") + val subFirstFoundTime = subidRow.getAs[Long]("FIRST_FOUND_TIME") + if (subidDoc != null){ + updateMaxAttribute(subidDoc,subLastFoundTime,"LAST_FOUND_TIME") + } else { + subidDoc = new BaseDocument() + subidDoc.setKey(subId) + subidDoc.addAttribute("SUBSCRIBER",subId) + subidDoc.addAttribute("FIRST_FOUND_TIME",subFirstFoundTime) + subidDoc.addAttribute("LAST_FOUND_TIME",subLastFoundTime) + } + } + + subidDoc + } + + private def getVertexFqdnRow(joinRow: (String, (Option[BaseDocument], Option[Row]))): BaseDocument = { + val fqdnDocOpt = joinRow._2._1 + var fqdnDoc = fqdnDocOpt match { + case Some(doc) => doc + case None => null + } + + val fqdnRowOpt = joinRow._2._2 + + val fqdnRow = fqdnRowOpt match { + case Some(r) => r + case None => null + } + + if (fqdnRow != null){ + val fqdn = fqdnRow.getAs[String]("FQDN") + val lastFoundTime = fqdnRow.getAs[Long]("LAST_FOUND_TIME") + val firstFoundTime = fqdnRow.getAs[Long]("FIRST_FOUND_TIME") + if (fqdnDoc != null) { + updateMaxAttribute(fqdnDoc, lastFoundTime, "LAST_FOUND_TIME") + } else { + fqdnDoc = new BaseDocument + fqdnDoc.setKey(fqdn) + fqdnDoc.addAttribute("FQDN_NAME", fqdn) + fqdnDoc.addAttribute("FIRST_FOUND_TIME", firstFoundTime) + fqdnDoc.addAttribute("LAST_FOUND_TIME", lastFoundTime) + } + } + + fqdnDoc + } + + private def getVertexIpRow(joinRow: (String, (Option[BaseDocument], Option[Row]))): BaseDocument = { + val ipDocOpt = joinRow._2._1 + var ipDoc = ipDocOpt match { + case Some(doc) => doc + case None => null + } + + val ipRowOpt = joinRow._2._2 + + val ipRow = ipRowOpt match { + case Some(r) => r + case None => null + } + + if (ipRow != null){ + val ip = ipRow.getAs[String]("IP") + val firstFoundTime = ipRow.getAs[Long]("FIRST_FOUND_TIME") + val lastFoundTime = ipRow.getAs[Long]("LAST_FOUND_TIME") + val sessionCountList = ipRow.getAs[ofRef[AnyRef]]("SESSION_COUNT_LIST") + val bytesSumList = ipRow.getAs[ofRef[AnyRef]]("BYTES_SUM_LIST") + val ipTypeList = ipRow.getAs[ofRef[String]]("ip_type_list") + val linkInfo = ipRow.getAs[String]("common_link_info") + val sepAttributeTuple = separateAttributeByIpType(ipTypeList, sessionCountList, bytesSumList) + + if (ipDoc != null) { + updateMaxAttribute(ipDoc, lastFoundTime, "LAST_FOUND_TIME") + updateSumAttribute(ipDoc, sepAttributeTuple._1, "SERVER_SESSION_COUNT") + updateSumAttribute(ipDoc, sepAttributeTuple._2, "SERVER_BYTES_SUM") + updateSumAttribute(ipDoc, sepAttributeTuple._3, "CLIENT_SESSION_COUNT") + updateSumAttribute(ipDoc, sepAttributeTuple._4, "CLIENT_BYTES_SUM") + replaceAttribute(ipDoc,linkInfo,"COMMON_LINK_INFO") + } else { + ipDoc = new BaseDocument + ipDoc.setKey(ip) + ipDoc.addAttribute("IP", ip) + ipDoc.addAttribute("FIRST_FOUND_TIME", firstFoundTime) + ipDoc.addAttribute("LAST_FOUND_TIME", lastFoundTime) + ipDoc.addAttribute("SERVER_SESSION_COUNT", sepAttributeTuple._1) + ipDoc.addAttribute("SERVER_BYTES_SUM", sepAttributeTuple._2) + ipDoc.addAttribute("CLIENT_SESSION_COUNT", sepAttributeTuple._3) + ipDoc.addAttribute("CLIENT_BYTES_SUM", sepAttributeTuple._4) + ipDoc.addAttribute("COMMON_LINK_INFO", "") + } + } + + ipDoc + } + + private def getRelationFqdnLocateIpRow(joinRow: (String, (Option[BaseEdgeDocument], Option[Row]))): BaseEdgeDocument = { + + val fqdnLocIpDocOpt = joinRow._2._1 + var fqdnLocIpDoc = fqdnLocIpDocOpt match { + case Some(doc) => doc + case None => null + } + + val fqdnLocIpRowOpt = joinRow._2._2 + + val fqdnLocIpRow = fqdnLocIpRowOpt match { + case Some(r) => r + case None => null + } + + if (fqdnLocIpDoc != null){ + updateProtocolDocument(fqdnLocIpDoc) + } + + if (fqdnLocIpRow != null){ + val fqdn = fqdnLocIpRow.getAs[String]("FQDN") + val serverIp = fqdnLocIpRow.getAs[String]("common_server_ip") + val firstFoundTime = fqdnLocIpRow.getAs[Long]("FIRST_FOUND_TIME") + val lastFoundTime = fqdnLocIpRow.getAs[Long]("LAST_FOUND_TIME") + val countTotalList = fqdnLocIpRow.getAs[ofRef[AnyRef]]("COUNT_TOTAL_LIST") + val schemaTypeList = fqdnLocIpRow.getAs[ofRef[AnyRef]]("schema_type_list") + val distCipRecent = fqdnLocIpRow.getAs[ofRef[String]]("DIST_CIP_RECENT") + + val sepAttritubeMap: Map[String, Long] = separateAttributeByProtocol(schemaTypeList, countTotalList) + val distinctIp: Array[String] = mergeDistinctIp(distCipRecent) + + val key = fqdn.concat("-" + serverIp) + + if (fqdnLocIpDoc != null) { + updateMaxAttribute(fqdnLocIpDoc, lastFoundTime, "LAST_FOUND_TIME") + updateProtocolAttritube(fqdnLocIpDoc, sepAttritubeMap) + updateDistinctIp(fqdnLocIpDoc, distinctIp) + } else { + fqdnLocIpDoc = new BaseEdgeDocument() + fqdnLocIpDoc.setKey(key) + fqdnLocIpDoc.setFrom("FQDN/" + fqdn) + fqdnLocIpDoc.setTo("IP/" + serverIp) + fqdnLocIpDoc.addAttribute("FIRST_FOUND_TIME", firstFoundTime) + fqdnLocIpDoc.addAttribute("LAST_FOUND_TIME", lastFoundTime) + putProtocolAttritube(fqdnLocIpDoc, sepAttritubeMap) + putDistinctIp(fqdnLocIpDoc, distinctIp) + } + } + + fqdnLocIpDoc } } diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/ArangoSpark.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/ArangoSpark.scala index b492f9a..e1a4060 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/ArangoSpark.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/ArangoSpark.scala @@ -24,6 +24,7 @@ package cn.ac.iie.spark import cn.ac.iie.spark.rdd.{ArangoRdd, ReadOptions, WriteOptions} import cn.ac.iie.spark.vpack.VPackUtils +import com.arangodb.model.DocumentCreateOptions import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -76,7 +77,6 @@ object ArangoSpark { * * @param dataframe the dataframe with data to save * @param collection the collection to save in - * @param options additional write options */ def saveDF(dataframe: DataFrame, collection: String): Unit = saveRDD[Row](dataframe.rdd, collection, WriteOptions(), (x: Iterator[Row]) => x.map { y => VPackUtils.rowToVPack(y) }) @@ -102,6 +102,11 @@ object ArangoSpark { case WriteOptions.INSERT => col.insertDocuments(docs) case WriteOptions.UPDATE => col.updateDocuments(docs) case WriteOptions.REPLACE => col.replaceDocuments(docs) + case WriteOptions.OVERWRITE => + val documentCreateOptions = new DocumentCreateOptions + documentCreateOptions.overwrite(true) + documentCreateOptions.silent(true) + col.insertDocuments(docs, documentCreateOptions) } arangoDB.shutdown() @@ -123,7 +128,7 @@ object ArangoSpark { * * @param sparkContext the sparkContext containing the ArangoDB configuration * @param collection the collection to load data from - * @param additional read options + * @param options read options */ def load[T: ClassTag](sparkContext: SparkContext, collection: String, options: ReadOptions): ArangoRdd[T] = new ArangoRdd[T](sparkContext, createReadOptions(options, sparkContext.getConf).copy(collection = collection)) diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoRdd.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoRdd.scala index 4162e76..ab77299 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoRdd.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/ArangoRdd.scala @@ -56,7 +56,7 @@ class ArangoRdd[T: ClassTag](@transient override val sparkContext: SparkContext, override def repartition(numPartitions: Int)(implicit ord: Ordering[T]): RDD[T] = super.repartition(numPartitions) private def getPartition(idx: Int, countTotal: Long): QueryArangoPartition = { - val sepNum = countTotal / ApplicationConfig.THREAD_POOL_NUMBER + 1 + val sepNum = countTotal / ApplicationConfig.SPARK_SQL_SHUFFLE_PARTITIONS + 1 val offsetNum = idx * sepNum new QueryArangoPartition(idx, offsetNum, sepNum) } diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/WriteOptions.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/WriteOptions.scala index 46f3c80..d659cb5 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/WriteOptions.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/spark/rdd/WriteOptions.scala @@ -116,4 +116,9 @@ object WriteOptions { */ case object REPLACE extends Method + /** + * save documents by overwrite + */ + case object OVERWRITE extends Method + }
\ No newline at end of file diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala index 8f7661a..5bececa 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala @@ -1,7 +1,9 @@ package cn.ac.iie.utils import cn.ac.iie.config.ApplicationConfig +import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession +import org.apache.spark.util.LongAccumulator import org.slf4j.LoggerFactory object SparkSessionUtil { @@ -9,6 +11,8 @@ object SparkSessionUtil { val spark: SparkSession = getSparkSession + var sparkContext: SparkContext = getContext + private def getSparkSession: SparkSession ={ val spark: SparkSession = SparkSession .builder() @@ -26,10 +30,27 @@ object SparkSessionUtil { spark } + def getContext: SparkContext = { + @transient var sc: SparkContext = null + if (sparkContext == null) sc = spark.sparkContext + sc + } + + def getLongAccumulator(name: String): LongAccumulator ={ + if (sparkContext == null){ + sparkContext = getContext + } + sparkContext.longAccumulator(name) + + } + def closeSpark(): Unit ={ if (spark != null){ spark.stop() } + if (sparkContext != null){ + sparkContext.stop() + } } } |
