これは、matlab 言語での勾配降下アルゴリズムの私自身の実装です。
m = height(data_training); % number of samples
cols = {'x1', 'x2', 'x3', 'x4', 'x5', 'x6',...
'x7', 'x8','x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15'};
y = data_training{:, {'y'}}';
X = [ones(m,1) data_training{:,cols}]';
theta = zeros(1,width(data_training));
alpha = 1e-2; % learning rate
iter = 400;
dJ = zeros(1,width(data_training));
J_seq = zeros(1, iter);
for n = 1:iter
err = (theta*X - y);
for j = 1:width(data_training)
dJ(j) = 1/m*sum(err*X(j,:)');
end
J = 1/(2*m)*sum((theta*X-y).^2);
theta = theta - alpha.*dJ;
J_seq(n) = J;
if mod(n,100) == 0
plot(1:iter, J_seq);
end
end
作業アルゴリズムを編集
このアルゴリズムを次のトレーニング データセットに適用しました。最後の列は出力変数です。ここには 15 の異なる機能があります。
理由は不明ですが、収束に向かっているかどうかを確認するために 50 回の反復後にコスト関数 J をプロットすると、収束していないことがわかります。理解するのを手伝ってもらえますか? 実装が間違っているのでしょうか、それとも何かを作る必要がありますか?
36 27 71 8.1 3.34 11.4 81.5 3243 8.8 42.6 11.7 21 15 59 59 921.87
35 23 72 11.1 3.14 11 78.8 4281 3.6 50.7 14.4 8 10 39 57 997.88
44 29 74 10.4 3.21 9.8 81.6 4260 0.8 39.4 12.4 6 6 33 54 962.35
47 45 79 6.5 3.41 11.1 77.5 3125 27.1 50.2 20.6 18 8 24 56 982.29
43 35 77 7.6 3.44 9.6 84.6 6441 24.4 43.7 14.3 43 38 206 55 1071.3
53 45 80 7.7 3.45 10.2 66.8 3325 38.5 43.1 25.5 30 32 72 54 1030.4
43 30 74 10.9 3.23 12.1 83.9 4679 3.5 49.2 11.3 21 32 62 56 934.7
45 30 73 9.3 3.29 10.6 86 2140 5.3 40.4 10.5 6 4 4 56 899.53
36 24 70 9 3.31 10.5 83.2 6582 8.1 42.5 12.6 18 12 37 61 1001.9
36 27 72 9.5 3.36 10.7 79.3 4213 6.7 41 13.2 12 7 20 59 912.35
52 42 79 7.7 3.39 9.6 69.2 2302 22.2 41.3 24.2 18 8 27 56 1017.6
33 26 76 8.6 3.2 10.9 83.4 6122 16.3 44.9 10.7 88 63 278 58 1024.9
40 34 77 9.2 3.21 10.2 77 4101 13 45.7 15.1 26 26 146 57 970.47
35 28 71 8.8 3.29 11.1 86.3 3042 14.7 44.6 11.4 31 21 64 60 985.95
37 31 75 8 3.26 11.9 78.4 4259 13.1 49.6 13.9 23 9 15 58 958.84
35 46 85 7.1 3.22 11.8 79.9 1441 14.8 51.2 16.1 1 1 1 54 860.1
36 30 75 7.5 3.35 11.4 81.9 4029 12.4 44 12 6 4 16 58 936.23
15 30 73 8.2 3.15 12.2 84.2 4824 4.7 53.1 12.7 17 8 28 38 871.77
31 27 74 7.2 3.44 10.8 87 4834 15.8 43.5 13.6 52 35 124 59 959.22
30 24 72 6.5 3.53 10.8 79.5 3694 13.1 33.8 12.4 11 4 11 61 941.18
31 45 85 7.3 3.22 11.4 80.7 1844 11.5 48.1 18.5 1 1 1 53 891.71
31 24 72 9 3.37 10.9 82.8 3226 5.1 45.2 12.3 5 3 10 61 871.34
42 40 77 6.1 3.45 10.4 71.8 2269 22.7 41.4 19.5 8 3 5 53 971.12
43 27 72 9 3.25 11.5 87.1 2909 7.2 51.6 9.5 7 3 10 56 887.47
46 55 84 5.6 3.35 11.4 79.7 2647 21 46.9 17.9 6 5 1 59 952.53
39 29 76 8.7 3.23 11.4 78.6 4412 15.6 46.6 13.2 13 7 33 60 968.66
35 31 81 9.2 3.1 12 78.3 3262 12.6 48.6 13.9 7 4 4 55 919.73
43 32 74 10.1 3.38 9.5 79.2 3214 2.9 43.7 12 11 7 32 54 844.05
11 53 68 9.2 2.99 12.1 90.6 4700 7.8 48.9 12.3 648 319 130 47 861.83
30 35 71 8.3 3.37 9.9 77.4 4474 13.1 42.6 17.7 38 37 193 57 989.26
50 42 82 7.3 3.49 10.4 72.5 3497 36.7 43.3 26.4 15 10 34 59 1006.5
60 67 82 10 2.98 11.5 88.6 4657 13.6 47.3 22.4 3 1 1 60 861.44
30 20 69 8.8 3.26 11.1 85.4 2934 5.8 44 9.4 33 23 125 64 929.15
25 12 73 9.2 3.28 12.1 83.1 2095 2 51.9 9.8 20 11 26 50 857.62
45 40 80 8.3 3.32 10.1 70.3 2682 21 46.1 24.1 17 14 78 56 961.01
46 30 72 10.2 3.16 11.3 83.2 3327 8.8 45.3 12.2 4 3 8 58 923.23