How do I set TensorFlow RNN state when state_is_tuple=True?

One problem with a Tensorflow placeholder is that you can only feed it with a Python list or Numpy array (I think). So you can’t save the state between runs in tuples of LSTMStateTuple.

I solved this by saving the state in a tensor like this

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

You have two components in an LSTM layer, the cell state and hidden state, thats what the “2” comes from. (this article is great: https://arxiv.org/pdf/1506.00019.pdf)

When building the graph you unpack and create the tuple state like this:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

Then you get the new state the usual way

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

It shouldn’t be like this… perhaps they are working on a solution.

Leave a Comment