2

一般的なデータセットを使用して、データセットで発生したエラーを再現しようとしています。何か不足している場合は修正してください。

Classificationを使用してツリーをフィッティングした後library(party)、各ノードでツリーの分割条件を取得しようとしています。バグが見つかるまで、問題なく動作していると信じていたコードを書くことができました。誰かがそれを解決するのを手伝ってくれますか?

私のコード:

 require(party)
 iris$Petal.Width <- as.factor(iris$Petal.Width)#imp to convert to factorial
 (ct <- ctree(Species ~ ., data = iris))
 plot(ct)
 #print(ct)
  a <- ct #convert it to s4 object
  t <- a@tree

 #recursive function to traverse the tree and get the splitting conditions

 recurse_tree <- function(tree,ret_list=list(),sub_list=list()){
   if(!tree$terminal){
     sub_list$assign <-list(tree$psplit$splitpoint,tree$psplit$variableName,class(tree$psplit))
     names(sub_list)[which(names(sub_list)=="assign")] <- paste("node",tree$nodeID,sep="")
     ret_list <- recurse_tree(tree$left, ret_list, sub_list)
     ret_list <- recurse_tree(tree$right, ret_list, sub_list)

  }
  if(tree$terminal){
    ret_list$assign <- c(sub_list, tree$prediction)
    names(ret_list)[which(names(ret_list)=="assign")] <- paste("node",tree$nodeID,sep="")
    return(ret_list)
   }
   return(ret_list)
  }

  result <- recurse_tree(t) #call to the functions

今、結果は私にすべてのノードと分割条件と予測のリストを与えます(私は仮定しました)。しかし、分割条件を確認すると、

  • Node5で予想される出力: {1.1, 1.2, 1.6, 1.7 }# from printing the tree print(ct), I get this
  • 私の関数から Node5に出力: {"1" , "1.3" ,"1.4" ,"1.5" } これは基本的に Node6 の分割条件であり、これは間違っています。どうやってこれを手に入れたのですか?

    z <- result[2] #I know node5 is second in the list

    z <- unlist(z,recursive = F,use.names = T) #unlist
    levels(z[[3]][[1]]) [which((z[[3]][[1]])==0)] #to find levels of corresponding values
    

詳細についてはプロット

私が疑問に思っているのは、私の function( recurse_tree) は常に、左側のノードではなく、右側のターミナル ノードの分割条件を示していることです。どんな助けでも大歓迎です。

4

0 に答える 0