Inspecting TFRecord files and debugging TensorFlow data input

TFrecord files are TensorFlow’s suggested data format, although they are very difficult to inspect given their binary nature. Inspecting the contents of existing record files and ensuring the data in your input pipeline is as you expect is a good technique to have.

Inspecting TFRecord values

The first trick is reading in the tfrecord files and inspecting their values in python. As you’d expect, the TensorFlow API allows this (although a little hidden down). The small code snippet below highlights using the tf.python_io.tf_record_iterator to inspect ’examples’ in your record file. Replace the ’label’ or ’text_label’ as appropriate for your features, but it shows you can dot access into the property values

import tensorflow as tf

for example in tf.python_io.tf_record_iterator("data/train-1.tfrecord"):
  result = tf.train.Example.FromString(example)
  print(result.features.feature['label'].int64_list.value)
  print(result.features.feature['text_label'].bytes_list.value)

Debugging Dataset input functions

The tf.data module release in TensorFlow 1.4 simplifies the handling of data input, but like any software component it’s great to know how to test it in isolation. I follow the TF Experiment patter of having an input_fn that returns a features, labels iterator tuple. An example below shows the function that constructs the input_fn to read and map our tfrecord files

def tfrec_data_input_fn(filenames, num_epochs=1, batch_size=64, shuffle=False):

    def _input_fn():
        def _parse_record(tf_record):
            features = {
                'image': tf.FixedLenFeature([], dtype=tf.string),
                'label': tf.FixedLenFeature([], dtype=tf.int64)
            }
            record = tf.parse_single_example(tf_record, features)

            image_raw = tf.decode_raw(record['image'], tf.float32)
            image_raw = tf.reshape(image_raw, shape=(224, 224, 3))

            label = tf.cast(record['label'], tf.int32)
            label = tf.one_hot(label, depth=2)

            return { 'image': image_raw }, label

        # For TF dataset blog post, see https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html
        dataset = tf.data.TFRecordDataset(filenames)
        dataset = dataset.map(_parse_record)

        if shuffle:
            dataset = dataset.shuffle(buffer_size=256)

        dataset = dataset.repeat(num_epochs)
        dataset = dataset.batch(batch_size)

        iterator = dataset.make_one_shot_iterator()
        features, labels = iterator.get_next()

        return features, labels

    return _input_fn

This function takes filenames for TFRecord files and returns a callable. Below we’ll call this and run this through a typical session to ensure our data features are coming through and in the shape we expect. This is handy to be able to run this in a short script rather than debugging when it’s in the middle of the model

tfrec_dev_input_fn = tfrec_data_input_fn(["data/train-1.tfrecord"], batch_size=3)
features, labels = tfrec_dev_input_fn()

with tf.Session() as sess:
  img, label = sess.run([features['image'], labels])
  print(img.shape, label.shape)

  # Loop over each example in batch
  for i in range(img.shape[0]):
    plt.imshow(img[i])
    plt.show()
    print('Class label ' + str(np.argmax(label[i])))

Conclusion

It can be frustrating to try and find why your model isn’t working as expected, but hopefully these two techniques above will eliminate these frustrations from your data input pipeline or your tfrecord files.

Licensed under CC BY-NC-SA 4.0
Last updated on Jan 01, 0001 00:00 UTC
Built with Hugo
Theme Stack designed by Jimmy