How to prefetch data using a custom python function in tensorflow

This is a common use case, and most implementations use TensorFlow’s queues to decouple the preprocessing code from the training code. There is a tutorial on how to use queues, but the main steps are as follows:

  1. Define a queue, q, that will buffer the preprocessed data. TensorFlow supports the simple tf.FIFOQueue that produces elements in the order they were enqueued, and the more advanced tf.RandomShuffleQueue that produces elements in a random order. A queue element is a tuple of one or more tensors (which can have different types and shapes). All queues support single-element (enqueue, dequeue) and batch (enqueue_many, dequeue_many) operations, but to use the batch operations you must specify the shapes of each tensor in a queue element when constructing the queue.

  2. Build a subgraph that enqueues preprocessed elements into the queue. One way to do this would be to define some tf.placeholder() ops for tensors corresponding to a single input example, then pass them to q.enqueue(). (If your preprocessing produces a batch at once, you should use q.enqueue_many() instead.) You might also include TensorFlow ops in this subgraph.

  3. Build a subgraph that performs training. This will look like a regular TensorFlow graph, but will get its input by calling q.dequeue_many(BATCH_SIZE).

  4. Start your session.

  5. Create one or more threads that execute your preprocessing logic, then execute the enqueue op, feeding in the preprocessed data. You may find the tf.train.Coordinator and tf.train.QueueRunner utility classes useful for this.

  6. Run your training graph (optimizer, etc.) as normal.

EDIT: Here’s a simple load_and_enqueue() function and code fragment to get you started:

# Features are length-100 vectors of floats
feature_input = tf.placeholder(tf.float32, shape=[100])
# Labels are scalar integers.
label_input = tf.placeholder(tf.int32, shape=[])

# Alternatively, could do:
# feature_batch_input = tf.placeholder(tf.float32, shape=[None, 100])
# label_batch_input = tf.placeholder(tf.int32, shape=[None])

q = tf.FIFOQueue(100, [tf.float32, tf.int32], shapes=[[100], []])
enqueue_op = q.enqueue([feature_input, label_input])

# For batch input, do:
# enqueue_op = q.enqueue_many([feature_batch_input, label_batch_input])

feature_batch, label_batch = q.dequeue_many(BATCH_SIZE)
# Build rest of model taking label_batch, feature_batch as input.
# [...]
train_op = ...

sess = tf.Session()

def load_and_enqueue():
  with open(...) as feature_file, open(...) as label_file:
    while True:
      feature_array = numpy.fromfile(feature_file, numpy.float32, 100)
      if not feature_array:
        return
      label_value = numpy.fromfile(feature_file, numpy.int32, 1)[0]

      sess.run(enqueue_op, feed_dict={feature_input: feature_array,
                                      label_input: label_value})

# Start a thread to enqueue data asynchronously, and hide I/O latency.
t = threading.Thread(target=load_and_enqueue)
t.start()

for _ in range(TRAINING_EPOCHS):
  sess.run(train_op)

Leave a Comment