はい、ウィキペディアの KD ツリーでの NN (Nearest Neighbour) 検索の説明は、理解するのが少し難しいです。NN KD Tree 検索で上位に表示される Google 検索結果の多くが単純に間違っていることは役に立ちません。
これを正しく行う方法を示す C++ コードを次に示します。
template <class T, std::size_t N>
void KDTree<T,N>::nearest (
const const KDNode<T,N> &node,
const std::array<T, N> &point, // looking for closest node to this point
const KDPoint<T,N> &closest, // closest node (so far)
double &minDist,
const uint depth) const
{
if (node->isLeaf()) {
const double dist = distance(point, node->leaf->point);
if (dist < minDist) {
minDist = dist;
closest = node->leaf;
}
} else {
const T dim = depth % N;
if (point[dim] < node->splitVal) {
// search left first
nearest(node->left, point, closest, minDist, depth + 1);
if (point[dim] + minDist >= node->splitVal)
nearest(node->right, point, closest, minDist, depth + 1);
} else {
// search right first
nearest(node->right, point, closest, minDist, depth + 1);
if (point[dim] - minDist <= node->splitVal)
nearest(node->left, point, closest, minDist, depth + 1);
}
}
}
KD ツリーでの NN 検索の API:
// Nearest neighbour
template <class T, std::size_t N>
const KDPoint<T,N> KDTree<T,N>::nearest (const std::array<T, N> &point) const {
const KDPoint<T,N> closest;
double minDist = std::numeric_limits<double>::max();
nearest(root, point, closest, minDist);
return closest;
}
デフォルトの距離関数:
template <class T, std::size_t N>
double distance (const std::array<T, N> &p1, const std::array<T, N> &p2) {
double d = 0.0;
for (uint i = 0; i < N; ++i) {
d += pow(p1[i] - p2[i], 2.0);
}
return sqrt(d);
}
編集: 一部の人々は、(NN アルゴリズムだけでなく) データ構造についても助けを求めているので、ここで私が使用したものを示します。目的によっては、データ構造を少し変更したい場合があります。(注: ただし、ほとんどの場合、NN アルゴリズムを変更したくないでしょう。)
KD ポイント クラス:
template <class T, std::size_t N>
class KDPoint {
public:
KDPoint<T,N> (std::array<T,N> &&t) : point(std::move(t)) { };
virtual ~KDPoint<T,N> () = default;
std::array<T, N> point;
};
KDNode クラス:
template <class T, std::size_t N>
class KDNode
{
public:
KDNode () = delete;
KDNode (const KDNode &) = delete;
KDNode & operator = (const KDNode &) = delete;
~KDNode () = default;
// branch node
KDNode (const T split,
std::unique_ptr<const KDNode> &lhs,
std::unique_ptr<const KDNode> &rhs) : splitVal(split), left(std::move(lhs)), right(std::move(rhs)) { };
// leaf node
KDNode (std::shared_ptr<const KDPoint<T,N>> p) : splitVal(0), leaf(p) { };
bool isLeaf (void) const { return static_cast<bool>(leaf); }
// data members
const T splitVal;
const std::unique_ptr<const KDNode<T,N>> left, right;
const std::shared_ptr<const KDPoint<T,N>> leaf;
};
KDTree クラス: (注: ツリーを構築/埋めるためにメンバー関数を追加する必要があります。)
template <class T, std::size_t N>
class KDTree {
public:
KDTree () = delete;
KDTree (const KDTree &) = delete;
KDTree (KDTree &&t) : root(std::move(const_cast<std::unique_ptr<const KDNode<T,N>>&>(t.root))) { };
KDTree & operator = (const KDTree &) = delete;
~KDTree () = default;
const KDPoint<T,N> nearest (const std::array<T, N> &point) const;
// Nearest neighbour search - runs in O(log n)
void nearest (const std::unique_ptr<const KDNode<T,N>> &node,
const std::array<T, N> &point,
std::shared_ptr<const KDPoint<T,N>> &closest,
double &minDist,
const uint depth = 0) const;
// data members
const std::unique_ptr<const KDNode<T,N>> root;
};