summaryrefslogtreecommitdiff
path: root/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala
diff options
context:
space:
mode:
Diffstat (limited to 'ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala')
-rw-r--r--ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala21
1 files changed, 21 insertions, 0 deletions
diff --git a/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala b/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala
index 8f7661a..5bececa 100644
--- a/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala
+++ b/ip-learning-spark/src/main/scala/cn/ac/iie/utils/SparkSessionUtil.scala
@@ -1,7 +1,9 @@
package cn.ac.iie.utils
import cn.ac.iie.config.ApplicationConfig
+import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.LongAccumulator
import org.slf4j.LoggerFactory
object SparkSessionUtil {
@@ -9,6 +11,8 @@ object SparkSessionUtil {
val spark: SparkSession = getSparkSession
+ var sparkContext: SparkContext = getContext
+
private def getSparkSession: SparkSession ={
val spark: SparkSession = SparkSession
.builder()
@@ -26,10 +30,27 @@ object SparkSessionUtil {
spark
}
+ def getContext: SparkContext = {
+ @transient var sc: SparkContext = null
+ if (sparkContext == null) sc = spark.sparkContext
+ sc
+ }
+
+ def getLongAccumulator(name: String): LongAccumulator ={
+ if (sparkContext == null){
+ sparkContext = getContext
+ }
+ sparkContext.longAccumulator(name)
+
+ }
+
def closeSpark(): Unit ={
if (spark != null){
spark.stop()
}
+ if (sparkContext != null){
+ sparkContext.stop()
+ }
}
}