|
34 | 34 | if __name__ == "__main__":
|
35 | 35 | sc = SparkContext(appName="SimpleTextClassificationPipeline")
|
36 | 36 | sqlCtx = SQLContext(sc)
|
| 37 | + |
| 38 | + # Prepare training documents, which are labeled. |
37 | 39 | LabeledDocument = Row('id', 'text', 'label')
|
38 | 40 | training = sqlCtx.inferSchema(
|
39 | 41 | sc.parallelize([(0L, "a b c d e spark", 1.0),
|
|
42 | 44 | (3L, "hadoop mapreduce", 0.0)])
|
43 | 45 | .map(lambda x: LabeledDocument(*x)))
|
44 | 46 |
|
| 47 | + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. |
45 | 48 | tokenizer = Tokenizer() \
|
46 | 49 | .setInputCol("text") \
|
47 | 50 | .setOutputCol("words")
|
|
54 | 57 | pipeline = Pipeline() \
|
55 | 58 | .setStages([tokenizer, hashingTF, lr])
|
56 | 59 |
|
| 60 | + # Fit the pipeline to training documents. |
57 | 61 | model = pipeline.fit(training)
|
58 | 62 |
|
| 63 | + # Prepare test documents, which are unlabeled. |
59 | 64 | Document = Row('id', 'text')
|
60 | 65 | test = sqlCtx.inferSchema(
|
61 | 66 | sc.parallelize([(4L, "spark i j k"),
|
|
64 | 69 | (7L, "apache hadoop")])
|
65 | 70 | .map(lambda x: Document(*x)))
|
66 | 71 |
|
| 72 | + # Make predictions on test documents and print columns of interest. |
67 | 73 | prediction = model.transform(test)
|
68 |
| - |
69 | 74 | prediction.registerTempTable("prediction")
|
70 | 75 | selected = sqlCtx.sql("SELECT id, text, prediction from prediction")
|
71 | 76 | for row in selected.collect():
|
72 | 77 | print row
|
| 78 | + |
| 79 | + sc.stop() |
0 commit comments