MATLAB で自己組織化マップ (SOM) アルゴリズムを実装しました。各データ ポイントが 2 次元空間で表されているとします。問題は、トレーニング段階で各データ ポイントの動きを視覚化したいことです。つまり、ポイントがどのように移動し、最終的にアルゴリズムが修正期間ごとに進行しているときにクラスターを形成するかを確認したいということです。これは MATLAB でのシミュレーションを通じて実行できると思いますが、視覚化のために MATLAB コードを組み込む方法がわかりません。
1 に答える
可能なすべてのデータ射影を 2 次元で使用して、多次元のクラスタリング データを視覚化するコード例を作成しました。特に高次元数の場合、視覚化には最適なアイデアではないかもしれません (このニーズに SOM 自体が使用される可能性があるため、このために開発された手法があります) (n-1)!
。良いビジュアライザー。
クラスターアルゴリズム
各反復のクラスター平均とクラスター ラベルを保存できるようにコードにアクセスする必要があったため、Mo ChenによるFEX で入手可能な高速 kmeans アルゴリズムを使用しましたが、このアクセスができるように適応させる必要がありました。適応されたコードは次のとおりです。
function [label,m] = litekmeans(X, k)
% Perform k-means clustering.
% X: d x n data matrix
% k: number of seeds
% Written by Michael Chen (sth4nth@gmail.com).
n = size(X,2);
last = 0;
iter = 1;
label{iter} = ceil(k*rand(1,n)); % random initialization
checkLabel = label{iter};
m = {};
while any(checkLabel ~= last)
[u,~,checkLabel] = unique(checkLabel); % remove empty clusters
k = length(u);
E = sparse(1:n,checkLabel,1,n,k,n); % transform label into indicator matrix
curM = X*(E*spdiags(1./sum(E,1)',0,k,k)); % compute m of each cluster
m{iter} = curM;
last = checkLabel';
[~,checkLabel] = max(bsxfun(@minus,curM'*X,dot(curM,curM,1)'/2),[],1); % assign samples to the nearest centers
iter = iter + 1;
label{iter} = checkLabel;
end
% Get last clusters centers
m{iter} = curM;
% If to remove empty clusters:
%for k=1:iter
% [~,~,label{k}] = unique(label{k});
%end
GIF作成
gif の作成には、@Amro の Matlab ビデオ チュートリアルも使用しました。
識別可能な色
クラスターの色を簡単に区別できるようにするために、Tim Holyによるこの素晴らしい FEXを使用しました。
結果のコード
私の結果のコードは次のとおりです。反復ごとにクラスターの数が変化し、散布図の更新によりエラーが発生せずにすべてのクラスターの中心が削除されるため、いくつかの問題がありました。私はそれに気付かなかったので、ウェブを見つけることができるあいまいな方法で散布関数を回避しようとしていました(ところで、ここで本当に素晴らしい散布図の代替を見つけました)が、幸いなことに何が起こっていたのかを今日はこれ。これが私が行ったコードです。自由に使用したり、適応させたりしてもかまいませんが、使用する場合は私の参照を保管してください。
function varargout=kmeans_test(data,nClusters,plotOpts,dimLabels,...
bigXDim,bigYDim,gifName)
%
% [label,m,figH,handles]=kmeans_test(data,nClusters,plotOpts,...
% dimLabels,bigXDim,bigYDim,gifName)
% Demonstrate kmeans algorithm iterative progress. Inputs are:
%
% -> data (rand(5,100)): the data to use.
%
% -> nClusters (7): number of clusters to use.
%
% -> plotOpts: struct holding the following fields:
%
% o leftBase: the percentage distance from the left
%
% o rightBase: the percentage distance from the right
%
% o bottomBase: the percentage distance from the bottom
%
% o topBase: the percentage distance from the top
%
% o FontSize: FontSize for axes labels.
%
% o widthUsableArea: Total width occupied by axes
%
% o heigthUsableArea: Total heigth occupied by axes
%
% -> bigXDim (1): the big subplot x dimension
%
% -> bigYDim (2): the big subplot y dimension
%
% -> dimLabels: If you want to specify dimensions labels
%
% -> gifName: gif file name to save
%
% Outputs are:
%
% -> label: Sample cluster center number for each iteration
%
% -> m: cluster center mean for each iteration
%
% -> figH: figure handle
%
% -> handles: axes handles
%
%
% - Creation Date: Fri, 13 Sep 2013
% - Last Modified: Mon, 16 Sep 2013
% - Author(s):
% - W.S.Freund <wsfreund_at_gmail_dot_com>
%
% TODO List (?):
%
% - Use input parser
% - Adapt it to be able to cluster any algorithm function.
% - Use arrows indicating cluster centers movement before moving them.
% - Drag and drop small axes to big axes.
%
% Pre-start
if nargin < 7
gifName = 'kmeansClusterization.gif';
if nargin < 6
bigYDim = 2;
if nargin < 5
bigXDim = 1;
if nargin < 4
nDim = size(data,1);
maxDigits = numel(num2str(nDim));
dimLabels = mat2cell(sprintf(['Dim %0' num2str(maxDigits) 'd'],...
1:nDim),1,zeros(1,nDim)+4+maxDigits);
if nargin < 3
plotOpts = struct('leftBase',.05,'rightBase',.02,...
'bottomBase',.05,'topBase',.02,'FontSize',10,...
'widthUsableArea',.87,'heigthUsableArea',.87);
if nargin < 2
nClusters = 7;
if nargin < 1
center1 = [1; 0; 0; 0; 0];
center2 = [0; 1; 0; 0; 0];
center3 = [0; 0; 1; 0; 0];
center4 = [0; 0; 0; 1; 0];
center5 = [0; 0; 0; 0; 1];
center6 = [0; 0; 0; 0; 1.5];
center7 = [0; 0; 0; 1.5; 1];
data = [...
bsxfun(@plus,center1,.5*rand(5,20)) ...
bsxfun(@plus,center2,.5*rand(5,20)) ...
bsxfun(@plus,center3,.5*rand(5,20)) ...
bsxfun(@plus,center4,.5*rand(5,20)) ...
bsxfun(@plus,center5,.5*rand(5,20)) ...
bsxfun(@plus,center6,.2*rand(5,20)) ...
bsxfun(@plus,center7,.2*rand(5,20)) ...
];
end
end
end
end
end
end
end
% NOTE of advice: It seems that Matlab does not test while on
% refreshdata if the dimension of the inputs are equivalent for the
% XData, YData and CData while using scatter. Because of this I wasted
% a lot of time trying to debug what was the problem, trying many
% workaround since my cluster centers would disappear for no reason.
% Draw axes:
nDim = size(data,1);
figH = figure;
set(figH,'Units', 'normalized', 'Position',...
[0, 0, 1, 1],'Color','w','Name',...
'k-means example','NumberTitle','Off',...
'MenuBar','none','Toolbar','figure',...
'Renderer','zbuffer');
% Create dintinguishable colors matrix:
colorMatrix = distinguishable_colors(nClusters);
% Create axes, deploy them on handles matrix more or less how they
% will be positioned:
[handles,horSpace,vertSpace] = ...
createAxesGrid(5,5,plotOpts,dimLabels);
% Add main axes
bigSubSize = ceil(nDim/2);
bigSubVec(bigSubSize^2) = 0;
for k = 0:nDim-bigSubSize
bigSubVec(k*bigSubSize+1:(k+1)*bigSubSize) = ...
... %(nDim-bigSubSize+k)*nDim+1:(nDim-bigSubSize+k)*nDim+(nDim-bigSubSize+1);
bigSubSize+nDim*k:nDim*(k+1);
end
handles(bigSubSize,bigSubSize) = subplot(nDim,nDim,bigSubVec,...
'FontSize',plotOpts.FontSize,'box','on');
bigSubplotH = handles(bigSubSize,bigSubSize);
horSpace(bigSubSize,bigSubSize) = bigSubSize;
vertSpace(bigSubSize,bigSubSize) = bigSubSize;
set(bigSubplotH,'NextPlot','add',...
'FontSize',plotOpts.FontSize,'box','on',...
'XAxisLocation','top','YAxisLocation','right');
% Squeeze axes through space to optimize space usage and improve
% visualization capability:
[leftPos,botPos,subplotWidth,subplotHeight]=setCustomPlotArea(...
handles,plotOpts,horSpace,vertSpace);
pColorAxes = axes('Position',[leftPos(end) botPos(end) ...
subplotWidth subplotHeight],'Parent',figH);
pcolor([1:nClusters+1;1:nClusters+1])
% image(reshape(colorMatrix,[1 size(colorMatrix)])); % Used image to
% check if the upcoming buggy behaviour would be fixed. I was not
% lucky, though...
colormap(pColorAxes,colorMatrix);
% Change XTick positions to its center:
set(pColorAxes,'XTick',.5:1:nClusters+.5);
set(pColorAxes,'YTick',[]);
% Change its label to cluster number:
set(pColorAxes,'XTickLabel',[nClusters 1:nClusters-1]); % FIXME At
% least on my matlab I have to use this buggy way to set XTickLabel.
% Am I doing something wrong? Since I dunno why this is caused, I just
% change the code so that it looks the way it should look, but this is
% quite strange...
xlabel(pColorAxes,'Clusters Colors','FontSize',plotOpts.FontSize);
% Now iterate throw data and get cluster information:
[label,m]=litekmeans(data,nClusters);
nIters = numel(m)-1;
scatterColors = colorMatrix(label{1},:);
annH = annotation('textbox',[leftPos(1),botPos(1) subplotWidth ...
subplotHeight],'String',sprintf('Start Conditions'),'EdgeColor',...
'none','FontSize',18);
% Creates dimData_%d variables for first iteration:
for curDim=1:nDim
curDimVarName = genvarname(sprintf('dimData_%d',curDim));
eval([curDimVarName,'= m{1}(curDim,:);']);
end
% clusterColors will hold the colors for the total number of clusters
% on each iteration:
clusterColors = colorMatrix;
% Draw cluster information for first iteration:
for curColumn=1:nDim
for curLine=curColumn+1:nDim
% Big subplot data:
if curColumn == bigXDim && curLine == bigYDim
curAxes = handles(bigSubSize,bigSubSize);
curScatter = scatter(curAxes,data(curColumn,:),...
data(curLine,:),16,'filled');
set(curScatter,'CDataSource','scatterColors');
% Draw cluster centers
curColumnVarName = genvarname(sprintf('dimData_%d',curColumn));
curLineVarName = genvarname(sprintf('dimData_%d',curLine));
eval(['curScatter=scatter(curAxes,' curColumnVarName ',' ...
curLineVarName ',100,colorMatrix,''^'',''filled'');']);
set(curScatter,'XDataSource',curColumnVarName,'YDataSource',...
curLineVarName,'CDataSource','clusterColors')
end
% Small subplots data:
curAxes = handles(curLine,curColumn);
% Draw data:
curScatter = scatter(curAxes,data(curColumn,:),...
data(curLine,:),16,'filled');
set(curScatter,'CDataSource','scatterColors');
% Draw cluster centers
curColumnVarName = genvarname(sprintf('dimData_%d',curColumn));
curLineVarName = genvarname(sprintf('dimData_%d',curLine));
eval(['curScatter=scatter(curAxes,' curColumnVarName ',' ...
curLineVarName ',100,colorMatrix,''^'',''filled'');']);
set(curScatter,'XDataSource',curColumnVarName,'YDataSource',...
curLineVarName,'CDataSource','clusterColors');
if curLine==nDim
xlabel(curAxes,dimLabels{curColumn});
set(curAxes,'XTick',xlim(curAxes));
end
if curColumn==1
ylabel(curAxes,dimLabels{curLine});
set(curAxes,'YTick',ylim(curAxes));
end
end
end
refreshdata(figH,'caller');
% Preallocate gif frame. From Amro's tutorial here:
% https://stackoverflow.com/a/11054155/1162884
f = getframe(figH);
[f,map] = rgb2ind(f.cdata, 256, 'nodither');
mov = repmat(f, [1 1 1 nIters+4]);
% Add one frame at start conditions:
curFrame = 1;
% Add three frames without movement at start conditions
f = getframe(figH);
mov(:,:,1,curFrame) = rgb2ind(f.cdata, map, 'nodither');
for curIter = 1:nIters
curFrame = curFrame+1;
% Change label text
set(annH,'String',sprintf('Iteration %d',curIter));
% Update cluster point colors
scatterColors = colorMatrix(label{curIter+1},:);
% Update cluster centers:
for curDim=1:nDim
curDimVarName = genvarname(sprintf('dimData_%d',curDim));
eval([curDimVarName,'= m{curIter+1}(curDim,:);']);
end
% Update cluster colors:
nClusterIter = size(m{curIter+1},2);
clusterColors = colorMatrix(1:nClusterIter,:);
% Update graphics:
refreshdata(figH,'caller');
% Update cluster colors:
if nClusterIter~=size(m{curIter},2) % If number of cluster
% of current iteration differs from previous iteration (or start
% conditions in case we are at first iteration) we redraw colors:
pcolor([1:nClusterIter+1;1:nClusterIter+1])
% image(reshape(colorMatrix,[1 size(colorMatrix)])); % Used image to
% check if the upcomming buggy behaviour would be fixed. I was not
% lucky, though...
colormap(pColorAxes,clusterColors);
% Change XTick positions to its center:
set(pColorAxes,'XTick',.5:1:nClusterIter+.5);
set(pColorAxes,'YTick',[]);
% Change its label to cluster number:
set(pColorAxes,'XTickLabel',[nClusterIter 1:nClusterIter-1]);
xlabel(pColorAxes,'Clusters Colors','FontSize',plotOpts.FontSize);
end
f = getframe(figH);
mov(:,:,1,curFrame) = rgb2ind(f.cdata, map, 'nodither');
end
set(annH,'String','Convergence Conditions');
for curFrame = nIters+1:nIters+3
% Add three frames without movement at start conditions
f = getframe(figH);
mov(:,:,1,curFrame) = rgb2ind(f.cdata, map, 'nodither');
end
imwrite(mov, map, gifName, 'DelayTime',.5, 'LoopCount',inf)
varargout = cell(1,nargout);
if nargout > 0
varargout{1} = label;
if nargout > 1
varargout{2} = m;
if nargout > 2
varargout{3} = figH;
if nargout > 3
varargout{4} = handles;
end
end
end
end
end
function [leftPos,botPos,subplotWidth,subplotHeight] = ...
setCustomPlotArea(handles,plotOpts,horSpace,vertSpace)
%
% -> handles: axes handles
%
% -> plotOpts: struct holding the following fields:
%
% o leftBase: the percentage distance from the left
%
% o rightBase: the percentage distance from the right
%
% o bottomBase: the percentage distance from the bottom
%
% o topBase: the percentage distance from the top
%
% o widthUsableArea: Total width occupied by axes
%
% o heigthUsableArea: Total heigth occupied by axes
%
% -> horSpace: the axes units size (integers only) that current axes
% should occupy in the horizontal (considering that other occupied
% axes handles are empty)
%
% -> vertSpace: the axes units size (integers only) that current axes
% should occupy in the vertical (considering that other occupied
% axes handles are empty)
%
nHorSubPlot = size(handles,1);
nVertSubPlot = size(handles,2);
if nargin < 4
horSpace(nHorSubPlot,nVertSubPlot) = 0;
horSpace = horSpace+1;
if nargin < 3
vertSpace(nHorSubPlot,nVertSubPlot) = 0;
vertSpace = vertSpace+1;
end
end
subplotWidth = plotOpts.widthUsableArea/nHorSubPlot;
subplotHeight = plotOpts.heigthUsableArea/nVertSubPlot;
totalWidth = (1-plotOpts.rightBase) - plotOpts.leftBase;
totalHeight = (1-plotOpts.topBase) - plotOpts.bottomBase;
gapHeigthSpace = (totalHeight - ...
plotOpts.heigthUsableArea)/(nVertSubPlot);
gapWidthSpace = (totalWidth - ...
plotOpts.widthUsableArea)/(nHorSubPlot);
botPos(nVertSubPlot) = plotOpts.bottomBase + gapWidthSpace/2;
leftPos(1) = plotOpts.leftBase + gapHeigthSpace/2;
botPos(nVertSubPlot-1:-1:1) = botPos(nVertSubPlot) + (subplotHeight +...
gapHeigthSpace)*(1:nVertSubPlot-1);
leftPos(2:nHorSubPlot) = leftPos(1) + (subplotWidth +...
gapWidthSpace)*(1:nHorSubPlot-1);
for curLine=1:nHorSubPlot
for curColumn=1:nVertSubPlot
if handles(curLine,curColumn)
set(handles(curLine,curColumn),'Position',[leftPos(curColumn)...
botPos(curLine) horSpace(curLine,curColumn)*subplotWidth ...
vertSpace(curLine,curColumn)*subplotHeight]);
end
end
end
end
function [handles,horSpace,vertSpace] = ...
createAxesGrid(nLines,nColumns,plotOpts,dimLabels)
handles = zeros(nLines,nColumns);
% Those hold the axes size units:
horSpace(nLines,nColumns) = 0;
vertSpace(nLines,nColumns) = 0;
for curColumn=1:nColumns
for curLine=curColumn+1:nLines
handles(curLine,curColumn) = subplot(nLines,...
nColumns,curColumn+(curLine-1)*nColumns);
horSpace(curLine,curColumn) = 1;
vertSpace(curLine,curColumn) = 1;
curAxes = handles(curLine,curColumn);
if feature('UseHG2')
colormap(handle(curAxes),colorMatrix);
end
set(curAxes,'NextPlot','add',...
'FontSize',plotOpts.FontSize,'box','on');
if curLine==nLines
xlabel(curAxes,dimLabels{curColumn});
else
set(curAxes,'XTick',[]);
end
if curColumn==1
ylabel(curAxes,dimLabels{curLine});
else
set(curAxes,'YTick',[]);
end
end
end
end
例
コードを使用して、5 つのディメンションを使用する例を次に示します。
center1 = [1; 0; 0; 0; 0];
center2 = [0; 1; 0; 0; 0];
center3 = [0; 0; 1; 0; 0];
center4 = [0; 0; 0; 1; 0];
center5 = [0; 0; 0; 0; 1];
center6 = [0; 0; 0; 0; 1.5];
center7 = [0; 0; 0; 1.5; 1];
data = [...
bsxfun(@plus,center1,.5*rand(5,20)) ...
bsxfun(@plus,center2,.5*rand(5,20)) ...
bsxfun(@plus,center3,.5*rand(5,20)) ...
bsxfun(@plus,center4,.5*rand(5,20)) ...
bsxfun(@plus,center5,.5*rand(5,20)) ...
bsxfun(@plus,center6,.2*rand(5,20)) ...
bsxfun(@plus,center7,.2*rand(5,20)) ...
];
[label,m,figH,handles]=kmeans_test(data,20);