0

kinect から取得した深度画像を使用してシャム ネットをトレーニングしたいのですが、コントラスト損失関数を使用してこのネットワークをトレーニングしたいのですが、mxnet でコントラスト損失関数が見つかりません。私の実装は次のとおりです。

def LossFunc(distance, label, margin):
distance = distance.reshape(label.shape)

dis_positive = distance * label

dis_negative = margin - distance
zeros = nd.zeros(label.shape, ctx=ctx)
dis_negative = nd.concat(dis_negative, zeros, dim=1)
dis_negative = nd.max(dis_negative, axis=1).reshape(label.shape)
dis_negative = (1-label) * dis_negative

return 0.5 * dis_positive**2 + 0.5 * dis_negative**2

そうですか?

4

1 に答える 1

4

Gluon APIを使用した Contrastive loss の実装は次のとおりです。

class ContrastiveLoss(Loss):
    def __init__(self, margin=2.0, weight=None, batch_axis=0, **kwargs):
        super(ContrastiveLoss, self).__init__(weight, batch_axis, **kwargs)
        self.margin = margin

    def hybrid_forward(self, F, output1, output2, label):
        euclidean_distance = F.sqrt(F.square(output1 - output2))
        loss_contrastive = F.mean(((1-label) * F.square(euclidean_distance) +
                                      label * F.square(F.clip(self.margin - euclidean_distance, 0.0, 10))))
        return loss_contrastive

hereから取得したシャムネットの使用方法のPyTorchの例に基づいて実装しました。

PyTorch と MxNet にはかなりの違いがあるため、これを試してみたい場合は、実行可能な完全な例を次に示します。ただし、mxnet はそのままでは .pgm 画像の読み込みをサポートしていないため、AT&T の顔データをダウンロードし、画像を jpeg に変換する必要があります。

import matplotlib.pyplot as plt
import numpy as np
import random
from PIL import Image
import PIL.ImageOps
import mxnet as mx
from mxnet import autograd
from mxnet.base import numeric_types
from mxnet.gluon import nn, HybridBlock, Trainer
from mxnet.gluon.data import DataLoader
from mxnet.gluon.data.vision.datasets import ImageFolderDataset
from mxnet.gluon.loss import Loss


def imshow(img,text=None, should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def show_plot(iteration, loss):
    plt.plot(iteration, loss)
    plt.show()


class Config:
    training_dir = "./faces/training/"
    testing_dir = "./faces/testing/"
    train_batch_size = 5
    train_number_epochs = 100


class SiameseNetworkDataset(ImageFolderDataset):
    def __init__(self, root, transform=None):
        super().__init__(root, flag=0, transform=transform)
        self.root = root
        self.transform = transform

    def __getitem__(self, index):
        items_with_index = list(enumerate(self.items))
        img0_index, img0_tuple = random.choice(items_with_index)
        # we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0, 1)
        if should_get_same_class:
            while True:
                # keep looping till the same class image is found
                img1_index, img1_tuple = random.choice(items_with_index)
                if img0_tuple[1] == img1_tuple[1]:
                    break
        else:
            img1_index, img1_tuple = random.choice(items_with_index)

        img0 = super().__getitem__(img0_index)
        img1 = super().__getitem__(img1_index)

        return img0[0].transpose(), img1[0].transpose(), mx.nd.array(mx.nd.array([int(img1_tuple[1] != img0_tuple[1])]))

    def __len__(self):
        return super().__len__()


class ReflectionPad2D(HybridBlock):
    """Pads the input tensor using the reflection of the input boundary.
    Parameters
    ----------
    padding: int
        An integer padding size
    Shape:
        - Input: :math:`(N, C, H_{in}, W_{in})`
        - Output: :math:`(N, C, H_{out}, W_{out})` where
          :math:`H_{out} = H_{in} + 2 * padding
          :math:`W_{out} = W_{in} + 2 * padding
    """
    def __init__(self, padding=0, **kwargs):
        super(ReflectionPad2D, self).__init__(**kwargs)
        if isinstance(padding, numeric_types):
            padding = (0, 0, 0, 0, padding, padding, padding, padding)
        assert(len(padding) == 8)
        self._padding = padding

    def hybrid_forward(self, F, x, *args, **kwargs):
        return F.pad(x, mode='reflect', pad_width=self._padding)


class SiameseNetwork(HybridBlock):
    def __init__(self):
        super(SiameseNetwork, self).__init__()

        self.cnn1 = nn.HybridSequential()
        with self.cnn1.name_scope():
            self.cnn1.add(ReflectionPad2D(padding=1))
            self.cnn1.add(nn.Conv2D(in_channels=1, channels=4, kernel_size=3))
            self.cnn1.add(nn.Activation('relu'))
            self.cnn1.add(nn.BatchNorm())

            self.cnn1.add(ReflectionPad2D(padding=1))
            self.cnn1.add(nn.Conv2D(in_channels=4, channels=8, kernel_size=3))
            self.cnn1.add(nn.Activation('relu'))
            self.cnn1.add(nn.BatchNorm())

            self.cnn1.add(ReflectionPad2D(padding=1))
            self.cnn1.add(nn.Conv2D(in_channels=8, channels=8, kernel_size=3))
            self.cnn1.add(nn.Activation('relu'))
            self.cnn1.add(nn.BatchNorm())

        self.fc1 = nn.HybridSequential()
        with self.fc1.name_scope():
            self.cnn1.add(nn.Dense(500)),
            self.cnn1.add(nn.Activation('relu')),
            self.cnn1.add(nn.Dense(500)),
            self.cnn1.add(nn.Activation('relu')),
            self.cnn1.add(nn.Dense(5))

    def hybrid_forward(self, F, input1, input2):
        output1 = self._forward_once(input1)
        output2 = self._forward_once(input2)
        return output1, output2

    def _forward_once(self, x):
        output = self.cnn1(x)
        #output = output.reshape((output.shape[0],))
        output = self.fc1(output)
        return output


class ContrastiveLoss(Loss):
    def __init__(self, margin=2.0, weight=None, batch_axis=0, **kwargs):
        super(ContrastiveLoss, self).__init__(weight, batch_axis, **kwargs)
        self.margin = margin

    def hybrid_forward(self, F, output1, output2, label):
        euclidean_distance = F.sqrt(F.square(output1 - output2))
        loss_contrastive = F.mean(((1-label) * F.square(euclidean_distance) +
                                      label * F.square(F.clip(self.margin - euclidean_distance, 0.0, 10))))
        return loss_contrastive


def aug_transform(data, label):
    augs = mx.image.CreateAugmenter(data_shape=(1, 100, 100))

    for aug in augs:
        data = aug(data)

    return data, label


def run_training():
    siamese_dataset = SiameseNetworkDataset(root=Config.training_dir,transform=aug_transform)
    train_dataloader = DataLoader(siamese_dataset, shuffle=True, num_workers=1, batch_size=Config.train_batch_size)

    counter = []
    loss_history = []
    iteration_number = 0

    net = SiameseNetwork()
    net.initialize(init=mx.init.Xavier())
    trainer = Trainer(net.collect_params(), 'adam', {'learning_rate': 0.0005})
    loss = ContrastiveLoss(margin=2.0)

    for epoch in range(0, Config.train_number_epochs):
        for i, data in enumerate(train_dataloader, 0):
            img0, img1, label = data

            with autograd.record():
                output1, output2 = net(img0, img1)
                loss_contrastive = loss(output1, output2, label)
                loss_contrastive.backward()

            trainer.step(Config.train_batch_size)

            if i % 10 == 0:
                print("Epoch number {}\n Current loss {}\n".format(epoch, loss_contrastive))
                iteration_number += 10
                counter.append(iteration_number)
                loss_history.append(loss_contrastive)

    #show_plot(counter, loss_history)
    return net


def run_predict(net):
    folder_dataset_test = SiameseNetworkDataset(root=Config.testing_dir,transform=aug_transform)
    test_dataloader = DataLoader(folder_dataset_test, shuffle=True, num_workers=1, batch_size=Config.train_batch_size)

    dataiter = iter(test_dataloader)
    x0, _, _ = next(dataiter)
    _, x1, label2 = next(dataiter)
    output1, output2 = net(x0, x1)
    euclidean_distance = mx.ndarray.sqrt(mx.ndarray.square(output1 - output2))
    print('x0 vs x1 dissimilarity is {}'.format(euclidean_distance[0][0]))


if __name__ == '__main__':
    net = run_training()
    run_predict(net)
于 2018-03-07T00:26:43.877 に答える