0

私は tflearn を使って簡単なオートエンコーダーを書いています。

net = tflearn.input_data (shape=[None, train.shape [1]])   
net = tflearn.fully_connected (net, 500, activation  = 'tanh', regularizer = None, name = 'fc_en_1')

#hidden state
net = tflearn.fully_connected (net, 100, activation  = 'tanh', regularizer = 'L1', name = 'fc_en_2', weight_decay = 0.0001)    

net = tflearn.fully_connected (net, 500, activation  = 'tanh', regularizer = None, name = 'fc_de_1')    
net = tflearn.fully_connected (net, train.shape [1], activation  = 'linear', name = 'fc_de_2')       
net = tflearn.regression(net, optimizer='adam', learning_rate=0.01, loss='mean_square', metric='default')

model = tflearn.DNN (net)

モデルは適切にトレーニングされていますが、トレーニング後、エンコーダーとデコーダーを別々に使用したいと考えています。

どうすればいいですか?現在、入力を復元できます。入力を非表示表現に変換し、任意の非表示表現から入力を復元できるようにしたいと考えています。

4

2 に答える 2

1

エンコーダーとデコーダーの入力/出力の名前を保存するだけです。

つまり(INPUT、HIDDEN_STATE、OUTPUTを追加):

net = tflearn.input_data (shape=[None, train.shape [1]])   
INPUT = net
net = tflearn.fully_connected (net, 500, activation  = 'tanh', regularizer = None, name = 'fc_en_1')

#hidden state
net = tflearn.fully_connected (net, 100, activation  = 'tanh', regularizer = 'L1', name = 'fc_en_2', weight_decay = 0.0001)    
HIDDEN_STATE = net

net = tflearn.fully_connected (net, 500, activation  = 'tanh', regularizer = None, name = 'fc_de_1')    
net = tflearn.fully_connected (net, train.shape [1], activation  = 'linear', name = 'fc_de_2')  
OUTPUT = net     
net = tflearn.regression(net, optimizer='adam', learning_rate=0.01, loss='mean_square', metric='default')

model = tflearn.DNN (net)

そして、そのような関数を使用してエンコード/デコードします。

def encode (X):    
    if len (X.shape) < 2:
        X = X.reshape (1, -1)

    tflearn.is_training (False, model.session)
    res = model.session.run (HIDDEN_STATE, feed_dict={INPUT.name:X})    
    return res    

def decode (X):
    if len (X.shape) < 2:
        X = X.reshape (1, -1)

    #just to pass something to place_holder
    zeros = np.zeros ((X.shape [0], train.shape [1]))

    tflearn.is_training (False, model.session)
    res = model.session.run (OUTPUT, feed_dict={INPUT.name:zeros, HIDDEN_STATE.name:X})    
    return res
于 2016-05-06T21:02:19.430 に答える