15

マルチクラス分類の問題を解決し、一般化ブースト モデル (R の gbm パッケージ) を使用しようとしています。私が直面した問題: キャレットのtrain機能はmethod="gbm"、マルチクラス データを適切に処理しないようです。簡単な例を以下に示します。

library(gbm)
library(caret)
data(iris)
fitControl <- trainControl(method="repeatedcv",
                           number=5,
                           repeats=1,
                           verboseIter=TRUE)
set.seed(825)
gbmFit <- train(Species ~ ., data=iris,
                method="gbm",
                trControl=fitControl,
                verbose=FALSE)
gbmFit

出力は

+ Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
predictions failed for Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
- Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150 
+ Fold1.Rep1: interaction.depth=2, shrinkage=0.1, n.trees=150 
...
+ Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
predictions failed for Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
- Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150 
Aggregating results
Selecting tuning parameters
Fitting interaction.depth = numeric(0), n.trees = numeric(0), shrinkage = numeric(0) on full training set
Error in if (interaction.depth < 1) { : argument is of length zero

それでも、キャレット ラッパーなしで gbm を使用しようとすると、良い結果が得られます。

set.seed(1365)
train <- createDataPartition(iris$Species, p=0.7, list=F)
train.iris <- iris[train,]
valid.iris <- iris[-train,]
gbm.fit.iris <- gbm(Species ~ ., data=train.iris, n.trees=200, verbose=FALSE)
gbm.pred <- predict(gbm.fit.iris, valid.iris, n.trees=200, type="response")
gbm.pred <- as.factor(colnames(gbm.pred)[max.col(gbm.pred)]) ##!
confusionMatrix(gbm.pred, valid.iris$Species)$overall

参考までに、 でマークされた行のコードは、##!によって返されたクラス確率の行列predict.gbmを最も確率の高いクラスの係数に変換します。出力は

      Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull AccuracyPValue  McnemarPValue 
  9.111111e-01   8.666667e-01   7.877883e-01   9.752470e-01   3.333333e-01   8.467252e-16            NaN 

マルチクラスデータの gbm でキャレットを適切に機能させる方法について何か提案はありますか?

更新:

sessionInfo()
R version 2.15.3 (2013-03-01)
Platform: x86_64-pc-linux-gnu (64-bit)

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8    LC_PAPER=C                 LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] splines   stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] e1071_1.6-1      class_7.3-5      gbm_2.0-8        survival_2.36-14 caret_5.15-61    reshape2_1.2.2   plyr_1.8        
 [8] lattice_0.20-13  foreach_1.4.0    cluster_1.14.3   compare_0.2-3   

loaded via a namespace (and not attached):
[1] codetools_0.2-8 compiler_2.15.3 grid_2.15.3     iterators_1.0.6 stringr_0.6.2   tools_2.15.3   
4

3 に答える 3

6

これは私が現在取り組んでいる問題です。

sessionInfo() の結果を投稿していただけると助かります。

また、 https: //code.google.com/p/gradientboostedmodels/ から最新の gbm を取得すると、問題が解決する場合があります。

マックス

于 2013-03-23T21:08:39.790 に答える
4

更新:キャレットはマルチクラス分類を実行できます。

クラス ラベルが英数字形式 (文字で始まる) であることを確認する必要があります。

例: データにラベル "1"、"2"、"3" がある場合、これらを "Seg1"、"Seg2"、"Seg3" に変更します。そうでない場合、キャレットは失敗します。

于 2014-08-10T06:15:32.737 に答える