必要なすべてのライブラリをインポートします
from pyspark.mllib.classification import LogisticRegressionWithSGD, LogisticRegressionModel
from pyspark.mllib.linalg import SparseVector
from pyspark.mllib.regression import LabeledPoint
import re
データをRDDにロードする
msgs = [("I love Star Wars but I can't watch it today", 1.0),
("I don't love Star Wars and people want to watch it today", 0.0),
("I dislike not being able to watch Star Wars", 1.0),
("People who love Star Wars are my friends", 1.0),
("I preffer to watch Star Wars on Netflix", 0.0),
("George Lucas shouldn't have sold the franchise", 1.0),
("Disney makes better movies than everyone else", 0.0)]
rdd = sc.parallelize(msgs)
データをトークン化します ( MLを使用する場合は、より簡単になる可能性があります)。
rdd = rdd.map(lambda (text, label): ([w.lower() for w in re.split(" +", text)], label))
不要な単語 (ストップ ワードとして広く知られている) と記号をすべて削除します。,.&
commons = ["and", "but", "to"]
rdd = rdd.map(lambda (tokens, label): (filter(lambda token: token not in commons, tokens), label))
すべてのデータセットdistinct
内のすべての単語を含む辞書を作成します。巨大に聞こえますが、期待するほど多くはありません。マスターノードに収まるはずです (ただし、これにアプローチする方法は他にもありますが、簡単にするためにここではこのままにしてください)。
# finds different words
words = rdd.flatMap(lambda (tokens, label): tokens).distinct().collect()
diffwords = len(words)
あなたfeatures
をDenseVectorまたはSparseVectorに変換します。通常、 a を表すのに必要なスペースが少なくて済むため、明らかに 2 番目の方法をお勧めしSparseVector
ますが、それはデータに依存します。のようなより良い代替手段があることに注意してくださいhashing
。その後、をLabeledPointtuple
に変換します
def sparsify(length, tokens):
indices = [words.index(t) for t in set(tokens)]
quantities = [tokens.count(words[i]) for i in indices]
return SparseVector(length, [(indices[i], quantities[i]) for i in xrange(len(indices))])
rdd = rdd.map(lambda (tokens, label): LabeledPoint(label, sparsify(diffwords, tokens)))
お気に入りのモデルに合わせてください。この場合、下心のためにLogisticRegressionWithSGDを使用しました。
lrm = LogisticRegressionWithSGD.train(rdd)
モデルを保存します。
lrm.save(sc, "mylovelymodel.model")
LogisticRegressionModelを別のアプリケーションにロードします。
lrm = LogisticRegressionModel.load(sc, "mylovelymodel.model")
カテゴリを予測します。
lrm.predict(SparseVector(37,[2,4,5,13,15,19,23,26,27,29],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]))
# outputs 0