torch7で実装されたネットワークを高速化しようとしていますが、 nn.DataParallelTableを使用しようとするとエラーが発生します。これは私がやろうとしていることです:
m1, m2 = createModel(8,48), createModel(8,48)
--8 # of GPUs, 48 hidden unit in the last layer
m2:share(m1,'weight', 'bias') ----THE ERROR IS HERE
prl = nn.ParallelTable()
prl:add(m1)
prl:add(m2)
prl:cuda()
mlp = nn.Sequential()
mlp:add(prl)
mlp:cuda()
crit = nn.CosineEmbeddingCriterion():cuda()
関数は次のとおりです。
function createModel(nGPU,bot)
local features = nn.Concat(2)
local fb1 = nn.Sequential() -- branch 1
fb1:add(nn.SpatialConvolution(1,48,3,3,1,1,1,1))
fb1:add(nn.ReLU(true))
fb1:add(nn.SpatialConvolution(48,128,3,3,1,1,1,1))
fb1:add(nn.ReLU(true))
fb1:add(nn.SpatialConvolution(128,192,3,3,1,1,1,1))
fb1:add(nn.ReLU(true))
fb1:add(nn.SpatialConvolution(192,192,3,3,1,1,1,1))
fb1:add(nn.ReLU(true))
fb1:add(nn.SpatialConvolution(192,128,3,3,1,1,1,1))
fb1:add(nn.ReLU(true))
fb1:add(nn.SpatialMaxPooling(2,2,2,2))
view = 12
local fb2 = fb1:clone() -- branch 2
for k,v in ipairs(fb2:findModules('nn.SpatialConvolution')) do
v:reset() -- reset branch 2's weights
end
features:add(fb1) features:add(fb2) features:cuda()
--------------the error is at this line-----------
features = makeDataParallel(features, nGPU)
local classifier = nn.Sequential()
classifier:add(nn.View(256viewview))
classifier:add(nn.Dropout(0.5))
classifier:add(nn.Linear(256viewview, 4096))
classifier:add(nn.Dropout(0.5))
classifier:add(nn.Linear(4096, 4096))
classifier:add(nn.Tanh())
classifier:add(nn.Linear(4096, bot))
classifier:add(nn.Tanh())
classifier:cuda()
local model = nn.Sequential():add(features):add(classifier)
return model
end
もう1つは次のとおりです。
function makeDataParallel(model, nGPU)
if nGPU > 1 then
print('converting module to nn.DataParallelTable')
assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified')
local model_single = model
model = nn.DataParallelTable(1)
for i=1, nGPU do
cutorch.setDevice(i)
model:add(model_single:clone():cuda(), i)
end
end
cutorch.setDevice(1)
return model
end
私が得るエラーは次のとおりです。
[C]: in function 'error'
...a/torch/install/share/lua/5.1/cunn/DataParallelTable.lua:337: in function 'share'
/home/andrea/torch/install/share/lua/5.1/nn/Container.lua:97: in function 'share'
main.lua:123: in main chunk
[C]: at 0x00406670
エラーがどこにあるのか、おそらくわかりますか?申し訳ありませんが、私はこれが初めてで、それを理解する方法が見つかりません。もちろん、私はネット構造が間違っていると考えています。前もって感謝します。