フィードフォワード逆伝播オートエンコーダー (勾配降下によるトレーニング) を実装しようとしており、勾配を正しく計算していることを確認したいと考えていました。このチュートリアルでは、各パラメーターの導関数を一度に 1 つずつ計算することを提案していますgrad_i(theta) = (J(theta_i+epsilon) - J(theta_i-epsilon)) / (2*epsilon)
。これを行うためにMatlabでサンプルコードを書きましたが、あまり運がありませんでした.導関数から計算された勾配と数値的に検出された勾配の差は大きくなる傾向があります(>>有効数字4桁)。
誰かが提案を提供できる場合は、(勾配の計算またはチェックの実行方法のいずれかで)助けていただければ幸いです。読みやすくするためにコードを大幅に単純化したため、バイアスを含めず、重み行列を結び付けていません。
まず、変数を初期化します。
numHidden = 200;
numVisible = 784;
low = -4*sqrt(6./(numHidden + numVisible));
high = 4*sqrt(6./(numHidden + numVisible));
encoder = low + (high-low)*rand(numVisible, numHidden);
decoder = low + (high-low)*rand(numHidden, numVisible);
次に、いくつかの入力 imagex
を指定して、フィードフォワード伝搬を行います。
a = sigmoid(x*encoder);
z = sigmoid(a*decoder); % (reconstruction of x)
私が使用している損失関数は、標準の Σ(0.5*(z - x)^2)) です。
% first calculate the error by finding the derivative of sum(0.5*(z-x).^2),
% which is (f(h)-x)*f'(h), where z = f(h), h = a*decoder, and
% f = sigmoid(x). However, since the derivative of the sigmoid is
% sigmoid*(1 - sigmoid), we get:
error_0 = (z - x).*z.*(1-z);
% The gradient \Delta w_{ji} = error_j*a_i
gDecoder = error_0'*a;
% not important, but included for completeness
% do back-propagation one layer down
error_1 = (error_0*encoder).*a.*(1-a);
gEncoder = error_1'*x;
最後に、勾配が正しいことを確認します (この場合は、デコーダーに対してのみ行います)。
epsilon = 10e-5;
check = gDecoder(:); % the values we obtained above
for i = 1:size(decoder(:), 1)
% calculate J+
theta = decoder(:); % unroll
theta(i) = theta(i) + epsilon;
decoderp = reshape(theta, size(decoder)); % re-roll
a = sigmoid(x*encoder);
z = sigmoid(a*decoderp);
Jp = sum(0.5*(z - x).^2);
% calculate J-
theta = decoder(:);
theta(i) = theta(i) - epsilon;
decoderp = reshape(theta, size(decoder));
a = sigmoid(x*encoder);
z = sigmoid(a*decoderp);
Jm = sum(0.5*(z - x).^2);
grad_i = (Jp - Jm) / (2*epsilon);
diff = abs(grad_i - check(i));
fprintf('%d: %f <=> %f: %f\n', i, grad_i, check(i), diff);
end
これを MNIST データセット (最初のエントリ) で実行すると、次のような結果が得られます。
2: 0.093885 <=> 0.028398: 0.065487
3: 0.066285 <=> 0.031096: 0.035189
5: 0.053074 <=> 0.019839: 0.033235
6: 0.108249 <=> 0.042407: 0.065843
7: 0.091576 <=> 0.009014: 0.082562