1

私は非常に大きな列車セットを持っているので、Matlab. そして、大規模なトレーニングを行う必要があります。

トレーニングセットを部分に分割し、ネットワークを繰り返しトレーニングし、反復ごとに上書きする代わりに「ネット」を更新することは可能ですか?

以下のコードはアイデアを示しており、機能しません。各反復で、トレーニングされたデータセットのみに応じてネットを更新します。

TF1 = 'tansig';TF2 = 'tansig'; TF3 = 'tansig';% layers of the transfer function , TF3 transfer function for the output layers

net = newff(trainSamples.P,trainSamples.T,[NodeNum1,NodeNum2,NodeOutput],{TF1 TF2 TF3},'traingdx');% Network created

net.trainfcn = 'traingdm' ; %'traingdm';
net.trainParam.epochs   = 1000;
net.trainParam.min_grad = 0;
net.trainParam.max_fail = 2000; %large value for infinity

while(1) // iteratively takes 10 data point at a time.
 p %=> get updated with following 10 new data points
 t %=> get updated with following 10 new data points

 [net,tr]             = train(net, p, t,[], []);

end
4

2 に答える 2

2

関数をまだ見る機会がありませんがadapt、上書きではなく更新していると思われます。このステートメントを検証するには、最初のデータ チャンクのサブセットをトレーニングの 2 番目のチャンクとして選択する必要がある場合があります。上書きしている場合、サブセットでトレーニング済みのネットを使用して最初のデータ チャンクをテストすると、サブセットに属さないデータの予測が不十分になるはずです。

非常に単純なプログラムでテストしました: train the curve y=x^2。最初のトレーニング プロセスで、次のデータ セットを学習しました[1,3,5,7,9]

   m=6;
   P=[1 3 5 7 9];
   T=P.^2;
   [Pn,minP,maxP,Tn,minT,maxT] = premnmx(P,T);
   clear net
   net.IW{1,1}=zeros(m,1);
   net.LW{2,1}=zeros(1,m);
   net.b{1,1}=zeros(m,1);
   net.b{2,1}=zeros(1,1);
   net=newff(minmax(Pn),[m,1],{'logsig','purelin'},'trainlm');
   net.trainParam.show =100;
   net.trainParam.lr = 0.09;
   net.trainParam.epochs =1000;
   net.trainParam.goal = 1e-3; 
   [net,tr]=train(net,Pn,Tn);
   Tn_predicted= sim(net,Pn)
   Tn

結果 (出力は同じ基準でスケーリングされることに注意してください。標準の正規化を行っている場合は、最初のトレーニング セットの平均値と標準値を残りのすべてのセットに常に適用するようにしてください):

Tn_predicted =

   -1.0000   -0.8000   -0.4000    0.1995    1.0000


Tn =

   -1.0000   -0.8000   -0.4000    0.2000    1.0000

次に、トレーニング データを使用して 2 番目のトレーニング プロセスを実装します[1,9]

   Pt=[1 9];
   Tt=Pt.^2;
   n=length(Pt);
   Ptn = tramnmx(Pt,minP,maxP);
   Ttn = tramnmx(Tt,minT,maxT);


   [net,tr]=train(net,Ptn,Ttn);
   Tn_predicted= sim(net,Pn)
   Tn

結果:

Tn_predicted =

   -1.0000   -0.8000   -0.4000    0.1995    1.0000


Tn =

   -1.0000   -0.8000   -0.4000    0.2000    1.0000

x=[3,5,7];のデータは依然として正確に予測されていることに注意してください。

ただし、x=[1,9];最初からのみトレーニングすると、次のようになります。

   clear net
   net.IW{1,1}=zeros(m,1);
   net.LW{2,1}=zeros(1,m);
   net.b{1,1}=zeros(m,1);
   net.b{2,1}=zeros(1,1);
   net=newff(minmax(Ptn),[m,1],{'logsig','purelin'},'trainlm');
   net.trainParam.show =100;
   net.trainParam.lr = 0.09;
   net.trainParam.epochs =1000;
   net.trainParam.goal = 1e-3; 
   [net,tr]=train(net,Ptn,Ttn);
   Tn_predicted= sim(net,Pn)
   Tn

結果を見る:

Tn_predicted =

   -1.0071   -0.6413    0.5281    0.6467    0.9922


Tn =

   -1.0000   -0.8000   -0.4000    0.2000    1.0000

訓練されたネットがうまく機能しなかったことに注意してくださいx=[3,5,7];

上記のテストは、トレーニングが再起動ではなく、以前のネットに基づいていることを示しています。パフォーマンスが低下する理由は、各データ チャンク (バッチ勾配降下ではなく確率的勾配降下) に対して 1 回しか実装しないため、合計誤差曲線がまだ収束しない可能性があるためです。データ チャンクが 2 つしかない場合、チャンク 2 のトレーニングが完了したらチャンク 1 を再トレーニングし、次にチャンク 2 を再トレーニングし、次にチャンク 1 というように、いくつかの条件が満たされるまで再トレーニングする必要がある場合があります。チャンクがはるかに多い場合は、1 番目のトレーニング効果に比べて 2 番目の効果について心配する必要はないかもしれません。オンライン学習は、更新された重みがそれらのパフォーマンスを損なうかどうかに関係なく、以前のデータ セットを削除するだけです。

于 2014-01-18T04:59:25.120 に答える
1

ここでは、matlab で NN を反復的にトレーニングする方法の例 (ミニ バッチ) を示します。

おもちゃのデータセットを作成するだけです

[ x,t] = building_dataset;

ミニバッチのサイズと数

M = 420 
imax = 10;

直接トレーニングとミニバッチ トレーニングを確認してみましょう

net = feedforwardnet(70,'trainscg');
dnet = feedforwardnet(70,'trainscg');

ここでの標準トレーニング: データ全体を使用した 1 回の呼び出し

dnet.trainParam.epochs=100;
[ dnet tr y ] = train( dnet, x, t ,'useGPU','only','showResources','no');

エラーの尺度:MEA、MSEまたはその他の測定が容易

dperf = mean(mean(abs(t-dnet(x))))

これは反復部分です: 呼び出しごとに 1 エポック

net.trainParam.epochs=1;
e=1;

エポック比較のために、前のメソッドエラーに到達するまで

while perf(end)>dperf

各エポックでデータをランダム化することは非常に重要です!!

    idx = randperm(size(x,2));

すべてのデータ チャンクを使用して反復的にトレーニングする

    for i=1:imax
        k = idx(1+M*(i-1) : M*i);
        [ net tr ] = train( net, x( : , k ), t( : , k ) );
    end

各エポックでのパフォーマンスを計算する

    perf(e) = mean(mean(abs(t-net(x))))
    e=e+1;
end

パフォーマンスをチェックしてください。準滑らかで exp(-x) のような曲線が必要です

plot(perf)
于 2016-03-30T15:01:54.853 に答える