1

平均シフト クラスタリングに問題があります。クラスター数が少ない場合 (2、3、4) は非常に高速に動作し、正しい結果を出力しますが、クラスター数が増えると失敗します。

たとえば、3 つのクラスターが正常に検出されます。 クラスタの成功

しかし、数が増えると失敗します: クラスター センターの障害 クラスターが失敗する

完全なコード リストは次のとおりです。

#!/usr/bin/env python

import sys
import logging

import numpy as np

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plot

from sklearn.cluster import estimate_bandwidth, MeanShift, get_bin_seeds
from sklearn.datasets.samples_generator import make_blobs


def test_mean_shift():
    logging.debug('Generating mixture')
    count = 5000
    blocks = 7
    std_error = 0.5
    mixture, clusters = make_blobs(n_samples=count, centers=blocks, cluster_std=std_error)

    logging.debug('Measuring bendwith')
    bandwidth = estimate_bandwidth(mixture)
    logging.debug('Bandwidth: %r' % bandwidth)

    mean_shift = MeanShift(bandwidth=bandwidth)

    logging.debug('Clustering')
    mean_shift.fit(mixture)

    shifted = mean_shift.cluster_centers_
    guess = mean_shift.labels_

    logging.debug('Centers: %r' % shifted)

    def draw_mixture(mixture, clusters, output='mixture.png'):
        plot.clf()
        plot.scatter(mixture[:, 0], mixture[:, 1],
                     c=clusters,
                     cmap=plot.cm.coolwarm)
        plot.savefig(output)

    def draw_mixture_shifted(mixture, shifted, output='mixture_shifted.png'):
        plot.clf()
        plot.scatter(mixture[:, 0], mixture[:, 1], c='r')
        plot.scatter(shifted[:, 0], shifted[:, 1], c='b')
        plot.savefig(output)

    logging.debug('Drawing')
    draw_mixture_shifted(mixture, shifted)
    draw_mixture(mixture, guess)


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    test_mean_shift()

私は何を間違っていますか?

4

1 に答える 1

1

おそらく、より狭い帯域幅を選択する必要があります。私は、ヒューリスティックによって帯域幅が選択される方法にあまり詳しくありません。したがって、ここでの「問題」は、実際のアルゴリズムではなく、ヒューリスティックにあります。

于 2013-01-28T20:38:41.240 に答える