以下に、BNTツールボックスを使用して単純ベイズネットを構築する方法を示す完全な例を示します。車のデータセットのサブセットを使用しています。離散属性と連続属性の両方が含まれています。
便宜上、統計ツールボックスを必要とするいくつかの関数を使用しています。
データセットを準備することから始めます。
%# load dataset
D = load('carsmall');
%# keep only features of interest
D = rmfield(D, {'Mfg','Horsepower','Displacement','Model'});
%# filter the rows to keep only two classes
idx = ismember(D.Origin, {'USA' 'Japan'});
D = structfun(@(x)x(idx,:), D, 'UniformOutput',false);
numInst = sum(idx);
%# replace missing values with mean
D.MPG(isnan(D.MPG)) = nanmean(D.MPG);
%# convert discrete attributes to numeric indices 1:mx
[D.Origin,~,gnOrigin] = grp2idx( cellstr(D.Origin) );
[D.Cylinders,~,gnCylinders] = grp2idx( D.Cylinders );
[D.Model_Year,~,gnModel_Year] = grp2idx( D.Model_Year );
次に、グラフィカルモデルを作成します。
%# info about the nodes
nodeNames = fieldnames(D);
numNodes = numel(nodeNames);
node = [nodeNames num2cell((1:numNodes)')]';
node = struct(node{:});
dNodes = [node.Origin node.Cylinders node.Model_Year];
cNodes = [node.MPG node.Weight node.Acceleration];
depNodes = [node.MPG node.Cylinders node.Weight ...
node.Acceleration node.Model_Year];
vals = cell(1,numNodes);
vals(dNodes) = cellfun(@(f) unique(D.(f)), nodeNames(dNodes), 'Uniform',false);
nodeSize = ones(1,numNodes);
nodeSize(dNodes) = cellfun(@numel, vals(dNodes));
%# DAG
dag = false(numNodes);
dag(node.Origin, depNodes) = true;
%# create naive bayes net
bnet = mk_bnet(dag, nodeSize, 'discrete',dNodes, 'names',nodeNames, ...
'observed',depNodes);
for i=1:numel(dNodes)
name = nodeNames{dNodes(i)};
bnet.CPD{dNodes(i)} = tabular_CPD(bnet, node.(name), ...
'prior_type','dirichlet');
end
for i=1:numel(cNodes)
name = nodeNames{cNodes(i)};
bnet.CPD{cNodes(i)} = gaussian_CPD(bnet, node.(name));
end
%# visualize the graph
[~,~,h] = draw_graph(bnet.dag, nodeNames);
hTxt = h(:,1); hNodes = h(:,2);
set(hTxt(node.Origin), 'FontWeight','bold', 'Interpreter','none')
set(hNodes(node.Origin), 'FaceColor','g')
set(hTxt(depNodes), 'Color','k', 'Interpreter','none')
set(hNodes(depNodes), 'FaceColor','y')
次に、データをトレーニング/テストに分割します。
%# build samples as cellarray
data = num2cell(cell2mat(struct2cell(D)')');
%# split train/test: 1/3 for testing, 2/3 for training
cv = cvpartition(D.Origin, 'HoldOut',1/3);
trainData = data(:,cv.training);
testData = data(:,cv.test);
testData(1,:) = {[]}; %# remove class
最後に、トレーニングセットからパラメーターを学習し、テストデータのクラスを予測します。
%# training
bnet = learn_params(bnet, trainData);
%# testing
prob = zeros(nodeSize(node.Origin), sum(cv.test));
engine = jtree_inf_engine(bnet); %# Inference engine
for i=1:size(testData,2)
[engine,loglik] = enter_evidence(engine, testData(:,i));
marg = marginal_nodes(engine, node.Origin);
prob(:,i) = marg.T;
end
[~,pred] = max(prob);
actual = D.Origin(cv.test)';
%# confusion matrix
predInd = full(sparse(1:numel(pred),pred,1));
actualInd = full(sparse(1:numel(actual),actual,1));
conffig(predInd, actualInd); %# confmat
%# ROC plot and AUC
figure
[~,~,auc] = plotROC(max(prob), pred==actual, 'b')
title(sprintf('Area Under the Curve = %g',auc))
set(findobj(gca, 'type','line'), 'LineWidth',2)
結果:

そして、各ノードでCPTと平均/シグマを抽出できます。
cellfun(@(x)dispcpt(struct(x).CPT), bnet.CPD(dNodes), 'Uniform',false)
celldisp(cellfun(@(x)struct(x).mean, bnet.CPD(cNodes), 'Uniform',false))
celldisp(cellfun(@(x)struct(x).cov, bnet.CPD(cNodes), 'Uniform',false))