How to extract data/labels back from TensorFlow dataset

In case your tf.data.Dataset is batched, the following code will retrieve all the y labels:

y = np.concatenate([y for x, y in ds], axis=0)

Quick explanation: [y for x, y in ds] is known as “list comprehension” in python. If dataset is batched, this expression will loop thru each batch and put each batch y (a TF 1D tensor) in the list, and return it. Then, np.concatenate will take this list of 1-D tensor (implicitly casting to numpy) and stack it in the 0-axis to produce a single long vector. In summary, it is just converting a bunch of 1-d little vector into one long vector.

Note: if your y is more complex, this answer will need some minor modification.

Leave a Comment