运行环境

  • MacOS 10.13.6
  • Spark 2.3.1
  • Scala 2.11.8
  • JDK 8
  • IDEA + Maven

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.regression.LabeledPoint

object Email {

def main(args: Array[String]) {

val conf = new SparkConf().setAppName("Email").setMaster("local")
val sc = new SparkContext(conf)

val spam = sc.textFile("/Users/mfcheer/IdeaProjects/sparkML/data/spam.txt")
val normal = sc.textFile("/Users/mfcheer/IdeaProjects/sparkML/data/normal.txt")

// 创建一个HashingTF实例来把邮件文本映射为包含10000个特征的向量
val tf = new HashingTF(10000)
val spamFeatures = spam.map(line => tf.transform(line.split(" ")))
val normalFeatures = normal.map(line => tf.transform(line.split(" ")))

val postiveExamples = spamFeatures.map(features => LabeledPoint(1, features))
val negativeExamples = normalFeatures.map(features => LabeledPoint(0, features))
val trainingData = postiveExamples.union(negativeExamples)
trainingData.persist()

val model = new LogisticRegressionWithLBFGS().run(trainingData)

// 垃圾邮件测试
println(model.predict(tf.transform("O M G GET cheap stuff by give my to ...".split(" "))))
// 正常邮件测试
println(model.predict(tf.transform("Hi Dad, I started studying Spark the other ...".split(" "))))
}

}

输出:

1
2
1.0
0.0