7

ツリーから生成された1つのルールをプログラムでテストしたいと思います。ツリーでは、ルートとリーフ(ターミナルノード)の間のパスは、原則として解釈できます。

Rでは、rpartパッケージを使用して次のことを行うことができます:(この投稿ではiris、例としてのみデータセットを使用します)

library(rpart)
model <- rpart(Species ~ ., data=iris)

この2行modelで、クラスがrpart.objectrpartドキュメント、21ページ)であるという名前のツリーを取得しました。このオブジェクトには多くの情報があり、さまざまなメソッドをサポートしています。特に、オブジェクトにはframe変数(標準的な方法でアクセスできます: model$frame)(idem)とメソッドpath.rpathrpartドキュメント、7ページ)があり、ルートノードから目的のノードへのパスを提供しますnode(働き)

変数のrow.namesframeは、ツリーのノード番号が含まれています。このvar列には、ノード内の分割変数、yval近似値とyval2クラス確率、およびその他の情報が表示されます。

> model$frame
           var   n  wt dev yval complexity ncompete nsurrogate     yval2.1     yval2.2     yval2.3     yval2.4     yval2.5     yval2.6     yval2.7
1 Petal.Length 150 150 100    1       0.50        3          3  1.00000000 50.00000000 50.00000000 50.00000000  0.33333333  0.33333333  0.33333333
2       <leaf>  50  50   0    1       0.01        0          0  1.00000000 50.00000000  0.00000000  0.00000000  1.00000000  0.00000000  0.00000000
3  Petal.Width 100 100  50    2       0.44        3          3  2.00000000  0.00000000 50.00000000 50.00000000  0.00000000  0.50000000  0.50000000
6       <leaf>  54  54   5    2       0.00        0          0  2.00000000  0.00000000 49.00000000  5.00000000  0.00000000  0.90740741  0.09259259
7       <leaf>  46  46   1    3       0.01        0          0  3.00000000  0.00000000  1.00000000 45.00000000  0.00000000  0.02173913  0.97826087

ただし<leaf>var列でマークされているのはターミナルノード(リーフ)のみです。この場合、ノードは2、6、および7です。

上記のようにpath.rpart、ルールを抽出する方法を使用できます(この手法は、 rattleパッケージおよびSharmaクレジットスコアの記事で次のように使用されます。

さらに、モデルは予測値の値を

predicted.levels <- attr(model, "ylevels")

yvalこの値は、model$frameデータセットの列に対応しています。

ノード番号7(行番号5)のリーフの場合、予測値は次のようになります。

> ylevels[model$frame[5, ]$yval]
[1] "virginica"

ルールは

> rule <- path.rpart(model, nodes = 7)

 node number: 7 
   root
   Petal.Length>=2.45
   Petal.Width>=1.75

したがって、ルールは次のように読み取ることができます

If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica

このルールに対して真陽性がいくつあるかをテストできることを知っています(テストデータセットでは、アイリスデータセットを再度使用します)。新しいデータセットを次のようにサブセット化します。

> hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75)

次に混同行列を計算します

> table(hits$Species, hits$Species == "virginica")

             FALSE TRUE
  setosa         0    0
  versicolor     1    0
  virginica      0   45

(注:テストと同じアイリスデータセットを使用しました)

プログラムでルールを評価するにはどうすればよいですか?次のようにルールから条件を抽出できます

> unlist(rule, use.names = FALSE)[-1]
[1] "Petal.Length>=2.45" "Petal.Width>=1.75" 

しかし、どうすればここから続けることができますか?subset機能が使えない

前もって感謝します

注: この質問は、わかりやすくするために大幅に編集されています

4

3 に答える 3

3

私はこれを次のように解決することができます

免責事項:明らかにこれを解決するためのより良い方法である必要がありますが、このハックは機能し、私が望むことを実行します...(私はそれをあまり誇りに思っていません...ハックですが、機能します)

じゃあ始めよう。基本的にアイデアはパッケージを使用することですsqldf

質問をチェックすると、コードの最後の部分が、ツリーのパスのすべての部分をリストに入れます。だから、そこから始めます

        library(sqldf)
        library(stringr)

        # Transform to a character vector
        rule.v <- unlist(rule, use.names=FALSE)[-1]
        # Remove all the dots, sqldf doesn't handles dots in names 
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2")
        # We have to remove all the equal signs to 'in ('
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('")
        # Embrace all the elements in the lists of values with " ' " 
        # The last element couldn't be modified in this way (Any ideas?) 
        rule.v <- str_replace_all(rule.v, pattern=",", replacement="','")

        # Close the last element with apostrophe and a ")" 
        for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) {
          rule.v[i] <- paste(append(rule.v[i], "')"), collapse="")
        }

        # Collapse all the list in one string joined by " AND "
        rule.v <- paste(rule.v, collapse = " AND ")

        # Generate the query
        # Use any metric that you can get from the data frame
        query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="")

        # For debug only...
        print(query)

        # Execute and print the results
        print(sqldf(query))

そしてそれがすべてです!

私はあなたに警告しました、それはハックでした...

うまくいけば、これは他の誰かを助ける...

すべての助けと提案をありがとう!

于 2012-08-14T19:47:40.283 に答える
2

一般的には使用をお勧めしませんeval(parse(...))が、この場合は機能するようです。

ルールを抽出します。

rule <- unname(unlist(path.rpart(model, nodes=7)))[-1]

 node number: 7 
   root
   Petal.Length>=2.45
   Petal.Width>=1.75
rule
[1] "Petal.Length>=2.45" "Petal.Width>=1.75" 

ルールを使用してデータを抽出します。

node_data <- with(iris, iris[eval(parse(text=paste(rule, collapse=" & "))), ])
head(node_data)

    Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
71           5.9         3.2          4.8         1.8 versicolor
101          6.3         3.3          6.0         2.5  virginica
102          5.8         2.7          5.1         1.9  virginica
103          7.1         3.0          5.9         2.1  virginica
104          6.3         2.9          5.6         1.8  virginica
105          6.5         3.0          5.8         2.2  virginica
于 2014-07-14T07:59:06.093 に答える
1

で始まります

Rule number: 16 [yval=bad cover=220 N=121 Y=99 (37%) prob=0.04]
checking< 2.5
afford< 54
history< 3.5
coapp< 2.5

すべてゼロで始まる「prob」ベクトルがあり、rule16で更新できます。

prob <- ifelse( dat[['checking']] < 2.5 &
                dat[['afford']]  < 54
                dat[['history']] < 3.5
                dat[['coapp']]  < 2.5) , 0.04, prob )

次に、他のすべてのルールを実行する必要があります(ツリーは互いに素な推定値である必要があるため、この場合の確率は変更されません)。予測を作成するには、これよりもはるかに効率的な方法があります。たとえば...predict.rpart関数。

于 2012-08-06T18:59:27.890 に答える