これを認めるのは少し恥ずかしいですが、単純なプログラミングの問題であるべきものにかなり困惑しているようです。私は意思決定ツリーの実装を構築しており、再帰を使用してラベル付きサンプルのリストを取得し、リストを再帰的に半分に分割してツリーに変換しています。
残念ながら、深いツリーではスタック オーバーフロー エラーが発生するので (ha!)、最初に考えたのは、継続を使用して末尾再帰に変換することでした。残念ながら、Scala はそのような TCO をサポートしていないため、唯一の解決策はトランポリンを使用することです。トランポリンはちょっと効率が悪いようで、この問題に対する単純なスタックベースの必須の解決策があることを望んでいましたが、それを見つけるのに苦労しています。
再帰的なバージョンは次のようになります (簡略化):
private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = {
if (shouldStop(samples)) {
DTLeaf(makeProportions(samples))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
DTBranch(
trainTree(statsWithFeature, usedFeatures + featureIdx),
trainTree(statsWithoutFeature, usedFeatures + featureIdx),
featureIdx)
}
}
したがって、基本的には、データのいくつかの機能に応じてリストを再帰的に 2 つに分割し、使用する機能のリストを通過させるので、繰り返しません。これはすべて「getSplittingFeature」関数で処理されるため、無視できます。コードは本当に簡単です!それでも、クロージャーを使用するだけでなく、効果的にトランポリンになるスタックベースのソリューションを見つけるのに苦労しています。少なくとも、スタック内の引数の小さな「フレーム」を維持する必要があることはわかっていますが、クロージャー呼び出しは避けたいと思います。
コールスタックとプログラムカウンターが暗黙的に再帰ソリューションで処理するものを明示的に書き出す必要があることがわかりましたが、継続なしでそれを行うのに問題があります。現時点では、効率性についてはほとんど問題ではありません。ただ興味があるだけです。時期尚早の最適化は諸悪の根源であり、トランポリン ベースのソリューションはおそらく問題なく機能することを思い出してください。私はそれがおそらくそうであることを知っています - これは基本的にそれ自体のためのパズルです.
この種の標準的なwhileループとスタックベースのソリューションが何であるかを誰か教えてもらえますか?
更新: Thipor Kong の優れたソリューションに基づいて、再帰バージョンの直接翻訳であるはずのアルゴリズムの while-loops/stacks/hashtable ベースの実装をコード化しました。これはまさに私が探していたものです:
最終更新: 順次整数インデックスを使用し、パフォーマンスのためにすべてをマップの代わりに配列に戻し、maxDepth サポートを追加し、最終的に再帰バージョンと同じパフォーマンスのソリューションを手に入れました (メモリ使用量についてはわかりませんが、私は少ないと思います):
private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = {
// Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit
type DenseIntMap[T] = ArrayBuffer[T]
def updateIntMap[@specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = {
if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) }
ab.update(idx, item)
}
var currentChildId = 0 // get childIdx or create one if it's not there already
def child(childMap: DenseIntMap[Int], heapIdx: Int) =
if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx)
else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId }
// go down
val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx
val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx
val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx
val nodes = new DenseIntMap[DTree]() // heapIdx -> node
while (!todo.isEmpty) {
val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop()
if (shouldStop(samples) || maxDepth == 0) {
updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples)))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx)))
todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx)))
branches.push((heapIdx, featureIdx))
}
}
// go up
while (!branches.isEmpty) {
val (heapIdx, featureIdx) = branches.pop()
updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx))
}
nodes(0)
}