Посмотреть на TensorFlow.org | Запустить в Google Colab | Посмотреть исходный код на GitHub | Скачать блокнот |
Настраивать
!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile
Типы расширений
Определяемые пользователем типы могут сделать проекты более читабельными, модульными и удобными в сопровождении. Однако большинство API-интерфейсов TensorFlow имеют очень ограниченную поддержку определяемых пользователем типов Python. Сюда входят как высокоуровневые API (такие как Keras , tf.function , tf.SavedModel ), так и более низкоуровневые API (такие как tf.while_loop
и tf.concat
). Типы расширения TensorFlow можно использовать для создания определяемых пользователем объектно-ориентированных типов, которые без проблем работают с API-интерфейсами TensorFlow. Чтобы создать тип расширения, просто определите класс Python с tf.experimental.ExtensionType
в качестве его основы и используйте аннотации типа , чтобы указать тип для каждого поля.
class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.
class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64
Базовый класс tf.experimental.ExtensionType
работает аналогично typing.NamedTuple
и @dataclasses.dataclass
из стандартной библиотеки Python. В частности, он автоматически добавляет конструктор и специальные методы (такие как __repr__
и __eq__
) на основе аннотаций типа поля.
Как правило, типы расширений попадают в одну из двух категорий:
Структуры данных , которые объединяют набор связанных значений и могут выполнять полезные операции на основе этих значений. Структуры данных могут быть довольно общими (например, пример
TensorGraph
выше); или они могут быть сильно адаптированы к конкретной модели.Тензороподобные типы , которые специализируются или расширяют понятие «тензор». Типы в этой категории имеют
rank
,shape
и обычноdtype
; и имеет смысл использовать их с тензорными операциями (такими какtf.stack
,tf.add
илиtf.matmul
).MaskedTensor
иCSRSparseMatrix
являются примерами тензороподобных типов.
Поддерживаемые API
Типы расширений поддерживаются следующими API-интерфейсами TensorFlow:
- Keras : типы расширений могут использоваться в качестве входных и выходных данных для
Models
иLayers
. - tf.data.Dataset : типы расширений могут быть включены в
Datasets
и возвращеныIterators
наборов данных. - Концентратор Tensorflow : типы расширений можно использовать в качестве входных и выходных данных для модулей
tf.hub
. - SavedModel : типы расширений можно использовать в качестве входных и выходных данных для функций
SavedModel
. - tf.function : типы расширений могут использоваться в качестве аргументов и возвращаемых значений для функций, обернутых декоратором
@tf.function
. - while loops : типы расширений могут использоваться как переменные цикла в
tf.while_loop
и могут использоваться как аргументы и возвращаемые значения для тела цикла while. - условные : типы расширений могут быть выбраны условно с помощью
tf.cond
иtf.case
. - py_function : типы расширений могут использоваться в качестве аргументов и возвращаемых значений для аргумента
func
дляtf.py_function
. - Тензорные операции : типы расширений могут быть расширены для поддержки большинства операций TensorFlow, которые принимают входные данные Tensor (например,
tf.matmul
,tf.gather
иtf.reduce_sum
). См. раздел « Отправка » ниже для получения дополнительной информации. - стратегия распределения : типы расширений можно использовать в качестве значений для каждой реплики.
Дополнительные сведения см. в разделе «API-интерфейсы TensorFlow, поддерживающие ExtensionTypes» ниже.
Требования
Типы полей
Все поля (также известные как переменные экземпляра) должны быть объявлены, и для каждого поля должна быть предоставлена аннотация типа. Поддерживаются следующие типы аннотаций:
Тип | Пример |
---|---|
Целые числа Python | i: int |
Питон плавает | f: float |
Строки Python | s: str |
логические значения Python | b: bool |
Python | n: None |
Тензорные формы | shape: tf.TensorShape |
Тензорные типы | dtype: tf.DType |
Тензоры | t: tf.Tensor |
Типы расширений | mt: MyMaskedTensor |
Рваные тензоры | rt: tf.RaggedTensor |
Разреженные тензоры | st: tf.SparseTensor |
Индексированные фрагменты | s: tf.IndexedSlices |
Дополнительные тензоры | o: tf.experimental.Optional |
Союзы типов | int_or_float: typing.Union[int, float] |
Кортежи | params: typing.Tuple[int, float, tf.Tensor, int] |
Кортежи длины Var | lengths: typing.Tuple[int, ...] |
Сопоставления | tags: typing.Mapping[str, tf.Tensor] |
Дополнительные значения | weight: typing.Optional[tf.Tensor] |
Изменчивость
Типы расширений должны быть неизменяемыми. Это гарантирует, что они могут быть правильно отслежены механизмами отслеживания графов TensorFlow. Если вы обнаружите, что хотите изменить значение типа расширения, рассмотрите вместо этого определение методов, которые преобразуют значения. Например, вместо определения метода set_mask
для MaskedTensor
вы можете определить метод replace_mask
, который возвращает новый MaskedTensor
:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def replace_mask(self, new_mask):
self.values.shape.assert_is_compatible_with(new_mask.shape)
return MaskedTensor(self.values, new_mask)
Функциональность, добавленная ExtensionType
Базовый класс ExtensionType
предоставляет следующие функциональные возможности:
- Конструктор (
__init__
). - Метод печатного представления (
__repr__
). - Операторы равенства и неравенства (
__eq__
). - Метод проверки (
__validate__
). - Принудительная неизменность.
- Вложенный
TypeSpec
. - Поддержка отправки Tensor API.
Дополнительные сведения о настройке этой функции см. в разделе «Настройка типов расширений» ниже.
Конструктор
Конструктор, добавленный ExtensionType
, принимает каждое поле в качестве именованного аргумента (в том порядке, в котором они перечислены в определении класса). Этот конструктор проверит тип каждого параметра и при необходимости преобразует их. В частности, поля Tensor
преобразуются с помощью tf.convert_to_tensor
; Поля Tuple
преобразуются в tuple
; Поля Mapping
преобразуются в неизменяемые словари.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
# Fields are type-checked and converted to the declared types.
# E.g., mt.values is converted to a Tensor.
print(mt.values)
tf.Tensor( [[1 2 3] [4 5 6]], shape=(2, 3), dtype=int32)
Конструктор вызывает TypeError
, если значение поля не может быть преобразовано в его объявленный тип:
try:
MaskedTensor([1, 2, 3], None)
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got None
Значение по умолчанию для поля можно указать, установив его значение на уровне класса:
class Pencil(tf.experimental.ExtensionType):
color: str = "black"
has_erasor: bool = True
length: tf.Tensor = 1.0
Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)
Печатное представление
ExtensionType
добавляет метод представления для печати по умолчанию ( __repr__
), который включает имя класса и значение для каждого поля:
print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True, True, False])>)
Операторы равенства
ExtensionType
добавляет операторы равенства по умолчанию ( __eq__
и __ne__
), которые считают два значения равными, если они имеют одинаковый тип и все их поля равны. Тензорные поля считаются равными, если они имеют одинаковую форму и поэлементно равны для всех элементов.
a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True a == b: False a == a.values: False
Метод проверки
ExtensionType
добавляет метод __validate__
, который можно переопределить для выполнения проверок полей. Он запускается после вызова конструктора и после того, как поля проверены и преобразованы в их объявленные типы, поэтому он может предположить, что все поля имеют свои объявленные типы.
Следующий пример обновляет MaskedTensor
, чтобы проверить shape
s и dtype
его полей:
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor
def __validate__(self):
self.values.shape.assert_is_compatible_with(self.mask.shape)
assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
MaskedTensor([1, 2, 3], [0, 1, 0]) # wrong dtype for mask.
except AssertionError as e:
print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
MaskedTensor([1, 2, 3], [True, False]) # shapes don't match.
except ValueError as e:
print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible
Принудительная неизменность
ExtensionType
переопределяет методы __setattr__
и __delattr__
для предотвращения изменения, гарантируя неизменность значений типа расширения.
mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
mt.mask = [True, True, True]
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
mt.mask[0] = False
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
del mt.mask
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
Вложенный TypeSpec
Каждый класс ExtensionType
имеет соответствующий класс TypeSpec
, который создается автоматически и сохраняется как <extension_type_name>.Spec
.
Этот класс собирает всю информацию из значения, кроме значений любых вложенных тензоров. В частности, TypeSpec
для значения создается путем замены любого вложенного Tensor, ExtensionType или CompositeTensor его TypeSpec
.
class Player(tf.experimental.ExtensionType):
name: tf.Tensor
attributes: Mapping[str, tf.Tensor]
anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name) # Records dtype and shape, but not the string value.
print(anne_spec.attributes) # Records keys and TensorSpecs for values.
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'> TensorSpec(shape=(), dtype=tf.string, name=None) ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})
Значения TypeSpec
могут быть созданы явно или из значения ExtensionType
с использованием tf.type_spec_from_value
:
spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)
TypeSpec
используются TensorFlow для разделения значений на статический компонент и динамический компонент :
- Статический компонент (который фиксируется во время построения графика) кодируется с помощью
tf.TypeSpec
. - Динамический компонент (который может меняться при каждом запуске графа) кодируется как список
tf.Tensor
s.
Например, tf.function
свою обернутую функцию всякий раз, когда аргумент имеет ранее невиданный TypeSpec
:
@tf.function
def anonymize_player(player):
print("<<TRACING>>")
return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'> <<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))
Для получения дополнительной информации см. Руководство по tf.function .
Настройка типов расширений
Помимо простого объявления полей и их типов, типы расширений могут:
- Переопределите печатное представление по умолчанию (
__repr__
). - Определить методы.
- Определите методы класса и статические методы.
- Определить свойства.
- Переопределить конструктор по умолчанию (
__init__
). - Переопределите оператор равенства по умолчанию (
__eq__
). - Определите операторы (такие как
__add__
и__lt__
). - Объявите значения по умолчанию для полей.
- Определите подклассы.
Переопределение печатного представления по умолчанию
Вы можете переопределить этот оператор преобразования строк по умолчанию для типов расширений. В следующем примере класс MaskedTensor
для создания более удобочитаемого строкового представления, когда значения печатаются в режиме Eager.
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for invalid values.
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def masked_tensor_str(values, mask):
if isinstance(values, tf.Tensor):
if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
else:
return f'MaskedTensor(values={values}, mask={mask})'
if len(values.shape) == 1:
items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
else:
items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
return '[%s]' % ', '.join(items)
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>
Определение методов
Типы расширений могут определять методы, как и любой обычный класс Python. Например, тип MaskedTensor
может определить метод with_default
, который возвращает копию self
с замаскированными значениями, замененными заданным значением default
. При желании методы могут быть аннотированы декоратором @tf.function
.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def with_default(self, default):
return tf.where(self.mask, self.values, default)
MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>
Определение методов классов и статических методов
Типы расширений могут определять методы с помощью декораторов @classmethod
и @staticmethod
. Например, тип MaskedTensor
может определять фабричный метод, который маскирует любой элемент с заданным значением:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
@staticmethod
def from_tensor_and_value_to_mask(values, value_to_mask):
return MaskedTensor(values, values == value_to_mask)
x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[_, 0, _], [_, 0, 0]]>
Определение свойств
Типы расширений могут определять свойства с помощью декоратора @property
, как и в любом обычном классе Python. Например, тип MaskedTensor
может определить свойство dtype
, которое является сокращением для dtype значений:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@property
def dtype(self):
return self.values.dtype
MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32
Переопределение конструктора по умолчанию
Вы можете переопределить конструктор по умолчанию для типов расширений. Пользовательские конструкторы должны устанавливать значение для каждого объявленного поля; и после того, как пользовательский конструктор вернется, все поля будут проверены на тип, а значения будут преобразованы, как описано выше.
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
def __init__(self, name, price, discount=0):
self.name = name
self.price = price * (1 - discount)
print(Toy("ball", 5.0, discount=0.2)) # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
В качестве альтернативы вы можете оставить конструктор по умолчанию как есть, но добавить один или несколько фабричных методов. Например:
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
@staticmethod
def new_toy_with_discount(name, price, discount):
return Toy(name, price * (1 - discount))
print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
Переопределение оператора равенства по умолчанию ( __eq__
)
Вы можете переопределить оператор __eq__
по умолчанию для типов расширений. В следующем примере MaskedTensor
обновляется, чтобы игнорировать маскированные элементы при сравнении на равенство.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def __eq__(self, other):
result = tf.math.equal(self.values, other.values)
result = result | ~(self.mask & other.mask)
return tf.reduce_all(result)
x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)
Использование прямых ссылок
Если тип поля еще не определен, вы можете вместо этого использовать строку, содержащую имя типа. В следующем примере строка "Node"
используется для аннотации children
поля, поскольку тип Node
еще не (полностью) определен.
class Node(tf.experimental.ExtensionType):
value: tf.Tensor
children: Tuple["Node", ...] = ()
Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))
Определение подклассов
Типы расширений могут быть подклассами с использованием стандартного синтаксиса Python. Подклассы типа расширения могут добавлять новые поля, методы и свойства; и может переопределить конструктор, печатное представление и оператор равенства. В следующем примере определяется базовый класс TensorGraph
, который использует три поля Tensor
для кодирования набора ребер между узлами. Затем он определяет подкласс, который добавляет поле Tensor
для записи «значения функции» для каждого узла. Подкласс также определяет метод для распространения значений признаков по краям.
class TensorGraph(tf.experimental.ExtensionType):
num_nodes: tf.Tensor
edge_src: tf.Tensor # edge_src[e] = index of src node for edge e.
edge_dst: tf.Tensor # edge_dst[e] = index of dst node for edge e.
class TensorGraphWithNodeFeature(TensorGraph):
node_features: tf.Tensor # node_features[n] = feature value for node n.
def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
updates = tf.gather(self.node_features, self.edge_src) * weight
new_node_features = tf.tensor_scatter_nd_add(
self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
return TensorGraphWithNodeFeature(
self.num_nodes, self.edge_src, self.edge_dst, new_node_features)
g = TensorGraphWithNodeFeature( # Edges: 0->1, 4->3, 2->2, 2->1
num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])
print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10. 0. 2. 5. -1. 0.], shape=(6,), dtype=float32) After propagating: tf.Tensor([10. 12. 4. 4. -1. 0.], shape=(6,), dtype=float32)
Определение частных полей
Поля типа расширения могут быть помечены как закрытые, если перед ними поставить знак подчеркивания (в соответствии со стандартными соглашениями Python). Это никак не влияет на то, как TensorFlow обрабатывает поля; но просто служит сигналом для любых пользователей типа расширения, что эти поля являются частными.
Настройка TypeSpec
ExtensionType
Каждый класс ExtensionType
имеет соответствующий класс TypeSpec
, который создается автоматически и сохраняется как <extension_type_name>.Spec
. Дополнительные сведения см. в разделе «Вложенный TypeSpec» выше.
Чтобы настроить TypeSpec
, просто определите свой собственный вложенный класс с именем Spec
, и ExtensionType
будет использовать его в качестве основы для автоматически TypeSpec
. Вы можете настроить класс Spec
следующим образом:
- Переопределение печатного представления по умолчанию.
- Переопределение конструктора по умолчанию.
- Определение методов, методов классов, статических методов и свойств.
В следующем примере класс MaskedTensor.Spec
настраивается, чтобы упростить его использование:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def with_values(self, new_values):
return MaskedTensor(new_values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
def __repr__(self):
return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
Отправка тензорного API
Типы расширения могут быть «тензороподобными» в том смысле, что они специализируются или расширяют интерфейс, определенный типом tf.Tensor
. Примеры тензороподобных типов расширений включают RaggedTensor
, SparseTensor
и MaskedTensor
. Декораторы отправки можно использовать для переопределения поведения операций TensorFlow по умолчанию при применении к тензороподобным типам расширений. В настоящее время TensorFlow определяет три декоратора отправки:
-
@tf.experimental.dispatch_for_api(tf_api)
-
@tf.experimental.dispatch_for_unary_elementwise_api(x_type)
-
@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)
Диспетчеризация для одного API
Декоратор tf.experimental.dispatch_for_api
переопределяет поведение по умолчанию указанной операции TensorFlow, когда она вызывается с указанной сигнатурой. Например, вы можете использовать этот декоратор, чтобы указать, как tf.stack
должен обрабатывать значения MaskedTensor
:
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
Это переопределяет реализацию по умолчанию для tf.stack
всякий раз, когда он вызывается со списком значений MaskedTensor
(поскольку аргумент values
аннотируется с помощью typing.List[MaskedTensor]
):
x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>
Чтобы позволить tf.stack
обрабатывать списки смешанных MaskedTensor
и Tensor
, вы можете уточнить аннотацию типа для параметра values
и соответствующим образом обновить тело функции:
tf.experimental.unregister_dispatch_for(masked_stack)
def convert_to_masked_tensor(x):
if isinstance(x, MaskedTensor):
return x
else:
return MaskedTensor(x, tf.ones_like(x, tf.bool))
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
values = [convert_to_masked_tensor(v) for v in values]
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>
Список API, которые можно переопределить, см. в документации API для tf.experimental.dispatch_for_api
.
Отправка для всех унарных поэлементных API
Декоратор tf.experimental.dispatch_for_unary_elementwise_apis
переопределяет поведение по умолчанию всех унарных поэлементных операций (таких как tf.math.cos
) всякий раз, когда значение первого аргумента (обычно называемого x
) соответствует аннотации типа x_type
. Декорированная функция должна принимать два аргумента:
-
api_func
: функция, которая принимает один параметр и выполняет поэлементную операцию (например,tf.abs
). -
x
: первый аргумент поэлементной операции.
В следующем примере все унарные поэлементные операции обновляются для обработки типа MaskedTensor
:
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def masked_tensor_unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
Эта функция теперь будет использоваться всякий раз, когда для MaskedTensor
унарная поэлементная операция.
x = MaskedTensor([1, -2, -3], [True, False, True])
print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>
Отправка для бинарных всех поэлементных API
Точно так же tf.experimental.dispatch_for_binary_elementwise_apis
можно использовать для обновления всех бинарных поэлементных операций для обработки типа MaskedTensor
:
@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>
Список поэлементных API, которые переопределяются, см. в документации по API для tf.experimental.dispatch_for_unary_elementwise_apis
и tf.experimental.dispatch_for_binary_elementwise_apis
.
Пакетные типы расширений
ExtensionType
является пакетным , если один экземпляр может использоваться для представления пакета значений. Как правило, это достигается путем добавления пакетных измерений ко всем вложенным Tensor
s. Следующие API-интерфейсы TensorFlow требуют, чтобы любые входные данные типа расширения были пакетными:
-
tf.data.Dataset
(batch
,unbatch
,from_tensor_slices
) -
tf.Keras
(fit
,evaluate
,predict
) -
tf.map_fn
По умолчанию BatchableExtensionType
создает пакетные значения, объединяя любые вложенные Tensor
s, CompositeTensor
s и ExtensionType
s. Если это не подходит для вашего класса, вам нужно будет использовать tf.experimental.ExtensionTypeBatchEncoder
, чтобы переопределить это поведение по умолчанию. Например, было бы нецелесообразно создавать пакет значений tf.SparseTensor
, просто складывая отдельные values
разреженных тензоров, indices
и поля dense_shape
— в большинстве случаев вы не можете складывать эти тензоры, поскольку они имеют несовместимые формы. ; и даже если бы вы могли, результат не был бы действительным SparseTensor
.
Пример BatchableExtensionType: сеть
В качестве примера рассмотрим простой класс Network
, используемый для балансировки нагрузки, который отслеживает, сколько работы осталось сделать на каждом узле, и какая полоса пропускания доступна для перемещения работы между узлами:
class Network(tf.experimental.ExtensionType): # This version is not batchable.
work: tf.Tensor # work[n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[n1, n2] = bandwidth from n1->n2
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
Чтобы сделать этот тип пакетным, измените базовый тип на BatchableExtensionType
и настройте форму каждого поля, чтобы включить дополнительные измерения пакета. В следующем примере также добавляется поле shape
для отслеживания формы пакета. Это поле shape
не требуется для tf.data.Dataset
или tf.map_fn
, но оно требуется для tf.Keras
.
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
def network_repr(network):
work = network.work
bandwidth = network.bandwidth
if hasattr(work, 'numpy'):
work = ' '.join(str(work.numpy()).split())
if hasattr(bandwidth, 'numpy'):
bandwidth = ' '.join(str(bandwidth.numpy()).split())
return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
work=tf.stack([net1.work, net2.work]),
bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]> batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
Затем вы можете использовать tf.data.Dataset
для перебора пакетов сетей:
dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
И вы также можете использовать map_fn
для применения функции к каждому пакетному элементу:
def balance_work_greedy(network):
delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
delta /= 4
delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
new_work = network.work + tf.reduce_sum(delta, -1)
return Network(new_work, network.bandwidth)
tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
API-интерфейсы TensorFlow, поддерживающие ExtensionTypes
@tf.function
tf.function — это декоратор, который предварительно вычисляет графики TensorFlow для функций Python, что может существенно повысить производительность вашего кода TensorFlow. Значения типа расширения можно использовать прозрачно с @tf.function
функциями.
class Pastry(tf.experimental.ExtensionType):
sweetness: tf.Tensor # 2d embedding that encodes sweetness
chewiness: tf.Tensor # 2d embedding that encodes chewiness
@tf.function
def combine_pastry_features(x: Pastry):
return (x.sweetness + x.chewiness) / 2
cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
Если вы хотите явно указать input_signature
для tf.function
, вы можете сделать это с помощью TypeSpec
типа расширения.
pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))
@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
return Pastry(x.sweetness + delta, x.chewiness)
increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)
Конкретные функции
Конкретные функции инкапсулируют отдельные отслеживаемые графы, построенные с помощью tf.function
. Типы расширений можно использовать прозрачно с конкретными функциями.
cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
Операции потока управления
Типы расширений поддерживаются операциями потока управления TensorFlow:
# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>
Поток управления автографом
Типы расширения также поддерживаются операторами потока управления в tf.function (с использованием автографа). В следующем примере операторы if
и for
автоматически преобразуются в tf.cond
и tf.while_loop
, которые поддерживают типы расширения.
@tf.function
def fn(x, b):
if b:
x = MaskedTensor(x, tf.less(x, 0))
else:
x = MaskedTensor(x, tf.greater(x, 0))
for i in tf.range(5 if b else 7):
x = x.with_values(x.values + 1 / 2)
return x
print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]> <MaskedTensor [4.5, _, 6.5]>
Керас
tf.keras — это высокоуровневый API TensorFlow для создания и обучения моделей глубокого обучения. Типы расширений могут передаваться в качестве входных данных для модели Keras, передаваться между слоями Keras и возвращаться моделями Keras. В настоящее время Keras предъявляет два требования к типам расширений:
- Они должны быть пакетными (см. «Пакетные типы расширений» выше).
- Должен иметь поле или свойство с именем
shape
.shape[0]
считается размером пакета.
В следующих двух подразделах приведены примеры, показывающие, как типы расширений можно использовать с Keras.
Пример Keras: Network
В качестве первого примера рассмотрим класс Network
, определенный в разделе «Пакетные типы расширений» выше, который можно использовать для балансировки нагрузки между узлами. Его определение повторяется здесь:
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
single_network = Network( # A single network w/ 4 nodes.
work=[8.0, 5, 12, 2],
bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])
batch_of_networks = Network( # Batch of 2 networks, each w/ 2 nodes.
work=[[8.0, 5], [3, 2]],
bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])
Вы можете определить новый слой Keras, который обрабатывает Network
s.
class BalanceNetworkLayer(tf.keras.layers.Layer):
"""Layer that balances work between nodes in a network.
Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
"""
def call(self, inputs):
# This function is defined above, in "Batchable ExtensionTypes" section.
return balance_work_greedy(inputs)
Затем вы можете использовать эти слои для создания простой модели. Чтобы передать ExtensionType
в модель, вы можете использовать слой tf.keras.layer.Input
с type_spec
, установленным на TypeSpec
типа расширения. Если модель Keras будет использоваться для обработки пакетов, то type_spec
должен включать размер пакета.
input_spec = Network.Spec(shape=None,
work=tf.TensorSpec(None, tf.float32),
bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
BalanceNetworkLayer(),
])
Наконец, вы можете применить модель к одной сети и к группе сетей.
model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>
Пример Keras: MaskedTensor
В этом примере MaskedTensor
расширен для поддержки Keras
. shape
определяется как свойство, вычисляемое из поля values
. Keras требует, чтобы вы добавили это свойство как к типу расширения, так и к его TypeSpec
. MaskedTensor
также определяет переменную __name__
, которая потребуется для сериализации SavedModel
(ниже).
class MaskedTensor(tf.experimental.BatchableExtensionType):
# __name__ is required for serialization in SavedModel; see below for details.
__name__ = 'extension_type_colab.MaskedTensor'
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_default(self, default):
return tf.where(self.mask, self.values, default)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_shape(self):
return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
tf.TensorSpec(shape, self.mask.dtype))
Затем декораторы отправки используются для переопределения поведения по умолчанию нескольких API-интерфейсов TensorFlow. Поскольку эти API используются стандартными слоями Keras (такими как слой Dense
), их переопределение позволит нам использовать эти слои с MaskedTensor
. Для целей этого примера matmul
для замаскированных тензоров определен так, чтобы обрабатывать замаскированные значения как нули (т. е. не включать их в произведение).
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
return MaskedTensor(op(x.values), x.mask)
@tf.experimental.dispatch_for_binary_elementwise_apis(
Union[MaskedTensor, tf.Tensor],
Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
x = convert_to_masked_tensor(x)
y = convert_to_masked_tensor(y)
return MaskedTensor(op(x.values, y.values), x.mask & y.mask)
@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
transpose_a=False, transpose_b=False,
adjoint_a=False, adjoint_b=False,
a_is_sparse=False, b_is_sparse=False,
output_type=None):
if isinstance(a, MaskedTensor):
a = a.with_default(0)
if isinstance(b, MaskedTensor):
b = b.with_default(0)
return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
adjoint_b, a_is_sparse, b_is_sparse, output_type)
Затем вы можете построить модель Keras, которая принимает входные данные MaskedTensor
, используя стандартные слои Keras:
input_spec = MaskedTensor.Spec([None, 2], tf.float32)
masked_tensor_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
[[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3 1/1 [==============================] - 1s 955ms/step - loss: 10.2833 Epoch 2/3 1/1 [==============================] - 0s 5ms/step - loss: 10.2833 Epoch 3/3 1/1 [==============================] - 0s 5ms/step - loss: 10.2833 tf.Tensor( [[-0.09944128] [-0.7225147 ] [-1.3020657 ]], shape=(3, 1), dtype=float32)
Сохраненная модель
SavedModel — это сериализованная программа TensorFlow, включающая как веса, так и вычисления. Он может быть построен из модели Keras или из пользовательской модели. В любом случае типы расширений можно использовать прозрачно с функциями и методами, определенными в SavedModel.
SavedModel может сохранять модели, слои и функции, обрабатывающие типы расширений, если у типов расширений есть поле __name__
. Это имя используется для регистрации типа расширения, чтобы его можно было найти при загрузке модели.
Пример: сохранение модели Keras
Модели Keras, которые используют типы расширения, могут быть сохранены с помощью SavedModel
.
masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
2021-11-06 01:25:14.285250: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel. INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets <tf.Tensor: shape=(3, 1), dtype=float32, numpy= array([[-0.09944128], [-0.7225147 ], [-1.3020657 ]], dtype=float32)>
Пример: сохранение пользовательской модели
SavedModel также можно использовать для сохранения пользовательских подклассов tf.Module
с функциями, которые обрабатывают типы расширений.
class CustomModule(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def grow(self, x: MaskedTensor):
"""Increase values in `x` by multiplying them by `self.v`."""
return MaskedTensor(x.values * self.v, x.mask)
module = CustomModule(100.0)
module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets <MaskedTensor [_, 200.0, _]>
Загрузка SavedModel, когда ExtensionType недоступен
Если вы загрузите SavedModel
, который использует ExtensionType
, но этот ExtensionType
недоступен (т. е. не был импортирован), вы увидите предупреждение, и TensorFlow вернется к использованию объекта «анонимного типа расширения». Этот объект будет иметь те же поля, что и исходный тип, но не будет иметь каких-либо дополнительных настроек, которые вы добавили для типа, таких как пользовательские методы или свойства.
Использование ExtensionTypes с обслуживанием TensorFlow
В настоящее время обслуживание TensorFlow (и другие потребители словаря «сигнатур» SavedModel) требуют, чтобы все входные и выходные данные были необработанными тензорами. Если вы хотите использовать обслуживание TensorFlow с моделью, использующей типы расширений, вы можете добавить методы-оболочки, которые составляют или разлагают значения типов расширений из тензоров. Например:
class CustomModuleWrapper(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def var_weighted_mean(self, x: MaskedTensor):
"""Mean value of unmasked values in x, weighted by self.v."""
x = MaskedTensor(x.values * self.v, x.mask)
return (tf.reduce_sum(x.with_default(0)) /
tf.reduce_sum(tf.cast(x.mask, x.dtype)))
@tf.function()
def var_weighted_mean_wrapper(self, x_values, x_mask):
"""Raw tensor wrapper for var_weighted_mean."""
return self.var_weighted_mean(MaskedTensor(x_values, x_mask))
module = CustomModuleWrapper([3., 2., 8., 5.])
module.var_weighted_mean_wrapper.get_concrete_function(
tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets <tf.Tensor: shape=(), dtype=float32, numpy=12.0>
Наборы данных
tf.data — это API, который позволяет создавать сложные конвейеры ввода из простых, повторно используемых частей. Его основная структура данных — tf.data.Dataset
, которая представляет собой последовательность элементов, в которой каждый элемент состоит из одного или нескольких компонентов.
Создание наборов данных с типами расширений
Наборы данных могут быть созданы из значений типа расширения с использованием Dataset.from_tensors
, Dataset.from_tensor_slices
или Dataset.from_generator
:
ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
print(value)
<MaskedTensor [0, 1, 2, 3]> <MaskedTensor [4, 5, 6, 7]> <MaskedTensor [8, 9, 10, 11]> <MaskedTensor [12, 13, 14, 15]> <MaskedTensor [16, 17, 18, 19]>
def value_gen():
for i in range(2, 7):
yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])
ds = tf.data.Dataset.from_generator(
value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>
Пакетирование и разделение наборов данных с типами расширения
Наборы данных с типами расширения могут быть пакетными и непакетными с помощью Dataset.batch
и Dataset.unbatch
.
batched_ds = ds.batch(2)
for value in batched_ds:
print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]> <MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]> <MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>