2

これは、matlab 言語での勾配降下アルゴリズムの私自身の実装です。

 m = height(data_training); % number of samples
cols = {'x1', 'x2', 'x3', 'x4', 'x5', 'x6',...
    'x7', 'x8','x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15'}; 

y = data_training{:, {'y'}}';
X = [ones(m,1) data_training{:,cols}]'; 

theta = zeros(1,width(data_training));

alpha = 1e-2; % learning rate
iter = 400;

dJ = zeros(1,width(data_training));

J_seq = zeros(1, iter);

for n = 1:iter

    err = (theta*X - y);

    for j = 1:width(data_training)
        dJ(j) = 1/m*sum(err*X(j,:)');
    end

    J = 1/(2*m)*sum((theta*X-y).^2);

    theta = theta - alpha.*dJ;

    J_seq(n) = J;

    if mod(n,100) == 0
        plot(1:iter, J_seq);
    end
end

作業アルゴリズムを編集

このアルゴリズムを次のトレーニング データセットに適用しました。最後の列は出力変数です。ここには 15 の異なる機能があります。

理由は不明ですが、収束に向かっているかどうかを確認するために 50 回の反復後にコスト関数 J をプロットすると、収束していないことがわかります。理解するのを手伝ってもらえますか? 実装が間違っているのでしょうか、それとも何かを作る必要がありますか?

36    27    71     8.1    3.34    11.4    81.5    3243     8.8    42.6    11.7     21     15     59    59     921.87
35    23    72    11.1    3.14      11    78.8    4281     3.6    50.7    14.4      8     10     39    57     997.88
44    29    74    10.4    3.21     9.8    81.6    4260     0.8    39.4    12.4      6      6     33    54     962.35
47    45    79     6.5    3.41    11.1    77.5    3125    27.1    50.2    20.6     18      8     24    56     982.29
43    35    77     7.6    3.44     9.6    84.6    6441    24.4    43.7    14.3     43     38    206    55     1071.3
53    45    80     7.7    3.45    10.2    66.8    3325    38.5    43.1    25.5     30     32     72    54     1030.4
43    30    74    10.9    3.23    12.1    83.9    4679     3.5    49.2    11.3     21     32     62    56      934.7
45    30    73     9.3    3.29    10.6      86    2140     5.3    40.4    10.5      6      4      4    56     899.53
36    24    70       9    3.31    10.5    83.2    6582     8.1    42.5    12.6     18     12     37    61     1001.9
36    27    72     9.5    3.36    10.7    79.3    4213     6.7      41    13.2     12      7     20    59     912.35
52    42    79     7.7    3.39     9.6    69.2    2302    22.2    41.3    24.2     18      8     27    56     1017.6
33    26    76     8.6     3.2    10.9    83.4    6122    16.3    44.9    10.7     88     63    278    58     1024.9
40    34    77     9.2    3.21    10.2      77    4101      13    45.7    15.1     26     26    146    57     970.47
35    28    71     8.8    3.29    11.1    86.3    3042    14.7    44.6    11.4     31     21     64    60     985.95
37    31    75       8    3.26    11.9    78.4    4259    13.1    49.6    13.9     23      9     15    58     958.84
35    46    85     7.1    3.22    11.8    79.9    1441    14.8    51.2    16.1      1      1      1    54      860.1
36    30    75     7.5    3.35    11.4    81.9    4029    12.4      44      12      6      4     16    58     936.23
15    30    73     8.2    3.15    12.2    84.2    4824     4.7    53.1    12.7     17      8     28    38     871.77
31    27    74     7.2    3.44    10.8      87    4834    15.8    43.5    13.6     52     35    124    59     959.22
30    24    72     6.5    3.53    10.8    79.5    3694    13.1    33.8    12.4     11      4     11    61     941.18
31    45    85     7.3    3.22    11.4    80.7    1844    11.5    48.1    18.5      1      1      1    53     891.71
31    24    72       9    3.37    10.9    82.8    3226     5.1    45.2    12.3      5      3     10    61     871.34
42    40    77     6.1    3.45    10.4    71.8    2269    22.7    41.4    19.5      8      3      5    53     971.12
43    27    72       9    3.25    11.5    87.1    2909     7.2    51.6     9.5      7      3     10    56     887.47
46    55    84     5.6    3.35    11.4    79.7    2647      21    46.9    17.9      6      5      1    59     952.53
39    29    76     8.7    3.23    11.4    78.6    4412    15.6    46.6    13.2     13      7     33    60     968.66
35    31    81     9.2     3.1      12    78.3    3262    12.6    48.6    13.9      7      4      4    55     919.73
43    32    74    10.1    3.38     9.5    79.2    3214     2.9    43.7      12     11      7     32    54     844.05
11    53    68     9.2    2.99    12.1    90.6    4700     7.8    48.9    12.3    648    319    130    47     861.83
30    35    71     8.3    3.37     9.9    77.4    4474    13.1    42.6    17.7     38     37    193    57     989.26
50    42    82     7.3    3.49    10.4    72.5    3497    36.7    43.3    26.4     15     10     34    59     1006.5
60    67    82      10    2.98    11.5    88.6    4657    13.6    47.3    22.4      3      1      1    60     861.44
30    20    69     8.8    3.26    11.1    85.4    2934     5.8      44     9.4     33     23    125    64     929.15
25    12    73     9.2    3.28    12.1    83.1    2095       2    51.9     9.8     20     11     26    50     857.62
45    40    80     8.3    3.32    10.1    70.3    2682      21    46.1    24.1     17     14     78    56     961.01
46    30    72    10.2    3.16    11.3    83.2    3327     8.8    45.3    12.2      4      3      8    58     923.23
4

1 に答える 1