-1

Pytorch を使用してトレーニングしたモデルを読み込もうとしていますが、次のエラーが発生し続けます。

ファイル「convert.py」、12 行目、model.load_state_dict(torch.load('model/model_vgg2d_2.pth')) 内 ファイル「/usr/local/lib/python3.5/dist-packages/torch/nn/modules」 /module.py"、490 行目、load_state_dict .format(name)) KeyError: 'state_dict の予期しないキー「module.features.0.weight」'

以下は私のコードです:

import torch.onnx
import torch.nn as nn

class TempModel(nn.Module):
    def __init__(self):
        super(TempModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 5, (3, 3))
    def forward(self, inp):
        return self.conv1(inp)

model = nn.DataParallel(TempModel())
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model_onnx/model_vgg2d_0.onnx")

モデルのトレーニングに使用したのと同じマシン (複数の GPU を搭載) で作業しています。私が間違っていることは何ですか?

4

1 に答える 1

-1

ロードするときは、同じstate_dictモデルの である必要があります。VGG モデルの をまったく異なる にロードすることはできません。state_dictstate_dictBasicModel


古い回答
モデルに適用 せずにモデルを保存しましたnn.DataParallelが、これを追加した後にロードしようとしています。試す

model = TempModel()
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
model = nn.DataParallel(model)  # parallel AFTER load
于 2018-10-23T09:05:19.913 に答える