11

Pythonで次の関数を作成しました:

def cross_validate(algorithms, data, labels, cv=4, n_jobs=-1):
    print "Cross validation using: "
    for alg, predictors in algorithms:
        print alg
        print
        # Compute the accuracy score for all the cross validation folds. 
        scores = cross_val_score(alg, data, labels, cv=cv, n_jobs=n_jobs)
        # Take the mean of the scores (because we have one for each fold)
        print scores
        print("Cross validation mean score = " + str(scores.mean()))

        name = re.split('\(', str(alg))
        filename = str('%0.5f' %scores.mean()) + "_" + name[0] + ".pkl"
        # We might use this another time 
        joblib.dump(alg, filename, compress=1, cache_size=1e9)  
        filenameL.append(filename)
        try:
            move(filename, "pkl")
        except:
            os.remove(filename) 

        print 
    return

相互検証を行うには、sklearn が関数に適合する必要があると考えました。

ただし、後で使用しようとすると(fは上記で保存したpklファイルですjoblib.dump(alg, filename, compress=1, cache_size=1e9))

alg = joblib.load(f)  
predictions = alg.predict_proba(train_data[predictors]).astype(float)

最初の行ではエラーは発生しません (ロードが機能しているように見えます) が、次の行ではNotFittedError: Estimator not fitted, call適切であることがわかりbefore exploiting the model.ます。

私は何を間違っていますか?適合したモデルを再利用して交差検証を計算することはできませんか? scikits Learn で cross_val_score を使用する場合は適合パラメーターを保持するを見ましたが、答えが理解できないか、探しているものではありません。私が望むのは、モデル全体を joblib で保存して、後で再調整せずに使用できるようにすることです。

4

3 に答える 3

13

交差検証がモデルに適合しなければならないというのは正しくありません。むしろ、k 分割交差検証は、部分的なデータ セットでモデルを k 回適合します。モデル自体が必要な場合は、実際にはデータセット全体にモデルを再度適合させる必要があります。これは実際には相互検証プロセスの一部ではありません。したがって、実際に呼び出すのは冗長ではありません

alg.fit(data, labels)

交差検証後にモデルに適合します。

別のアプローチは、特殊な関数 を使用するのではなく、cross_val_scoreこれをクロス検証グリッド検索の特殊なケースと考えることができます (パラメーター空間内の単一のポイントを使用)。この場合GridSearchCV、デフォルトで、データセット全体にわたってモデルをリフィットし (パラメータがありますrefit=True)、またAPIに メソッドpredictとメソッドがあります。predict_proba

于 2016-07-24T23:03:04.450 に答える
0

Cross_val_score は適合モデルを保持しません Cross_val_predict は保持します cross_val_predict_proba はありませんが、これを行うことができます

交差検証済みモデルの predict_proba

于 2016-07-24T22:50:27.253 に答える