Count the number of records in a TFRecordDataset
Fri, Nov 27, 2020
There are 2 ways to determine the length of a TFRecordDataset. Count the number of records, which causes one iteration of the data, or store the number of records in metadata. I’ll try to cover both approaches.
Iterating over the dataset to get the count can be useful when you are determining the steps_per_epoch and validation_steps parameters to model.fit() function. If your dataset is more than a few gigabytes you may want to consider the metadata approach.
See the full example at https://github.com/reasonedpenguin/tensorflow-examples
def countRecords(ds:tf.data.Dataset):
count = 0
if tf.executing_eagerly():
# TF v2 or v1 in eager mode
for r in ds:
count = count+1
else:
# TF v1 in non-eager mode
iterator = tf.compat.v1.data.make_one_shot_iterator(ds)
next_batch = iterator.get_next()
with tf.compat.v1.Session() as sess:
try:
while True:
sess.run(next_batch)
count = count+1
except tf.errors.OutOfRangeError:
pass
return count
We’ll be extending the tf.io.TFRecordWriter and tf.data.TFRecordDataset classes to keep track of the number of records in the dataset.
class TFRecordWriterExtended(tf.io.TFRecordWriter):
recordCount = 0
# filename
# writer
def __init__(self,path,options=None):
self.filename = path
tf.io.TFRecordWriter.__init__(self,path,options)
def __del__(self):
self.close()
def write(self,record):
tf.io.TFRecordWriter.write(self,record)
self.recordCount += 1
def close(self):
tf.io.TFRecordWriter.close(self)
self.writeMetadata()
def getMetadata(self):
# Add more metadata as necessary
return { 'recordCount':self.recordCount }
def writeMetadata(self):
metadataFile = self.filename + '.meta'
data = self.getMetadata()
with open(metadataFile,'w') as outfile:
json.dump(data,outfile)
You can then use the new record writer the same as before:
writer = tfds_util.TFRecordWriterExtended(filename)
for i in range(count):
... # Create feature
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
We are also going to extend the tf.data.TFRecordDataset class to add a new function recordCount()
.
class TFRecordDatasetExtended(tf.data.TFRecordDataset):
def __init__(self,filenames, compression_type=None, buffer_size=None, num_parallel_reads=None):
self.filenames = filenames
tf.data.TFRecordDataset.__init__(self,filenames, compression_type=None, buffer_size=None, num_parallel_reads=None)
def recordCount(self):
count = 0
if(isinstance(self.filenames,str)):
return self.recordCountForFile(self.filenames)
for f in self.filenames:
count += self.recordCountForFile(f)
return count
def recordCountForFile(self,filename):
metadataFile = filename + '.meta'
with open(metadataFile,'r') as metafile:
data = json.load(metafile)
return data['recordCount']
Again we can use this class the same as before, however if we need to check the number of records we can call recordCount()
.
ds = tfds_util.TFRecordDatasetExtended(filename)
count = ds.recordCount()