TFRecords are TensorFlow’s native binary data format and is the recommended way to store your data for streaming data. Using the TFRecordReader
is also a very convenient way to subsequently get these records into your model.
The data
We will use the well known MNIST dataset for handwritten digit recognition as a sample. This is easily retrieved from tensorflow via:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(
"/tmp/tensorflow/mnist/input_data",
reshape=False
)
We then have mnist.validation
, mnist.train
and mnist.test
data sets.
Creating TFRecords
TFRecord
s contain Example
instances for each data point of which each Example
containers some Features
. To write these out to disk we use a TFRecordWriter
. Let’s use a single MNIST data sample to show an example:
image = mnist.train.images[0]
image_label = mnist.train.labels[0]
_, rows, cols, depth = mnist.train.images.shape
with tf.python_io.TFRecordWriter(filename) as writer:
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(image_label)),
'image_raw': _bytes_feature(image_raw)
}))
writer.write(example.SerializeToString())
Feature entries are protobuf instances and the TF documentation doesn’t have too many details, but linking to the feature.proto definitions, the Feature
has one of BytesList
, FloatList
or Int64List
. Here we are only using bytes and int features so creating two helper functions to simplify the record creation is handy as so:
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
The single example higlights the specific code, but here’s the entire sample to demonstrate sharding the data as well
def convert_to(data_set, name:str, data_directory:str, num_shards:int=1):
"""Convert the dataset into TFRecords on disk
Args:
data_set: The MNIST data set to convert
name: The name of the data set
data_directory: The directory where records will be stored
num_shards: The number of files on disk to separate records into
"""
num_examples, rows, cols, depth = data_set.images.shape
data_set = list(zip(data_set.images, data_set.labels))
def _process_examples(example_dataset, filename:str):
print(f'Processing {filename} data')
dataset_length = len(example_dataset)
with tf.python_io.TFRecordWriter(filename) as writer:
for index, (image, label) in enumerate(example_dataset):
sys.stdout.write(f"\rProcessing sample {index+1} of {dataset_length}")
sys.stdout.flush()
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(label)),
'image_raw': _bytes_feature(image_raw)
}))
writer.write(example.SerializeToString())
print()
if num_shards == 1:
_process_examples(data_set, _data_path(data_directory, name))
else:
sharded_dataset = np.array_split(data_set, num_shards)
for shard, dataset in enumerate(sharded_dataset):
_process_examples(dataset, _data_path(data_directory, f'{name}-{shard+1}'))
convert_to(mnist.validation, 'validation', data_directory)
convert_to(mnist.train, 'train', data_directory, num_shards=10)
convert_to(mnist.test, 'test', data_directory)
Reading the data
If you are using the recommended Dataset API, we can use the TFRecordDataset
to read in one or more TFRecord files shown in the example below. The main difference from any other use of the Dataset API is how we parse out the sample. We tell the tf.parse_single_example
what features and types we want retrieved and then get them into an appropriate format for our model.
def data_input_fn(filenames, batch_size=1000, shuffle=False):
def _parse(record):
features={
'label': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string)
}
parsed_record = tf.parse_single_example(record, features)
image = tf.decode_raw(parsed_record['image_raw'], tf.float32)
label = tf.cast(parsed_record['label'], tf.int32)
return image, label
def _input_fn():
dataset = (tf.contrib.data.TFRecordDataset(filenames)
.map(_parser))
if shuffle:
dataset = dataset.shuffle(buffer_size=10_000)
dataset = dataset.repeat(None) # Infinite iterations: let experiment determine num_epochs
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return _input_fn
And then calling this:
train_input_fn = data_input_fn(glob.glob('/path/to/data/train-*.tfrecords'), shuffle=True)
validation_input_fn = data_input_fn('/path/to/data/validation.tfrecords')
Full code available at https://github.com/damienpontifex/BlogCodeSamples/blob/master/DataToTfRecords/mnist-to-tfrecords.py