diff options
Diffstat (limited to 'ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala')
| -rw-r--r-- | ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala index 8f7661a..5bececa 100644 --- a/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala +++ b/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala @@ -1,7 +1,9 @@ package cn.ac.iie.utils import cn.ac.iie.config.ApplicationConfig +import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession +import org.apache.spark.util.LongAccumulator import org.slf4j.LoggerFactory object SparkSessionUtil { @@ -9,6 +11,8 @@ object SparkSessionUtil { val spark: SparkSession = getSparkSession + var sparkContext: SparkContext = getContext + private def getSparkSession: SparkSession ={ val spark: SparkSession = SparkSession .builder() @@ -26,10 +30,27 @@ object SparkSessionUtil { spark } + def getContext: SparkContext = { + @transient var sc: SparkContext = null + if (sparkContext == null) sc = spark.sparkContext + sc + } + + def getLongAccumulator(name: String): LongAccumulator ={ + if (sparkContext == null){ + sparkContext = getContext + } + sparkContext.longAccumulator(name) + + } + def closeSpark(): Unit ={ if (spark != null){ spark.stop() } + if (sparkContext != null){ + sparkContext.stop() + } } } |
