コードに問題はありません。現在のフレームワークで、 の形式でデータを定義できる場合y = m*x + b
、このコードで十分です。実際に、線の方程式を定義し、それにガウス ランダム ノイズを追加するいくつかのテストを実行しました (振幅 = 0.1、平均 = 0、標準偏差 = 1)。
ただし、私が言及する 1 つの問題は、正弦波データを見ると、 の間にドメインを定義することです[0,2*pi]
。ご覧のとおりx
、同じ値にマップされるy
が大きさが異なる複数の値があります。たとえば、atx = pi/2
は 1 ですが、atx = -3*pi/2
は -1 です。この高い変動性は、線形回帰の前兆とはならないため、ドメインを制限することをお勧めします... [0, pi]
. おそらく収束しないもう 1 つの理由は、選択した学習率が高すぎることです。のような低い値に設定し0.01
ます。コメントで述べたように、あなたはすでにそれを理解しています!
ただし、線形回帰を使用して非線形データを当てはめたい場合は、変動性を考慮してより高次の項を含める必要があります。そのため、2 次および/または 3 次の項を含めてみてください。x
これは、次のようにマトリックスを変更することで簡単に実行できます。
x = [ones(1,length(x)); x; x.^2; x.^3];
思い出すと、仮説関数は線形項の和として表すことができます。
h(x) = theta0 + theta1*x1 + theta2*x2 + ... + thetan*xn
私たちの場合、各項theta
は多項式の高次項を構築します。 とx2
なるでしょう。したがって、ここでも線形回帰に勾配降下の定義を使用できます。x^2
x3
x^3
また、ランダム生成シードを ( 経由でrng
) 制御して、得られたのと同じ結果を生成できるようにします。
clear all;
close all;
rng(123123);
total_loss = [];
m = 20;
x = linspace(0,pi,m); %// Change
y = sin(x);
w = rand(1,4); %// Change
rate = 0.01; %// Change
x = [ones(1,length(x)); x; x.^2; x.^3]; %// Change - Second and third order terms
for i=1:500
h = w*x;
loss = sum((h-y).^2)/m/2;
total_loss = [total_loss loss];
% gradient is now in a different expression
gradient = (h-y)*x'./m ; % sum all in each iteration, it's a batch gradient
w = w - rate.*gradient;
end
これを試すと、w
(パラメータ)が得られます。
>> format long g;
>> w
w =
Columns 1 through 3
0.128369521905694 0.819533906064327 -0.0944622478526915
Column 4
-0.0596638117151464
この時点以降の私の最終的な損失は次のとおりです。
loss =
0.00154350916582836
これは、直線の式が次のようになることを意味します。
y = 0.12 + 0.819x - 0.094x^2 - 0.059x^3
この直線の方程式を正弦波データでプロットすると、次のようになります。
xval = x(2,:);
plot(xval, y, xval, polyval(fliplr(w), xval))
legend('Original', 'Fitted');
