1

次のように、x の戻り値であるデータの INDArray があります。

    private static INDArray createDataSet(String path)throws Exception {

    List<String> lines = IOUtils.readLines(new FileInputStream(path), StandardCharsets.UTF_8);

    double[] position = new double[lines.size()];
    double[] year = new double[lines.size()];
    double[] month = new double[lines.size()];
    double[] day = new double[lines.size()];
    double[] close = new double[lines.size()];

    int linecount = 0;
    Iterator<String> it = lines.iterator();
    while(it.hasNext()) {
        String line = it.next();

            String[] parts = line.split(",");

            position[linecount] = linecount;
            year[linecount] =  Double.valueOf(parts[0]);
            month[linecount] =  Double.valueOf(parts[1]);
            day[linecount] =  Double.valueOf(parts[2]);
            close[linecount] = Double.valueOf(parts[5]);

            linecount++;
    }//endloop


    double[][] arr2D = new double[][] {position, year, month, day, close};
    INDArray x = Nd4j.createFromArray(arr2D);

    return x;

}

csvplotter の例をコピーして、単一の入出力ネットワークで線形回帰を実行しようとしています。

配列行(0)を機能として、配列行(4)をラベルとしてロードするにはどうすればよいですか?

もう少し情報:

    System.out.println(ds.rank());
    long[] l  = ds.shape();
    System.out.println(l[0] + " , " + l[1] + "  -  " + l.length);
    System.out.println(ds.length());

結果:

2,
5, 1260 -2
6300

明確にするために、ここに私の問題があります:

       for (int i = 0; i < nEpochs; i++) {

       net.fit(d);
    }

データを追加しようとする方法に応じて、さまざまなエラーが発生します

4

1 に答える 1

0

答えは出ていませんが、自分の問題に気づきました。csv プロッターの例のコメントに基づいて、indarray の行が入力に配信されると仮定しました。ただし、実際に入力に配信されるのは列です。

INDArray を転置して 2 つの列を追加することで、ネットワークがデータを処理する必要がありました。

INDArray ds;
     ds = ds.transpose();
        DataSet ddd = new DataSet();
        ddd.setFeatures(ds.getColumn(0, true)); //true maintains matrix instead of vector
        ddd.setLabels(ds.getColumn(4, true));
        ddd.dataSetBatches(500);
        System.out.println(ddd);

私の印刷物:

===========INPUT===================
[[0], 
 [1.0000], 
 [2.0000], 
  ..., 
 [1257.0000], 
 [1258.0000], 
 [1259.0000]]
=================OUTPUT==================
[[540.3100], 
 [536.7000], 
 [533.3300], 
  ..., 
 [1431.8199], 
 [1439.2200], 
 [1436.3800]]

トレーニングは失敗しましたが、これは私の元の質問に答えています。

于 2020-06-04T21:59:50.360 に答える