ここでPythonでMXNetの手書き数字認識を理解しようとしています
トレーニング データとラベル データを作成するコードを以下に示します。
def read_data(label_url, image_url):
with gzip.open(download_data(label_url)) as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
label = np.fromstring(flbl.read(), dtype=np.int8)
with gzip.open(download_data(image_url), 'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
return (label, image)
次に、以下のコードを使用して数値を予測します。
prob = model.predict(val_img[0:1].astype(np.float32)/255)[0]
assert max(prob) > 0.99, "Low prediction accuracy."
print 'Classified as %d with probability %f' % (prob.argmax(), max(prob))
出力は - 確率 0.999391 で 7 に分類されます。私の質問は、argmax 関数によって返されたインデックスがラベル -7 に対応することを MXNet がどのように判断できたかです。