tf.data.TFRecordDataset Count

Count the number of records in a TFRecordDataset

Fri, Nov 27, 2020

There are 2 ways to determine the length of a TFRecordDataset. Count the number of records, which causes one iteration of the data, or store the number of records in metadata. I’ll try to cover both approaches.

Iterate over the dataset

Iterating over the dataset to get the count can be useful when you are determining the steps_per_epoch and validation_steps parameters to model.fit() function. If your dataset is more than a few gigabytes you may want to consider the metadata approach.
See the full example at https://github.com/reasonedpenguin/tensorflow-examples

def countRecords(ds:tf.data.Dataset):
  count = 0

  if tf.executing_eagerly():
    # TF v2 or v1 in eager mode
    for r in ds:
      count = count+1
  else:  
    # TF v1 in non-eager mode
    iterator = tf.compat.v1.data.make_one_shot_iterator(ds)
    next_batch = iterator.get_next()
    with tf.compat.v1.Session() as sess:
      try:
        while True:
          sess.run(next_batch)
          count = count+1    
      except tf.errors.OutOfRangeError:
        pass
  
  return count

Save the metadata

We’ll be extending the tf.io.TFRecordWriter and tf.data.TFRecordDataset classes to keep track of the number of records in the dataset.

class TFRecordWriterExtended(tf.io.TFRecordWriter):
  recordCount = 0
  # filename
  # writer

  def __init__(self,path,options=None):
    self.filename = path
    tf.io.TFRecordWriter.__init__(self,path,options)
  
  def __del__(self):
    self.close()

  def write(self,record):
    tf.io.TFRecordWriter.write(self,record)
    self.recordCount += 1

  def close(self):
    tf.io.TFRecordWriter.close(self)
    self.writeMetadata()

  def getMetadata(self):
    # Add more metadata as necessary
    return { 'recordCount':self.recordCount }

  def writeMetadata(self):
    metadataFile = self.filename + '.meta'
    data = self.getMetadata()
    with open(metadataFile,'w') as outfile:
      json.dump(data,outfile)

You can then use the new record writer the same as before:

writer = tfds_util.TFRecordWriterExtended(filename)
for i in range(count):
  ... # Create feature
  example = tf.train.Example(features=tf.train.Features(feature=feature))
          
  # Serialize to string and write on the file
  writer.write(example.SerializeToString())
writer.close()

We are also going to extend the tf.data.TFRecordDataset class to add a new function recordCount().

class TFRecordDatasetExtended(tf.data.TFRecordDataset):

  def __init__(self,filenames, compression_type=None, buffer_size=None, num_parallel_reads=None):
    self.filenames = filenames
    tf.data.TFRecordDataset.__init__(self,filenames, compression_type=None, buffer_size=None, num_parallel_reads=None)

  def recordCount(self):
    count = 0
    if(isinstance(self.filenames,str)):
      return self.recordCountForFile(self.filenames)
    for f in self.filenames:
      count += self.recordCountForFile(f)
    return count

  def recordCountForFile(self,filename):
    metadataFile = filename + '.meta'
    with open(metadataFile,'r') as metafile:
      data = json.load(metafile)
      return data['recordCount']

Again we can use this class the same as before, however if we need to check the number of records we can call recordCount().

ds = tfds_util.TFRecordDatasetExtended(filename)
count = ds.recordCount()