-
Notifications
You must be signed in to change notification settings - Fork 45
Description
import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.ml.classification.{DecisionTreeClassifier, DecisionTreeClassificationModel}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.functions._
object ActiveLearningExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("ActiveLearningExample")
.master("local[*]")
.config("spark.driver.host", "localhost")
.getOrCreate()
// 打印当前的 Spark 配置
println("Current Spark Configuration:")
spark.conf.getAll.foreach(println)
// 读取数据
val data = spark.read.parquet("/home/hadoop/桌面/cleaned_data")
println("Initial Data:")
data.show(5)
println(s"Count of initial data: ${data.count()}")
// 检查空值并处理
val dataWithoutNulls = data.na.drop()
println("Data without nulls:")
dataWithoutNulls.show(5)
println(s"Count after removing nulls: ${dataWithoutNulls.count()}")
// 转换列的类型为数值类型
val dataWithNumeric = dataWithoutNulls
.withColumn("sepal_length", col("sepal_length").cast("double"))
.withColumn("sepal_width", col("sepal_width").cast("double"))
.withColumn("petal_length", col("petal_length").cast("double"))
.withColumn("petal_width", col("petal_width").cast("double"))
.withColumn("species", col("species").cast("double"))
println("Data with Numeric Columns:")
dataWithNumeric.show(5)
println(s"Count after type conversion: ${dataWithNumeric.count()}")
// 准备特征向量
val featureCols = Array("sepal_length", "sepal_width", "petal_length", "petal_width")
val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
val assembledData = assembler.transform(dataWithNumeric)
println("Data with Features Vector:")
assembledData.select("features", "species").show(5)
println(s"Count after assembling features: ${assembledData.count()}")
// 将标签列转换为数值类型
val labeledData = assembledData.withColumn("label", col("species"))
val Array(trainingData, testData) = labeledData.randomSplit(Array(0.8, 0.2), seed = 1234L)
println(s"Training data count: ${trainingData.count()}")
println(s"Test data count: ${testData.count()}")
// 训练决策树分类器
val dt = new DecisionTreeClassifier().setFeaturesCol("features").setLabelCol("label")
val dtModel = dt.fit(trainingData)
// 打印模型信息
println("Decision Tree Model:")
println(s"Number of nodes: ${dtModel.numNodes}")
println(s"Depth of tree: ${dtModel.depth}")
// 执行主动学习策略,选取样本进行标注
val samplesToLabel = ActiveLearningStrategy.selectSamplesForLabeling(testData, dtModel, 5)
// 打印选取的样本
println("Selected Samples for Labeling:")
samplesToLabel.show()
spark.stop()
}
}
object ActiveLearningStrategy {
import org.apache.spark.ml.linalg.Vector // 确保导入正确的Vector类型
def calculateEntropy(probabilities: Vector): Double = {
probabilities.toArray.map(p => if (p == 0) 0 else -p * Math.log(p)).sum
}
def selectSamplesForLabeling(data: DataFrame, model: DecisionTreeClassificationModel, k: Int): DataFrame = {
val predictions = model.transform(data)
// 打印预测结果
println("Predictions:")
predictions.select("features", "probability", "prediction").show(5)
val entropyUDF = udf((probability: Vector) => calculateEntropy(probability))
val dataWithEntropy = predictions.withColumn("entropy", entropyUDF(col("probability")))
// 打印带有熵值的数据
println("Data with Entropy:")
dataWithEntropy.select("features", "probability", "entropy").show(5)
dataWithEntropy.orderBy(desc("entropy")).limit(k)
}
}