To test this multi-class classifier, we can try it on handwritten digit recognition problem. Get hand-written digits data from here.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// using https://github.com/Bekbolatov/spark/commit/463d73323d5f08669d5ae85dc9791b036637c966 | |
import org.apache.spark.mllib.classification.SVMMultiClassWithSGD | |
import org.apache.spark.mllib.regression.LabeledPoint | |
import org.apache.spark.mllib.linalg.Vectors | |
import breeze.linalg.DenseVector | |
val digits_train = sc.textFile("/data/pendigits.tra").map(line => DenseVector(line.split(",").map(_.trim().toDouble))).map( v => LabeledPoint(v(-1),Vectors.dense(v(0 to 15).toArray))).cache() | |
val digits_test = sc.textFile("/data/pendigits.tes").map(line => DenseVector(line.split(",").map(_.trim().toDouble))).map( v => LabeledPoint(v(-1),Vectors.dense(v(0 to 15).toArray))) | |
val model = SVMMultiClassWithSGD.train(digits_train, 100) | |
val predictionAndLabel = digits_test.map(p => (model.predict(p.features), p.label)) | |
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / digits_test.count() | |
val scoreAndLabels = digits_test.map { point => | |
val score = model.predict(point.features) | |
(score, point.label) | |
} | |
scoreAndLabels.take(5) |
For comparison, here are some results with tree classifiers. With RandomForest (30 trees, Gini, depth 7) it goes up to 93%. Adding extra 2nd order interactions (Spark doesn't support kernels in classification yet, but here a simple feature transformation that adds second order feature interactions), and increasing allowed tree depth to 15, brings accuracy to 97%. So, there is a lot of room for improvement in multiclass to binary classifier reduction.
No comments:
Post a Comment