0

私はラザニアの Conv3DDNNLayer を使用しており、(N x 1 x 9 x 9 x 9) の入力次元を持っています。ここで、各 9x9x9 キューブは分類対象のサンプルです。

したがって、各エントリがキューブに対応する (N x 1) のターゲット ディメンションがあります。これはエラーを引き起こしています:

 Bad input argument to theano function with name "Conv_Net_1.py:45"  at index 1(0-based)', 'Wrong number of dimensions: expected 1,
 got 2 with shape (324640, 1).')´

この場合、どのディメンションにターゲットを設定する必要がありますか?

 11 dtensor5 = TensorType('float32', (False,)*5)
 12 input_var = dtensor5('X_Train')
 13 target_var = T.ivector('Y_train')
 14 
 15 X_train, Y_train = DP.data_gen( '/home/Upload/Smalls', 9)

 16 print X_train.shape
 17 print Y_train.shape

 18 # Build Neural Network:
 19 input = lasagne.layers.InputLayer((None, 1, 9, 9, 9), input_var=input_var)
 20 
 21 l_conv_1 = lasagne.layers.dnn.Conv3DDNNLayer(input, 20, (2,2,2))
 22 
 29 l_hidden1 = lasagne.layers.DenseLayer(l_conv_1, num_units=256,nonlinearity=lasagne.nonlinearities.rectify,W=l    asagne.init.HeNormal(gain='relu'))
 30 
 31 l_hidden1_dropout = lasagne.layers.DropoutLayer(l_hidden1, p=0.5)
 32 
 33 output = lasagne.layers.DenseLayer(l_hidden1_dropout, num_units=2, nonlinearity = lasagne.nonlinearities.soft    max) 
 34 
 35 ##
 36 prediction = lasagne.layers.get_output(output)
 37 loss = T.mean(lasagne.objectives.categorical_crossentropy(prediction, target_var)
 39 
 40 # Get list of all trainable parameters in the network.
 41 params = lasagne.layers.get_all_params(output, trainable=True)
 42 updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.01, momentum=0.3)
 43 
 44 ##
 45 train_fn = theano.function([input_var, target_var], loss, updates=updates)
 46 
 47 ##
 48 for epoch in range(500):
 49     print('training')
 50     loss = train_fn(X_train, Y_train)
 51     print(loss.type)
 52     print("Epoch %d: Loss %g" % (epoch + 1, loss))
 53 
 54 
 55 ##
 56 test_prediction = lasagne.layers.get_output(output, deterministic=True)
 57 predict_fn = theano.function([input_var], T.argmax(test_prediction, axis=1))

編集 - コードを追加

ありがとう!

4

1 に答える 1

0

興味のある方のために説明すると、それはデータが (N, 1) ではなく (N, ) だったからです。

問題が解決したようです!- 次へ..

于 2016-03-23T10:18:47.737 に答える