1

実際には線形モデルを使用して一連の「sin」データに適合させたいのですが、反復ごとに損失関数が大きくなることがわかりました。以下のコードに問題はありますか? (勾配降下法)

これがMatlabの私のコードです

m=20;
rate = 0.1;
x = linspace(0,2*pi,20);
x = [ones(1,length(x));x]
y = sin(x);
w = rand(1,2);
for i=1:500
    h = w*x;
    loss = sum((h-y).^2)/m/2 
    total_loss = [total_loss loss];
    **gradient = (h-y)*x'./m ;**
    w = w - rate.*gradient;
end

ここに私が適合させたいデータがあります y=sin(x)

4

1 に答える 1

1

コードに問題はありません。現在のフレームワークで、 の形式でデータを定義できる場合y = m*x + b、このコードで十分です。実際に、線の方程式を定義し、それにガウス ランダム ノイズを追加するいくつかのテストを実行しました (振幅 = 0.1、平均 = 0、標準偏差 = 1)。

ただし、私が言及する 1 つの問題は、正弦波データを見ると、 の間にドメインを定義することです[0,2*pi]。ご覧のとおりx、同じ値にマップされるyが大きさが異なる複数の値があります。たとえば、atx = pi/2は 1 ですが、atx = -3*pi/2は -1 です。この高い変動性は、線形回帰の前兆とはならないため、ドメインを制限することをお勧めします... [0, pi]. おそらく収束しないもう 1 つの理由は、選択した学習率が高すぎることです。のような低い値に設定し0.01ます。コメントで述べたように、あなたはすでにそれを理解しています!

ただし、線形回帰を使用して非線形データを当てはめたい場合は、変動性を考慮してより高次の項を含める必要があります。そのため、2 次および/または 3 次の項を含めてみてください。xこれは、次のようにマトリックスを変更することで簡単に実行できます。

x = [ones(1,length(x)); x; x.^2; x.^3];

思い出すと、仮説関数は線形項の和として表すことができます。

h(x) = theta0 + theta1*x1 + theta2*x2 + ... + thetan*xn

私たちの場合、各項thetaは多項式の高次項を構築します。 とx2なるでしょう。したがって、ここでも線形回帰に勾配降下の定義を使用できます。x^2x3x^3

また、ランダム生成シードを ( 経由でrng) 制御して、得られたのと同じ結果を生成できるようにします。

clear all; 
close all;
rng(123123);
total_loss = [];
m = 20;
x = linspace(0,pi,m); %// Change
y = sin(x);
w = rand(1,4); %// Change
rate = 0.01; %// Change
x = [ones(1,length(x)); x; x.^2; x.^3]; %// Change - Second and third order terms
for i=1:500
    h = w*x;
    loss = sum((h-y).^2)/m/2;
    total_loss = [total_loss loss];
    % gradient is now in a different expression
    gradient = (h-y)*x'./m ; % sum all in each iteration, it's a batch gradient
    w = w - rate.*gradient;
end

これを試すと、w(パラメータ)が得られます。

>> format long g;
>> w


w =

  Columns 1 through 3

         0.128369521905694         0.819533906064327       -0.0944622478526915

  Column 4

       -0.0596638117151464

この時点以降の私の最終的な損失は次のとおりです。

loss =

       0.00154350916582836

これは、直線の式が次のようになることを意味します。

y = 0.12 + 0.819x - 0.094x^2 - 0.059x^3

この直線の方程式を正弦波データでプロットすると、次のようになります。

xval = x(2,:);
plot(xval, y, xval, polyval(fliplr(w), xval))
legend('Original', 'Fitted');

ここに画像の説明を入力

于 2015-03-09T15:17:16.003 に答える