ナイーブベイズ - RDDベースAPI
ナイーブベイズは、各特徴量のペア間の独立性を仮定した単純な多クラス分類アルゴリズムです。ナイーブベイズは非常に効率的に学習できます。トレーニングデータへの単一パス内で、ラベルを所与とした各特徴量の条件付き確率分布を計算し、次にベイズの定理を適用して、観測値を所与としたラベルの条件付き確率分布を計算し、予測に使用します。
spark.mllib
は多項ナイーブベイズとベルヌーイナイーブベイズをサポートしています。これらのモデルは通常、文書分類に使用されます。このコンテキストでは、各観測値は文書であり、各特徴量は、その値が多項ナイーブベイズでは用語の頻度、ベルヌーイナイーブベイズでは用語が文書に見つかったかどうかを示す0または1である用語を表します。特徴量の値は非負でなければなりません。「multinomial」または「bernoulli」のオプションパラメータでモデルタイプを選択し、「multinomial」がデフォルトです。加算スムージングは、パラメータλ(デフォルトは1.0)を設定することで使用できます。文書分類では、入力特徴ベクトルは通常スパースであり、スパース性の利点を活用するために、スパースベクトルを入力として提供する必要があります。トレーニングデータは一度しか使用されないため、キャッシュする必要はありません。
例
NaiveBayesは多項ナイーブベイズを実装します。LabeledPointのRDDとオプションのスムージングパラメータlambda
を入力として受け取り、評価と予測に使用できるNaiveBayesModelを出力します。
Python APIはまだモデルの保存/読み込みをサポートしていませんが、将来的にはサポートされる予定です。
APIの詳細については、NaiveBayes
PythonドキュメントとNaiveBayesModel
Pythonドキュメントを参照してください。
from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file.
data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
# Split data approximately into training (60%) and test (40%)
training, test = data.randomSplit([0.6, 0.4])
# Train a naive Bayes model.
model = NaiveBayes.train(training, 1.0)
# Make prediction and test accuracy.
predictionAndLabel = test.map(lambda p: (model.predict(p.features), p.label))
accuracy = 1.0 * predictionAndLabel.filter(lambda pl: pl[0] == pl[1]).count() / test.count()
print('model accuracy {}'.format(accuracy))
# Save and load model
output_dir = 'target/tmp/myNaiveBayesModel'
shutil.rmtree(output_dir, ignore_errors=True)
model.save(sc, output_dir)
sameModel = NaiveBayesModel.load(sc, output_dir)
predictionAndLabel = test.map(lambda p: (sameModel.predict(p.features), p.label))
accuracy = 1.0 * predictionAndLabel.filter(lambda pl: pl[0] == pl[1]).count() / test.count()
print('sameModel accuracy {}'.format(accuracy))
NaiveBayesは多項ナイーブベイズを実装します。LabeledPointのRDDとオプションのスムージングパラメータlambda
、オプションのモデルタイプパラメータ(デフォルトは「multinomial」)を入力として受け取り、評価と予測に使用できるNaiveBayesModelを出力します。
APIの詳細については、NaiveBayes
ScalaドキュメントとNaiveBayesModel
Scalaドキュメントを参照してください。
import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split data into training (60%) and test (40%).
val Array(training, test) = data.randomSplit(Array(0.6, 0.4))
val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial")
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
// Save and load model
model.save(sc, "target/tmp/myNaiveBayesModel")
val sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel")
NaiveBayesは多項ナイーブベイズを実装します。LabeledPointのScala RDDとオプションのスムージングパラメータlambda
を入力として受け取り、評価と予測に使用できるNaiveBayesModelを出力します。
APIの詳細については、NaiveBayes
JavaドキュメントとNaiveBayesModel
Javaドキュメントを参照してください。
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
String path = "data/mllib/sample_libsvm_data.txt";
JavaRDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
JavaRDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[]{0.6, 0.4});
JavaRDD<LabeledPoint> training = tmp[0]; // training set
JavaRDD<LabeledPoint> test = tmp[1]; // test set
NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
JavaPairRDD<Double, Double> predictionAndLabel =
test.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
double accuracy =
predictionAndLabel.filter(pl -> pl._1().equals(pl._2())).count() / (double) test.count();
// Save and load model
model.save(jsc.sc(), "target/tmp/myNaiveBayesModel");
NaiveBayesModel sameModel = NaiveBayesModel.load(jsc.sc(), "target/tmp/myNaiveBayesModel");