TFRecord et tf.train.Example

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Le format TFRecord est un format simple pour stocker une séquence d'enregistrements binaires.

Les tampons de protocole sont une bibliothèque multiplateforme et multilangage pour une sérialisation efficace des données structurées.

Les messages de protocole sont définis par des fichiers .proto , ceux-ci sont souvent le moyen le plus simple de comprendre un type de message.

Le message tf.train.Example (ou protobuf) est un type de message flexible qui représente un mappage {"string": value} . Il est conçu pour être utilisé avec TensorFlow et est utilisé dans les API de niveau supérieur telles que TFX .

Ce bloc-notes montre comment créer, analyser et utiliser le message tf.train.Example , puis sérialiser, écrire et lire des messages tf.train.Example vers et depuis des fichiers .tfrecord .

Installer

import tensorflow as tf

import numpy as np
import IPython.display as display

tf.train.Example

Types de données pour tf.train.Example

Fondamentalement, un tf.train.Example est un mappage {"string": tf.train.Feature} .

Le type de message tf.train.Feature peut accepter l'un des trois types suivants (voir le fichier .proto pour référence). La plupart des autres types génériques peuvent être contraints à l'un de ceux-ci :

  1. tf.train.BytesList (les types suivants peuvent être forcés)

    • string
    • byte
  2. tf.train.FloatList (les types suivants peuvent être forcés)

    • float ( float32 )
    • double ( float64 )
  3. tf.train.Int64List (les types suivants peuvent être forcés)

    • bool
    • enum
    • int32
    • uint32
    • int64
    • uint64

Afin de convertir un type TensorFlow standard en un tf.train.Example -compatible tf.train.Feature , vous pouvez utiliser les fonctions de raccourci ci-dessous. Notez que chaque fonction prend une valeur d'entrée scalaire et renvoie un tf.train.Feature contenant l'un des trois types de list ci-dessus :

# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

Vous trouverez ci-dessous quelques exemples du fonctionnement de ces fonctions. Notez les différents types d'entrée et les types de sortie normalisés. Si le type d'entrée d'une fonction ne correspond pas à l'un des types coercibles indiqués ci-dessus, la fonction lèvera une exception (par exemple _int64_feature(1.0) une erreur car 1.0 est un flottant - par conséquent, il doit être utilisé avec la fonction _float_feature à la place ):

print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))

print(_float_feature(np.exp(1)))

print(_int64_feature(True))
print(_int64_feature(1))
bytes_list {
  value: "test_string"
}

bytes_list {
  value: "test_bytes"
}

float_list {
  value: 2.7182817459106445
}

int64_list {
  value: 1
}

int64_list {
  value: 1
}

Tous les messages proto peuvent être sérialisés en une chaîne binaire à l'aide de la méthode .SerializeToString :

feature = _float_feature(np.exp(1))

feature.SerializeToString()
b'\x12\x06\n\x04T\xf8-@'

Création d'un message tf.train.Example

Supposons que vous souhaitiez créer un message tf.train.Example à partir de données existantes. En pratique, le jeu de données peut provenir de n'importe où, mais la procédure de création du message tf.train.Example à partir d'une seule observation sera la même :

  1. Dans chaque observation, chaque valeur doit être convertie en un tf.train.Feature contenant l'un des 3 types compatibles, en utilisant l'une des fonctions ci-dessus.

  2. Vous créez une carte (dictionnaire) à partir de la chaîne du nom de l'entité vers la valeur de l'entité encodée produite en #1.

  3. La carte produite à l'étape 2 est convertie en un Features d'entités.

Dans ce notebook, vous allez créer un ensemble de données à l'aide de NumPy.

Ce jeu de données aura 4 fonctionnalités :

  • une caractéristique booléenne, False ou True avec une probabilité égale
  • une caractéristique entière uniformément choisie au hasard parmi [0, 5]
  • une caractéristique de chaîne générée à partir d'une table de chaînes en utilisant la caractéristique d'entier comme index
  • une fonction flottante à partir d'une distribution normale standard

Considérons un échantillon composé de 10 000 observations distribuées de manière indépendante et identique à partir de chacune des distributions ci-dessus :

# The number of observations in the dataset.
n_observations = int(1e4)

# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)

# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)

# String feature.
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# Float feature, from a standard normal distribution.
feature3 = np.random.randn(n_observations)

Chacune de ces fonctionnalités peut être convertie en un type compatible avec tf.train.Example en utilisant l'un des _bytes_feature , _float_feature , _int64_feature . Vous pouvez ensuite créer un message tf.train.Example à partir de ces fonctionnalités encodées :

def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.train.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.train.Example-compatible
  # data type.
  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

Par exemple, supposons que vous ayez une seule observation de l'ensemble de données, [False, 4, bytes('goat'), 0.9876] . Vous pouvez créer et imprimer le message tf.train.Example pour cette observation en utilisant create_message() . Chaque observation unique sera écrite sous la forme d'un message de Features conformément à ce qui précède. Notez que le message tf.train.Example n'est qu'un wrapper autour du message Features :

# This is an example observation from the dataset.

example_observation = []

serialized_example = serialize_example(False, 4, b'goat', 0.9876)
serialized_example
b'\nR\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?'

Pour décoder le message, utilisez la méthode tf.train.Example.FromString .

example_proto = tf.train.Example.FromString(serialized_example)
example_proto
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}

Détails du format TFRecords

Un fichier TFRecord contient une séquence d'enregistrements. Le fichier ne peut être lu que séquentiellement.

Chaque enregistrement contient une chaîne d'octets, pour la charge utile des données, plus la longueur des données, et des hachages CRC-32C (CRC 32 bits utilisant le polynôme de Castagnoli ) pour la vérification de l'intégrité.

Chaque enregistrement est stocké dans les formats suivants :

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

Les enregistrements sont concaténés pour produire le fichier. Les CRC sont décrits ici , et le masque d'un CRC est :

masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul

Fichiers TFRecord utilisant tf.data

Le module tf.data fournit également des outils pour lire et écrire des données dans TensorFlow.

Ecriture d'un fichier TFRecord

Le moyen le plus simple d'intégrer les données dans un jeu de données consiste à utiliser la méthode from_tensor_slices .

Appliqué à un tableau, il renvoie un ensemble de données de scalaires :

tf.data.Dataset.from_tensor_slices(feature1)
<TensorSliceDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>

Appliqué à un tuple de tableaux, il renvoie un jeu de données de tuples :

features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
<TensorSliceDataset element_spec=(TensorSpec(shape=(), dtype=tf.bool, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None), TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.float64, name=None))>
# Use `take(1)` to only pull one example from the dataset.
for f0,f1,f2,f3 in features_dataset.take(1):
  print(f0)
  print(f1)
  print(f2)
  print(f3)
tf.Tensor(False, shape=(), dtype=bool)
tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(b'goat', shape=(), dtype=string)
tf.Tensor(0.5251196235602504, shape=(), dtype=float64)

Utilisez la méthode tf.data.Dataset.map pour appliquer une fonction à chaque élément d'un Dataset .

La fonction mappée doit fonctionner en mode graphique TensorFlow. Elle doit fonctionner sur et renvoyer tf.Tensors . Une fonction non tenseur, comme serialize_example , peut être enveloppée avec tf.py_function pour la rendre compatible.

L'utilisation de tf.py_function nécessite de spécifier les informations de forme et de type qui ne sont pas disponibles autrement :

def tf_serialize_example(f0,f1,f2,f3):
  tf_string = tf.py_function(
    serialize_example,
    (f0, f1, f2, f3),  # Pass these args to the above function.
    tf.string)      # The return type is `tf.string`.
  return tf.reshape(tf_string, ()) # The result is a scalar.
tf_serialize_example(f0, f1, f2, f3)
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04=n\x06?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04'>

Appliquez cette fonction à chaque élément de l'ensemble de données :

serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset
<MapDataset element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>
def generator():
  for features in features_dataset:
    yield serialize_example(*features)
serialized_features_dataset = tf.data.Dataset.from_generator(
    generator, output_types=tf.string, output_shapes=())
serialized_features_dataset
<FlatMapDataset element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

Et écrivez-les dans un fichier TFRecord :

filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)
WARNING:tensorflow:From /tmp/ipykernel_25215/3575438268.py:2: TFRecordWriter.__init__ (from tensorflow.python.data.experimental.ops.writers) is deprecated and will be removed in a future version.
Instructions for updating:
To write TFRecords to disk, use `tf.io.TFRecordWriter`. To save and load the contents of a dataset, use `tf.data.experimental.save` and `tf.data.experimental.load`

Lire un fichier TFRecord

Vous pouvez également lire le fichier TFRecord à l'aide de la classe tf.data.TFRecordDataset .

Vous trouverez plus d'informations sur l'utilisation des fichiers TFRecord à l'aide tf.data dans le guide tf.data : Build TensorFlow input pipelines .

L'utilisation TFRecordDataset s peut être utile pour normaliser les données d'entrée et optimiser les performances.

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

À ce stade, l'ensemble de données contient des messages tf.train.Example sérialisés. Lorsqu'il est itéré, il les renvoie sous forme de tenseurs de chaîne scalaires.

Utilisez la méthode .take pour afficher uniquement les 10 premiers enregistrements.

for raw_record in raw_dataset.take(10):
  print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04=n\x06?'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x9d\xfa\x98\xbe\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04a\xc0r?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x92Q(?'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04>\xc0\xe5>\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04I!\xde\xbe\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xe0\x1a\xab\xbf\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x87\xb2\xd7?\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04n\xe19>\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x1as\xd9\xbf\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'>

Ces tenseurs peuvent être analysés à l'aide de la fonction ci-dessous. Notez que la feature_description est nécessaire ici car les tf.data.Dataset utilisent l'exécution de graphes et ont besoin de cette description pour construire leur forme et leur signature de type :

# Create a description of the features.
feature_description = {
    'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}

def _parse_function(example_proto):
  # Parse the input `tf.train.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

Vous pouvez également utiliser tf.parse example pour analyser l'ensemble du lot en une seule fois. Appliquez cette fonction à chaque élément du jeu de données à l'aide de la méthode tf.data.Dataset.map :

parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
<MapDataset element_spec={'feature0': TensorSpec(shape=(), dtype=tf.int64, name=None), 'feature1': TensorSpec(shape=(), dtype=tf.int64, name=None), 'feature2': TensorSpec(shape=(), dtype=tf.string, name=None), 'feature3': TensorSpec(shape=(), dtype=tf.float32, name=None)}>

Utilisez une exécution hâtive pour afficher les observations dans le jeu de données. Il y a 10 000 observations dans ce jeu de données, mais vous n'afficherez que les 10 premières. Les données sont affichées sous la forme d'un dictionnaire d'entités. Chaque élément est un tf.Tensor , et l'élément numpy de ce tenseur affiche la valeur de la fonctionnalité :

for parsed_record in parsed_dataset.take(10):
  print(repr(parsed_record))
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.5251196>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.29878703>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.94824797>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.65749466>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.44873232>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.4338477>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.3367577>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=1.6851357>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.18152401>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.6988251>}

Ici, la fonction tf.parse_example décompresse les champs tf.train.Example en tenseurs standard.

Fichiers TFRecord en Python

Le module tf.io contient également des fonctions en pur Python pour lire et écrire des fichiers TFRecord.

Ecriture d'un fichier TFRecord

Ensuite, écrivez les 10 000 observations dans le fichier test.tfrecord . Chaque observation est convertie en un message tf.train.Example , puis écrite dans un fichier. Vous pouvez ensuite vérifier que le fichier test.tfrecord a bien été créé :

# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)
du -sh {filename}
984K    test.tfrecord

Lire un fichier TFRecord

Ces tenseurs sérialisés peuvent être facilement analysés à l'aide tf.train.Example.ParseFromString :

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>
for raw_record in raw_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  print(example)
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.5251196026802063
      }
    }
  }
}

Cela renvoie un proto tf.train.Example qui est difficile à utiliser tel quel, mais c'est fondamentalement une représentation de a:

Dict[str,
     Union[List[float],
           List[int],
           List[str]]]

Le code suivant convertit manuellement l' Example en un dictionnaire de tableaux NumPy, sans utiliser TensorFlow Ops. Reportez-vous au fichier PROTO pour les détails.

result = {}
# example.features.feature is the dictionary
for key, feature in example.features.feature.items():
  # The values are the Feature objects which contain a `kind` which contains:
  # one of three fields: bytes_list, float_list, int64_list

  kind = feature.WhichOneof('kind')
  result[key] = np.array(getattr(feature, kind).value)

result
{'feature3': array([0.5251196]),
 'feature1': array([4]),
 'feature0': array([0]),
 'feature2': array([b'goat'], dtype='|S4')}

Procédure pas à pas : lecture et écriture de données d'image

Il s'agit d'un exemple de bout en bout de la lecture et de l'écriture de données d'image à l'aide de TFRecords. En utilisant une image comme données d'entrée, vous écrirez les données dans un fichier TFRecord, puis relirez le fichier et afficherez l'image.

Cela peut être utile si, par exemple, vous souhaitez utiliser plusieurs modèles sur le même jeu de données d'entrée. Au lieu de stocker les données d'image brutes, elles peuvent être prétraitées au format TFRecords, et cela peut être utilisé dans tous les traitements et modélisations ultérieurs.

Tout d'abord, téléchargeons cette image d'un chat dans la neige et cette photo du pont de Williamsburg, NYC en construction.

Récupérer les images

cat_in_snow  = tf.keras.utils.get_file(
    '320px-Felis_catus-cat_on_snow.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')

williamsburg_bridge = tf.keras.utils.get_file(
    '194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 0us/step
32768/17858 [=======================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg
16384/15477 [===============================] - 0s 0us/step
24576/15477 [===============================================] - 0s 0us/step
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

JPEG

display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))

JPEG

Ecrire le fichier TFRecord

Comme précédemment, encodez les fonctionnalités en tant que types compatibles avec tf.train.Example . Cela stocke la fonctionnalité de chaîne d'image brute, ainsi que la hauteur, la largeur, la profondeur et la fonctionnalité d' label arbitraire. Ce dernier est utilisé lorsque vous écrivez le fichier pour faire la distinction entre l'image du chat et l'image du pont. Utilisez 0 pour l'image du chat et 1 pour l'image du pont :

image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
  image_shape = tf.io.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')
features {
  feature {
    key: "depth"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "height"
    value {
      int64_list {
        value: 213
      }
...

Notez que toutes les fonctionnalités sont maintenant stockées dans le message tf.train.Example . Ensuite, fonctionnalisez le code ci-dessus et écrivez les exemples de messages dans un fichier nommé images.tfrecords :

# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.train.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())
du -sh {record_file}
36K images.tfrecords

Lire le fichier TFRecord

Vous avez maintenant le fichier - images.tfrecords - et pouvez maintenant parcourir les enregistrements qu'il contient pour relire ce que vous avez écrit. Étant donné que dans cet exemple, vous ne reproduisez que l'image, la seule fonctionnalité dont vous aurez besoin est la chaîne d'image brute. Extrayez-le à l'aide des getters décrits ci-dessus, à savoir example.features.feature['image_raw'].bytes_list.value[0] . Vous pouvez également utiliser les étiquettes pour déterminer quel enregistrement est le chat et lequel est le pont :

raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
<MapDataset element_spec={'depth': TensorSpec(shape=(), dtype=tf.int64, name=None), 'height': TensorSpec(shape=(), dtype=tf.int64, name=None), 'image_raw': TensorSpec(shape=(), dtype=tf.string, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 'width': TensorSpec(shape=(), dtype=tf.int64, name=None)}>

Récupérez les images du fichier TFRecord :

for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))

JPEG

JPEG