14

Scikit Learn を使用して Python でランダム フォレストを使用するのに苦労しています。私の問題は、テキスト分類 (3 つのクラス - ポジティブ/ネガティブ/ニュートラル) に使用し、抽出する特徴は主に単語/ユニグラムであるため、これらを数値特徴に変換する必要があることです。DictVectorizerでそれを行う方法を見つけましたfit_transform

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from sklearn.feature_extraction import DictVectorizer

vec = DictVectorizer(sparse=False)
rf = RandomForestClassifier(n_estimators = 100)
trainFeatures1 = vec.fit_transform(trainFeatures)

# Fit the training data to the training output and create the decision trees
rf = rf.fit(trainFeatures1.toarray(), LabelEncoder().fit_transform(trainLabels))

testFeatures1 = vec.fit_transform(testFeatures)
# Take the same decision trees and run on the test data
Output = rf.score(testFeatures1.toarray(), LabelEncoder().fit_transform(testLabels))

print "accuracy: " + str(Output)

私の問題は、fit_transformメソッドが約 8000 のインスタンスを含むトレーニング データセットで動作していることですが、テスト セットを約 80000 のインスタンスである数値機能にも変換しようとすると、次のようなメモリ エラーが発生します。

testFeatures1 = vec.fit_transform(testFeatures)
File "C:\Python27\lib\site-packages\sklearn\feature_extraction\dict_vectorizer.py", line 143, in fit_transform
return self.transform(X)
File "C:\Python27\lib\site-packages\sklearn\feature_extraction\dict_vectorizer.py", line 251, in transform
Xa = np.zeros((len(X), len(vocab)), dtype=dtype)
MemoryError

何が原因で、回避策はありますか? どうもありがとう!

4

1 に答える 1

15

fit_transformテストデータに対して行うべきではありませんが、 transform. そうしないと、トレーニング中に使用されたベクトル化とは異なるベクトル化が得られます。

メモリの問題についてはTfIdfVectorizer、次元を減らすための多数のオプションがある (まれなユニグラムなどを削除することにより) をお勧めします。

アップデート

唯一の問題がテストデータのフィッティングである場合は、単純に小さなチャンクに分割します。のようなものの代わりに

x=vect.transform(test)
eval(x)

できるよ

K=10
for i in range(K):
    size=len(test)/K
    x=vect.transform(test[ i*size : (i+1)*size ])
    eval(x)

結果/統計を記録し、後で分析します。

特に

predictions = []

K=10
for i in range(K):
    size=len(test)/K
    x=vect.transform(test[ i*size : (i+1)*size ])
    predictions += rf.predict(x) # assuming it retuns a list of labels, otherwise - convert it to list

print accuracy_score( predictions, true_labels )
于 2014-02-25T06:50:07.300 に答える