0

tf.Saverweightbiasおよびstate変数が保存されている単純な再帰型ネットワークの例があります。

この例をオプションなしで実行すると、状態ベクトルがゼロを含むように初期化されますが、オプションを渡し、状態ベクトルの最後の値を呼び出しload_modelのフィードとして使用したいと考えています。session.run

私が目にするすべてのドキュメントは、変数から格納された値を取得するには呼び出す必要があると主張していsession.runますが、この場合、状態変数を初期化できるように値を取得したいと考えています。初期値を取得するためだけに別のグラフを作成する必要がありますか?

以下のコード例:

import tensorflow as tf
import math
import numpy as np

INPUTS = 10
HIDDEN_1 = 2
BATCH_SIZE = 3

def batch_vm2(m, x):
  [input_size, output_size] = m.get_shape().as_list()

  input_shape = tf.shape(x)
  batch_rank = input_shape.get_shape()[0].value - 1
  batch_shape = input_shape[:batch_rank]
  output_shape = tf.concat(0, [batch_shape, [output_size]])

  x = tf.reshape(x, [-1, input_size])
  y = tf.matmul(x, m)

  y = tf.reshape(y, output_shape)

  return y

def get_weight_and_biases():
    with tf.variable_scope(network_scope, reuse = True) as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
    return weights, biases

def get_saver():
    with tf.variable_scope('h1') as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
        saver = tf.train.Saver([weights, biases, state])
    return saver, scope


def load(sess, saver, checkpoint_dir = './'):

        print("loading a session")
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise Exception("no checkpoint found")
        return

iteration = None

def iterate_state(prev_state_tuple, input):
    with tf.variable_scope(network_scope, reuse = True) as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
        print("input: ",input.get_shape())
        matmuladd = batch_vm2(weights, input) + biases
        matmulpri = tf.Print(matmuladd,[matmuladd, weights], message=" malmul -> %i, weights " % iteration)
        print("prev state: ",prev_state_tuple.get_shape())
        unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
        prev_state = 0.99* unpacked_state
        prev_state = tf.Print(prev_state, [unpacked_state, matmuladd], message=" -> prevstate, matmulpri ")
        state = state.assign( prev_state + 0.01*matmulpri )
        #output = tf.nn.relu(state)
        output = tf.nn.tanh(state)
        state = tf.Print(state, [state], message=" state -> ")
        output = tf.Print(output, [output], message=" output -> ")
        print(" state: ", state.get_shape())
        print(" output: ", output.get_shape())
        concat_result = tf.concat(0,[state, output])
        print (" concat return: ", concat_result.get_shape())
        return concat_result

def data_iter():
    while True:
        idxs = np.random.rand(BATCH_SIZE, INPUTS)
        yield idxs

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean('load_model', False, 'If true, uses model files '
                     'to restore.')


network_scope = None

with tf.Graph().as_default():
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
    iteration = -1
    saver, network_scope = get_saver()
    initial_state = tf.placeholder(tf.float32, shape=(HIDDEN_1))
    initial_out = tf.zeros([HIDDEN_1],
                             name='initial_out')
    concat_tensor = tf.concat(0,[initial_state, initial_out])
    print(" init state: ",initial_state.get_shape())
    print(" init out: ",initial_out.get_shape())
    print(" concat: ",concat_tensor.get_shape())
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
    print ("scanout shape: ", scanout.get_shape())
    state, output = tf.split(1,2,scanout, name='split_scan_output')
    print(" end state: ",state.get_shape())
    print(" end out: ",output.get_shape())


    sess = tf.Session()
    # Run the Op to initialize the variables.

    sess.run(tf.initialize_all_variables())
    tf.train.write_graph(sess.graph_def, './tenIrisSave/logsd','graph.pbtxt')
    tf_weight, tf_bias = get_weight_and_biases()
    tf.histogram_summary('weights', tf_weight)
    tf.histogram_summary('bias', tf_bias)
    tf.histogram_summary('state', state)
    tf.histogram_summary('out', output)
    summary_op = tf.merge_all_summaries()
    summary_writer = tf.train.SummaryWriter('./tenIrisSave/summary',sess.graph_def)
    if FLAGS.load_model:
        load(sess, saver)
        # HOW DO I LOAD restored state values??????
        #st = state[BATCH_SIZE - 1,:]
        #st = sess.run([state], feed_dict={})
        print("LOADED last state vec: ", st)
    else:
        st = np.array([0.0 , 0.0])
    iter_ = data_iter()
    for i in xrange(0, 1):
        print ("iteration: ",i)
        iteration = i
        input_data = iter_.next()
        out,st,so,summary_str = sess.run([output,state,scanout,summary_op], feed_dict={ inputs: input_data, initial_state: st })
        saver.save(sess, 'my-model', global_step=1+i)
        summary_writer.add_summary(summary_str, i)
        summary_writer.flush()
        print("input vec: ", input_data)
        print("state vec: ", st)
        st = st[-1]
        print("last state vec: ", st)
        print("output vec: ", out)
        print(" end state (runtime): ",st.shape)
        print(" end out (runtime): ",out.shape)
        print(" end scanout (runtime): ",so.shape)

124 ~ 126 行目で、フィード ディクショナリの値を初期化しようとした方法についてコメントされた行に注意してください。それらのどれも機能しません。

4

1 に答える 1