TFrecord files are TensorFlow’s suggested data format, although they are very difficult to inspect given their binary nature. Inspecting the contents of existing record files and ensuring the data in your input pipeline is as you expect is a good technique to have.
Inspecting TFRecord values
The first trick is reading in the tfrecord files and inspecting their values in python. As you’d expect, the TensorFlow API allows this (although a little hidden down). The small code snippet below highlights using the tf.python_io.tf_record_iterator
to inspect ’examples’ in your record file. Replace the ’label’ or ’text_label’ as appropriate for your features, but it shows you can dot access into the property values
import tensorflow as tf
for example in tf.python_io.tf_record_iterator("data/train-1.tfrecord"):
result = tf.train.Example.FromString(example)
print(result.features.feature['label'].int64_list.value)
print(result.features.feature['text_label'].bytes_list.value)
Debugging Dataset input functions
The tf.data
module release in TensorFlow 1.4 simplifies the handling of data input, but like any software component it’s great to know how to test it in isolation. I follow the TF Experiment patter of having an input_fn
that returns a features, labels
iterator tuple. An example below shows the function that constructs the input_fn to read and map our tfrecord files
def tfrec_data_input_fn(filenames, num_epochs=1, batch_size=64, shuffle=False):
def _input_fn():
def _parse_record(tf_record):
features = {
'image': tf.FixedLenFeature([], dtype=tf.string),
'label': tf.FixedLenFeature([], dtype=tf.int64)
}
record = tf.parse_single_example(tf_record, features)
image_raw = tf.decode_raw(record['image'], tf.float32)
image_raw = tf.reshape(image_raw, shape=(224, 224, 3))
label = tf.cast(record['label'], tf.int32)
label = tf.one_hot(label, depth=2)
return { 'image': image_raw }, label
# For TF dataset blog post, see https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_record)
if shuffle:
dataset = dataset.shuffle(buffer_size=256)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return _input_fn
This function takes filenames for TFRecord files and returns a callable. Below we’ll call this and run this through a typical session to ensure our data features are coming through and in the shape we expect. This is handy to be able to run this in a short script rather than debugging when it’s in the middle of the model
tfrec_dev_input_fn = tfrec_data_input_fn(["data/train-1.tfrecord"], batch_size=3)
features, labels = tfrec_dev_input_fn()
with tf.Session() as sess:
img, label = sess.run([features['image'], labels])
print(img.shape, label.shape)
# Loop over each example in batch
for i in range(img.shape[0]):
plt.imshow(img[i])
plt.show()
print('Class label ' + str(np.argmax(label[i])))
Conclusion
It can be frustrating to try and find why your model isn’t working as expected, but hopefully these two techniques above will eliminate these frustrations from your data input pipeline or your tfrecord files.