Veja no TensorFlow.org | Executar no Google Colab | Ver fonte no GitHub | Baixar caderno |
Configurar
!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile
Tipos de extensão
Os tipos definidos pelo usuário podem tornar os projetos mais legíveis, modulares e de fácil manutenção. No entanto, a maioria das APIs do TensorFlow tem suporte muito limitado para tipos Python definidos pelo usuário. Isso inclui APIs de alto nível (como Keras , tf.function , tf.SavedModel ) e APIs de nível inferior (como tf.while_loop
e tf.concat
). Os tipos de extensão do TensorFlow podem ser usados para criar tipos orientados a objetos definidos pelo usuário que funcionam perfeitamente com as APIs do TensorFlow. Para criar um tipo de extensão, simplesmente defina uma classe Python com tf.experimental.ExtensionType
como base e use anotações de tipo para especificar o tipo de cada campo.
class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.
class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64
A classe base tf.experimental.ExtensionType
funciona de forma semelhante à typing.NamedTuple
e @dataclasses.dataclass
da biblioteca padrão do Python. Em particular, ele adiciona automaticamente um construtor e métodos especiais (como __repr__
e __eq__
) com base nas anotações do tipo de campo.
Normalmente, os tipos de extensão tendem a se enquadrar em uma das duas categorias:
Estruturas de dados , que agrupam uma coleção de valores relacionados e podem fornecer operações úteis com base nesses valores. As estruturas de dados podem ser bastante gerais (como o exemplo do
TensorGraph
acima); ou podem ser altamente personalizados para um modelo específico.Tipos do tipo tensor , que especializam ou estendem o conceito de "tensor". Tipos nesta categoria têm uma
rank
, umashape
e geralmente umdtype
; e faz sentido usá-los com operações de tensor (comotf.stack
,tf.add
outf.matmul
).MaskedTensor
eCSRSparseMatrix
são exemplos de tipos do tipo tensor.
APIs compatíveis
Os tipos de extensão são compatíveis com as seguintes APIs do TensorFlow:
- Keras : Os tipos de extensão podem ser usados como entradas e saídas para
Models
eLayers
Keras. - tf.data.Dataset : Os tipos de extensão podem ser incluídos em
Datasets
e retornados porIterators
de dataset. - Hub Tensorflow : Os tipos de extensão podem ser usados como entradas e saídas para módulos
tf.hub
. - SavedModel : Os tipos de extensão podem ser usados como entradas e saídas para funções
SavedModel
. - tf.function : Os tipos de extensão podem ser usados como argumentos e valores de retorno para funções agrupadas com o decorador
@tf.function
. - while loops : Os tipos de extensão podem ser usados como variáveis de loop em
tf.while_loop
, e podem ser usados como argumentos e valores de retorno para o corpo do while-loop. - condicionais : Os tipos de extensão podem ser selecionados condicionalmente usando
tf.cond
etf.case
. - py_function : Os tipos de extensão podem ser usados como argumentos e valores de retorno para o argumento
func
paratf.py_function
. - Operações de tensor : os tipos de extensão podem ser estendidos para oferecer suporte à maioria das operações de TensorFlow que aceitam entradas de tensor (por exemplo,
tf.matmul
,tf.gather
etf.reduce_sum
). Consulte a seção " Despacho " abaixo para obter mais informações. - estratégia de distribuição : os tipos de extensão podem ser usados como valores por réplica.
Para obter mais detalhes, consulte a seção "APIs do TensorFlow compatíveis com ExtensionTypes" abaixo.
Requisitos
Tipos de campo
Todos os campos (também conhecidos como variáveis de instância) devem ser declarados e uma anotação de tipo deve ser fornecida para cada campo. As seguintes anotações de tipo são compatíveis:
Modelo | Exemplo |
---|---|
Inteiros Python | i: int |
Python flutua | f: float |
Strings Python | s: str |
Booleanos Python | b: bool |
Python Nenhum | n: None |
Formas de tensor | shape: tf.TensorShape |
Tipos de tensor | dtype: tf.DType |
Tensores | t: tf.Tensor |
Tipos de extensão | mt: MyMaskedTensor |
Tensores irregulares | rt: tf.RaggedTensor |
Tensores Esparsos | st: tf.SparseTensor |
Fatias Indexadas | s: tf.IndexedSlices |
Tensores Opcionais | o: tf.experimental.Optional |
Tipo de uniões | int_or_float: typing.Union[int, float] |
Tuplas | params: typing.Tuple[int, float, tf.Tensor, int] |
Tuplas de comprimento de var | lengths: typing.Tuple[int, ...] |
Mapeamentos | tags: typing.Mapping[str, tf.Tensor] |
Valores opcionais | weight: typing.Optional[tf.Tensor] |
Mutabilidade
Os tipos de extensão devem ser imutáveis. Isso garante que eles possam ser rastreados adequadamente pelos mecanismos de rastreamento de gráficos do TensorFlow. Se você quiser alterar um valor de tipo de extensão, considere definir métodos que transformam valores. Por exemplo, em vez de definir um método set_mask
para alterar um MaskedTensor
, você pode definir um método replace_mask
que retorna um novo MaskedTensor
:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def replace_mask(self, new_mask):
self.values.shape.assert_is_compatible_with(new_mask.shape)
return MaskedTensor(self.values, new_mask)
Funcionalidade adicionada por ExtensionType
A classe base ExtensionType
fornece a seguinte funcionalidade:
- Um construtor (
__init__
). - Um método de representação imprimível (
__repr__
). - Operadores de igualdade e desigualdade (
__eq__
). - Um método de validação (
__validate__
). - Imutabilidade imposta.
- Um
TypeSpec
aninhado. - Suporte de envio da API do tensor.
Consulte a seção "Personalizando ExtensionTypes" abaixo para obter mais informações sobre como personalizar essa funcionalidade.
Construtor
O construtor adicionado por ExtensionType
recebe cada campo como um argumento nomeado (na ordem em que foram listados na definição da classe). Esse construtor verificará cada parâmetro e os converterá quando necessário. Em particular, os campos de Tensor
são convertidos usando tf.convert_to_tensor
; Tuple
campos de tupla são convertidos em tuple
s; e os campos de Mapping
são convertidos em dicts imutáveis.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
# Fields are type-checked and converted to the declared types.
# E.g., mt.values is converted to a Tensor.
print(mt.values)
tf.Tensor( [[1 2 3] [4 5 6]], shape=(2, 3), dtype=int32)
O construtor gera um TypeError
se um valor de campo não puder ser convertido em seu tipo declarado:
try:
MaskedTensor([1, 2, 3], None)
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got None
O valor padrão para um campo pode ser especificado definindo seu valor no nível de classe:
class Pencil(tf.experimental.ExtensionType):
color: str = "black"
has_erasor: bool = True
length: tf.Tensor = 1.0
Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)
Representação para impressão
ExtensionType
adiciona um método de representação padrão para impressão ( __repr__
) que inclui o nome da classe e o valor de cada campo:
print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True, True, False])>)
Operadores de igualdade
ExtensionType
adiciona operadores de igualdade padrão ( __eq__
e __ne__
) que consideram dois valores iguais se tiverem o mesmo tipo e todos os seus campos forem iguais. Os campos de tensor são considerados iguais se tiverem a mesma forma e forem iguais em todos os elementos.
a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True a == b: False a == a.values: False
Método de validação
ExtensionType
adiciona um método __validate__
, que pode ser substituído para realizar verificações de validação em campos. Ele é executado depois que o construtor é chamado e depois que os campos são verificados e convertidos em seus tipos declarados, para que possa assumir que todos os campos têm seus tipos declarados.
O exemplo a seguir atualiza o MaskedTensor
para validar a shape
dtype
s de seus campos:
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor
def __validate__(self):
self.values.shape.assert_is_compatible_with(self.mask.shape)
assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
MaskedTensor([1, 2, 3], [0, 1, 0]) # wrong dtype for mask.
except AssertionError as e:
print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
MaskedTensor([1, 2, 3], [True, False]) # shapes don't match.
except ValueError as e:
print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible
Imutabilidade imposta
ExtensionType
substitui os métodos __setattr__
e __delattr__
para evitar a mutação, garantindo que os valores do tipo de extensão sejam imutáveis.
mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
mt.mask = [True, True, True]
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
mt.mask[0] = False
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
del mt.mask
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
Especificação de tipo aninhado
Cada classe ExtensionType
tem uma classe TypeSpec
correspondente, que é criada automaticamente e armazenada como <extension_type_name>.Spec
.
Essa classe captura todas as informações de um valor, exceto os valores de quaisquer tensores aninhados. Em particular, o TypeSpec
para um valor é criado substituindo qualquer Tensor, ExtensionType ou CompositeTensor aninhado por seu TypeSpec
.
class Player(tf.experimental.ExtensionType):
name: tf.Tensor
attributes: Mapping[str, tf.Tensor]
anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name) # Records dtype and shape, but not the string value.
print(anne_spec.attributes) # Records keys and TensorSpecs for values.
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'> TensorSpec(shape=(), dtype=tf.string, name=None) ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})
Os valores TypeSpec
podem ser construídos explicitamente ou podem ser construídos a partir de um valor ExtensionType
usando tf.type_spec_from_value
:
spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)
TypeSpec
s são usados pelo TensorFlow para dividir valores em um componente estático e um componente dinâmico :
- O componente estático (que é fixado no tempo de construção do gráfico) é codificado com um
tf.TypeSpec
. - O componente dinâmico (que pode variar cada vez que o gráfico é executado) é codificado como uma lista de
tf.Tensor
s.
Por exemplo, tf.function
refaz sua função encapsulada sempre que um argumento tem um TypeSpec
não visto anteriormente:
@tf.function
def anonymize_player(player):
print("<<TRACING>>")
return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'> WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'> <<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))
Para obter mais informações, consulte o Guia tf.function .
Personalizando tipos de extensão
Além de simplesmente declarar campos e seus tipos, os tipos de extensão podem:
- Substitua a representação imprimível padrão (
__repr__
). - Defina métodos.
- Defina métodos de classe e métodos estáticos.
- Defina propriedades.
- Substitua o construtor padrão (
__init__
). - Substitua o operador de igualdade padrão (
__eq__
). - Defina operadores (como
__add__
e__lt__
). - Declare valores padrão para campos.
- Defina subclasses.
Substituindo a representação padrão para impressão
Você pode substituir esse operador de conversão de string padrão para tipos de extensão. O exemplo a seguir atualiza a classe MaskedTensor
para gerar uma representação de cadeia de caracteres mais legível quando os valores são impressos no modo Eager.
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for invalid values.
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def masked_tensor_str(values, mask):
if isinstance(values, tf.Tensor):
if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
else:
return f'MaskedTensor(values={values}, mask={mask})'
if len(values.shape) == 1:
items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
else:
items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
return '[%s]' % ', '.join(items)
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>
Definindo métodos
Os tipos de extensão podem definir métodos, assim como qualquer classe Python normal. Por exemplo, o tipo MaskedTensor
pode definir um método with_default
que retorna uma cópia de self
com valores mascarados substituídos por um determinado valor default
. Os métodos podem opcionalmente ser anotados com o decorador @tf.function
.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def with_default(self, default):
return tf.where(self.mask, self.values, default)
MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>
Definindo métodos de classe e métodos estáticos
Os tipos de extensão podem definir métodos usando os decoradores @classmethod
e @staticmethod
. Por exemplo, o tipo MaskedTensor
pode definir um método de fábrica que mascara qualquer elemento com um determinado valor:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
@staticmethod
def from_tensor_and_value_to_mask(values, value_to_mask):
return MaskedTensor(values, values == value_to_mask)
x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[_, 0, _], [_, 0, 0]]>
Definindo propriedades
Os tipos de extensão podem definir propriedades usando o decorador @property
, assim como qualquer classe normal do Python. Por exemplo, o tipo MaskedTensor
pode definir uma propriedade dtype
que é um atalho para o dtype dos valores:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@property
def dtype(self):
return self.values.dtype
MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32
Substituindo o construtor padrão
Você pode substituir o construtor padrão para tipos de extensão. Construtores personalizados devem definir um valor para cada campo declarado; e depois que o construtor personalizado retornar, todos os campos serão verificados por tipo e os valores serão convertidos conforme descrito acima.
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
def __init__(self, name, price, discount=0):
self.name = name
self.price = price * (1 - discount)
print(Toy("ball", 5.0, discount=0.2)) # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
Como alternativa, você pode considerar deixar o construtor padrão como está, mas adicionar um ou mais métodos de fábrica. Por exemplo:
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
@staticmethod
def new_toy_with_discount(name, price, discount):
return Toy(name, price * (1 - discount))
print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
Substituindo o operador de igualdade padrão ( __eq__
)
Você pode substituir o operador __eq__
padrão para tipos de extensão. O exemplo a seguir atualiza MaskedTensor
para ignorar elementos mascarados ao comparar a igualdade.
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def __eq__(self, other):
result = tf.math.equal(self.values, other.values)
result = result | ~(self.mask & other.mask)
return tf.reduce_all(result)
x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)
Usando referências diretas
Se o tipo de um campo ainda não foi definido, você pode usar uma string contendo o nome do tipo. No exemplo a seguir, a string "Node"
é usada para anotar o campo children
porque o tipo Node
ainda não foi (totalmente) definido.
class Node(tf.experimental.ExtensionType):
value: tf.Tensor
children: Tuple["Node", ...] = ()
Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))
Definindo subclasses
Os tipos de extensão podem ser subclassificados usando a sintaxe padrão do Python. As subclasses do tipo de extensão podem adicionar novos campos, métodos e propriedades; e pode substituir o construtor, a representação imprimível e o operador de igualdade. O exemplo a seguir define uma classe TensorGraph
básica que usa três campos Tensor
para codificar um conjunto de arestas entre nós. Em seguida, ele define uma subclasse que adiciona um campo Tensor
para registrar um "valor de recurso" para cada nó. A subclasse também define um método para propagar os valores dos recursos ao longo das arestas.
class TensorGraph(tf.experimental.ExtensionType):
num_nodes: tf.Tensor
edge_src: tf.Tensor # edge_src[e] = index of src node for edge e.
edge_dst: tf.Tensor # edge_dst[e] = index of dst node for edge e.
class TensorGraphWithNodeFeature(TensorGraph):
node_features: tf.Tensor # node_features[n] = feature value for node n.
def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
updates = tf.gather(self.node_features, self.edge_src) * weight
new_node_features = tf.tensor_scatter_nd_add(
self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
return TensorGraphWithNodeFeature(
self.num_nodes, self.edge_src, self.edge_dst, new_node_features)
g = TensorGraphWithNodeFeature( # Edges: 0->1, 4->3, 2->2, 2->1
num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])
print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10. 0. 2. 5. -1. 0.], shape=(6,), dtype=float32) After propagating: tf.Tensor([10. 12. 4. 4. -1. 0.], shape=(6,), dtype=float32)
Definindo campos privados
Os campos de um tipo de extensão podem ser marcados como privados prefixando-os com um sublinhado (seguindo as convenções padrão do Python). Isso não afeta a maneira como o TensorFlow trata os campos de forma alguma; mas simplesmente serve como um sinal para qualquer usuário do tipo de extensão de que esses campos são privados.
Personalizando o TypeSpec
do ExtensionType
Cada classe ExtensionType
tem uma classe TypeSpec
correspondente, que é criada automaticamente e armazenada como <extension_type_name>.Spec
. Para obter mais informações, consulte a seção "Nested TypeSpec" acima.
Para personalizar o TypeSpec
, simplesmente defina sua própria classe aninhada chamada Spec
, e ExtensionType
usará isso como base para o TypeSpec
construído automaticamente. Você pode personalizar a classe Spec
por:
- Substituindo a representação imprimível padrão.
- Substituindo o construtor padrão.
- Definindo métodos, métodos de classe, métodos estáticos e propriedades.
O exemplo a seguir personaliza a classe MaskedTensor.Spec
para facilitar o uso:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def with_values(self, new_values):
return MaskedTensor(new_values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
def __repr__(self):
return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
Despacho da API do tensor
Os tipos de extensão podem ser "tipo tensor", no sentido de que se especializam ou estendem a interface definida pelo tipo tf.Tensor
. Exemplos de tipos de extensão do tipo tensor incluem RaggedTensor
, SparseTensor
e MaskedTensor
. Decoradores de despacho podem ser usados para substituir o comportamento padrão das operações do TensorFlow quando aplicados a tipos de extensão do tipo tensor. Atualmente, o TensorFlow define três decoradores de despacho:
-
@tf.experimental.dispatch_for_api(tf_api)
-
@tf.experimental.dispatch_for_unary_elementwise_api(x_type)
-
@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)
Despacho para uma única API
O decorador tf.experimental.dispatch_for_api
substitui o comportamento padrão de uma operação TensorFlow especificada quando é chamada com a assinatura especificada. Por exemplo, você pode usar este decorador para especificar como o tf.stack
deve processar os valores do MaskedTensor
:
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
Isso substitui a implementação padrão para tf.stack
sempre que é chamado com uma lista de valores MaskedTensor
(já que o argumento values
é anotado com typing.List[MaskedTensor]
):
x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>
Para permitir que o tf.stack
com listas de valores mistos de MaskedTensor
e Tensor
, você pode refinar a anotação de tipo para o parâmetro de values
e atualizar o corpo da função adequadamente:
tf.experimental.unregister_dispatch_for(masked_stack)
def convert_to_masked_tensor(x):
if isinstance(x, MaskedTensor):
return x
else:
return MaskedTensor(x, tf.ones_like(x, tf.bool))
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
values = [convert_to_masked_tensor(v) for v in values]
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>
Para obter uma lista de APIs que podem ser substituídas, consulte a documentação da API para tf.experimental.dispatch_for_api
.
Despacho para todas as APIs elementwise unárias
O decorador tf.experimental.dispatch_for_unary_elementwise_apis
substitui o comportamento padrão de todas as operações elementares unárias (como tf.math.cos
) sempre que o valor do primeiro argumento (normalmente denominado x
) corresponder à anotação de tipo x_type
. A função decorada deve receber dois argumentos:
-
api_func
: Uma função que recebe um único parâmetro e executa a operação elementwise (por exemplo,tf.abs
). -
x
: O primeiro argumento para a operação elementwise.
O exemplo a seguir atualiza todas as operações elementares unárias para lidar com o tipo MaskedTensor
:
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def masked_tensor_unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
Esta função agora será usada sempre que uma operação elementar unária for chamada em um MaskedTensor
.
x = MaskedTensor([1, -2, -3], [True, False, True])
print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>
Despacho para todas as APIs elementwise binárias
Da mesma forma, tf.experimental.dispatch_for_binary_elementwise_apis
pode ser usado para atualizar todas as operações elementares binárias para lidar com o tipo MaskedTensor
:
@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>
Para obter uma lista das APIs elementwise que foram substituídas, consulte a documentação da API para tf.experimental.dispatch_for_unary_elementwise_apis
e tf.experimental.dispatch_for_binary_elementwise_apis
.
Tipos de extensão em lote
Um ExtensionType
é em lote se uma única instância puder ser usada para representar um lote de valores. Normalmente, isso é feito adicionando dimensões de lote a todos os Tensor
aninhados. As seguintes APIs do TensorFlow exigem que todas as entradas de tipo de extensão sejam em lote:
-
tf.data.Dataset
(batch
,unbatch
,from_tensor_slices
) -
tf.Keras
(fit
,evaluate
,predict
) -
tf.map_fn
Por padrão, BatchableExtensionType
cria valores em lote agrupando quaisquer Tensor
s, CompositeTensor
s e ExtensionType
s aninhados. Se isso não for apropriado para sua classe, você precisará usar tf.experimental.ExtensionTypeBatchEncoder
para substituir esse comportamento padrão. Por exemplo, não seria apropriado criar um lote de valores tf.SparseTensor
simplesmente empilhando values
de tensores esparsos individuais , indices
e campos de dense_shape
-- na maioria dos casos, você não pode empilhar esses tensores, pois eles têm formas incompatíveis ; e mesmo que você pudesse, o resultado não seria um SparseTensor
válido.
Exemplo de BatchableExtensionType: rede
Como exemplo, considere uma classe de Network
simples usada para balanceamento de carga, que rastreia quanto trabalho resta a ser feito em cada nó e quanta largura de banda está disponível para mover o trabalho entre nós:
class Network(tf.experimental.ExtensionType): # This version is not batchable.
work: tf.Tensor # work[n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[n1, n2] = bandwidth from n1->n2
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
Para tornar esse tipo em lote, altere o tipo base para BatchableExtensionType
e ajuste a forma de cada campo para incluir dimensões de lote opcionais. O exemplo a seguir também adiciona um campo de shape
para acompanhar a forma do lote. Este campo de shape
não é requerido por tf.data.Dataset
ou tf.map_fn
, mas é requerido por tf.Keras
.
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
def network_repr(network):
work = network.work
bandwidth = network.bandwidth
if hasattr(work, 'numpy'):
work = ' '.join(str(work.numpy()).split())
if hasattr(bandwidth, 'numpy'):
bandwidth = ' '.join(str(bandwidth.numpy()).split())
return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
work=tf.stack([net1.work, net2.work]),
bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]> batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
Você pode usar tf.data.Dataset
para iterar por meio de um lote de redes:
dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
E você também pode usar map_fn
para aplicar uma função a cada elemento de lote:
def balance_work_greedy(network):
delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
delta /= 4
delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
new_work = network.work + tf.reduce_sum(delta, -1)
return Network(new_work, network.bandwidth)
tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
APIs do TensorFlow compatíveis com ExtensionTypes
@tf.função
tf.function é um decorador que pré-computa gráficos do TensorFlow para funções do Python, o que pode melhorar substancialmente o desempenho do seu código do TensorFlow. Os valores de tipo de extensão podem ser usados de forma transparente com funções @tf.function
-decorated.
class Pastry(tf.experimental.ExtensionType):
sweetness: tf.Tensor # 2d embedding that encodes sweetness
chewiness: tf.Tensor # 2d embedding that encodes chewiness
@tf.function
def combine_pastry_features(x: Pastry):
return (x.sweetness + x.chewiness) / 2
cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
Se você deseja especificar explicitamente o input_signature
para tf.function
, pode fazê-lo usando TypeSpec
do tipo de extensão.
pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))
@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
return Pastry(x.sweetness + delta, x.chewiness)
increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)
Funções concretas
As funções concretas encapsulam gráficos rastreados individuais que são construídos por tf.function
. Os tipos de extensão podem ser usados de forma transparente com funções concretas.
cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
Operações de fluxo de controle
Os tipos de extensão são compatíveis com as operações de fluxo de controle do TensorFlow:
# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>
Fluxo de controle de autógrafos
Os tipos de extensão também são suportados por instruções de fluxo de controle em tf.function (usando autógrafo). No exemplo a seguir, as instruções if
e for
são convertidas automaticamente em operações tf.cond
e tf.while_loop
, que suportam tipos de extensão.
@tf.function
def fn(x, b):
if b:
x = MaskedTensor(x, tf.less(x, 0))
else:
x = MaskedTensor(x, tf.greater(x, 0))
for i in tf.range(5 if b else 7):
x = x.with_values(x.values + 1 / 2)
return x
print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]> <MaskedTensor [4.5, _, 6.5]>
Keras
tf.keras é a API de alto nível do TensorFlow para criar e treinar modelos de aprendizado profundo. Os tipos de extensão podem ser passados como entradas para um modelo Keras, passados entre camadas Keras e retornados por modelos Keras. Keras atualmente coloca dois requisitos nos tipos de extensão:
- Eles devem ser em lote (consulte "Extensões em lote" acima).
- O deve ter um campo ou propriedade denominado
shape
.shape[0]
é considerado a dimensão do lote.
As duas subseções a seguir fornecem exemplos mostrando como os tipos de extensão podem ser usados com o Keras.
Exemplo Keras: Network
Para o primeiro exemplo, considere a classe Network
definida na seção "Batchable ExtensionTypes" acima, que pode ser usada para trabalho de balanceamento de carga entre nós. Sua definição é repetida aqui:
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
single_network = Network( # A single network w/ 4 nodes.
work=[8.0, 5, 12, 2],
bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])
batch_of_networks = Network( # Batch of 2 networks, each w/ 2 nodes.
work=[[8.0, 5], [3, 2]],
bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])
Você pode definir uma nova camada Keras que processa Network
s.
class BalanceNetworkLayer(tf.keras.layers.Layer):
"""Layer that balances work between nodes in a network.
Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
"""
def call(self, inputs):
# This function is defined above, in "Batchable ExtensionTypes" section.
return balance_work_greedy(inputs)
Você pode então usar essas camadas para criar um modelo simples. Para alimentar um ExtensionType
em um modelo, você pode usar uma camada tf.keras.layer.Input
com type_spec
definido como TypeSpec
do tipo de extensão. Se o modelo Keras for usado para processar lotes, o type_spec
deverá incluir a dimensão do lote.
input_spec = Network.Spec(shape=None,
work=tf.TensorSpec(None, tf.float32),
bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
BalanceNetworkLayer(),
])
Por fim, você pode aplicar o modelo a uma única rede e a um lote de redes.
model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>
Exemplo Keras: MaskedTensor
Neste exemplo, MaskedTensor
é estendido para dar suporte a Keras
. shape
é definido como uma propriedade que é calculada a partir do campo de values
. Keras requer que você adicione essa propriedade ao tipo de extensão e seu TypeSpec
. MaskedTensor
também define uma variável __name__
, que será necessária para a serialização de SavedModel
(abaixo).
class MaskedTensor(tf.experimental.BatchableExtensionType):
# __name__ is required for serialization in SavedModel; see below for details.
__name__ = 'extension_type_colab.MaskedTensor'
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_default(self, default):
return tf.where(self.mask, self.values, default)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_shape(self):
return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
tf.TensorSpec(shape, self.mask.dtype))
Em seguida, os decoradores de expedição são usados para substituir o comportamento padrão de várias APIs do TensorFlow. Como essas APIs são usadas por camadas Keras padrão (como a camada Dense
), substituí-las nos permitirá usar essas camadas com MaskedTensor
. Para os propósitos deste exemplo, matmul
para tensores mascarados é definido para tratar os valores mascarados como zeros (ou seja, para não incluí-los no produto).
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
return MaskedTensor(op(x.values), x.mask)
@tf.experimental.dispatch_for_binary_elementwise_apis(
Union[MaskedTensor, tf.Tensor],
Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
x = convert_to_masked_tensor(x)
y = convert_to_masked_tensor(y)
return MaskedTensor(op(x.values, y.values), x.mask & y.mask)
@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
transpose_a=False, transpose_b=False,
adjoint_a=False, adjoint_b=False,
a_is_sparse=False, b_is_sparse=False,
output_type=None):
if isinstance(a, MaskedTensor):
a = a.with_default(0)
if isinstance(b, MaskedTensor):
b = b.with_default(0)
return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
adjoint_b, a_is_sparse, b_is_sparse, output_type)
Você pode então construir um modelo Keras que aceite entradas MaskedTensor
, usando camadas Keras padrão:
input_spec = MaskedTensor.Spec([None, 2], tf.float32)
masked_tensor_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
[[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3 1/1 [==============================] - 1s 955ms/step - loss: 10.2833 Epoch 2/3 1/1 [==============================] - 0s 5ms/step - loss: 10.2833 Epoch 3/3 1/1 [==============================] - 0s 5ms/step - loss: 10.2833 tf.Tensor( [[-0.09944128] [-0.7225147 ] [-1.3020657 ]], shape=(3, 1), dtype=float32)
Modelo salvo
Um SavedModel é um programa TensorFlow serializado, incluindo pesos e computação. Ele pode ser construído a partir de um modelo Keras ou de um modelo personalizado. Em ambos os casos, os tipos de extensão podem ser usados de forma transparente com as funções e métodos definidos por um SavedModel.
SavedModel pode salvar modelos, camadas e funções que processam tipos de extensão, desde que os tipos de extensão tenham um campo __name__
. Este nome é usado para registrar o tipo de extensão, para que possa ser localizado quando o modelo for carregado.
Exemplo: salvar um modelo Keras
Os modelos Keras que usam tipos de extensão podem ser salvos usando SavedModel
.
masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
2021-11-06 01:25:14.285250: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel. INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets <tf.Tensor: shape=(3, 1), dtype=float32, numpy= array([[-0.09944128], [-0.7225147 ], [-1.3020657 ]], dtype=float32)>
Exemplo: salvar um modelo personalizado
SavedModel também pode ser usado para salvar subclasses tf.Module
personalizadas com funções que processam tipos de extensão.
class CustomModule(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def grow(self, x: MaskedTensor):
"""Increase values in `x` by multiplying them by `self.v`."""
return MaskedTensor(x.values * self.v, x.mask)
module = CustomModule(100.0)
module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets <MaskedTensor [_, 200.0, _]>
Carregando um SavedModel quando o ExtensionType não está disponível
Se você carregar um SavedModel
que usa um ExtensionType
, mas esse ExtensionType
não está disponível (ou seja, não foi importado), você verá um aviso e o TensorFlow voltará a usar um objeto "tipo de extensão anônimo". Esse objeto terá os mesmos campos que o tipo original, mas não terá qualquer personalização adicional que você tenha adicionado ao tipo, como métodos ou propriedades personalizados.
Como usar ExtensionTypes com veiculação do TensorFlow
Atualmente, a veiculação do TensorFlow (e outros consumidores do dicionário de "assinaturas" SavedModel) exige que todas as entradas e saídas sejam tensores brutos. Se você deseja usar o TensorFlow servindo com um modelo que usa tipos de extensão, você pode adicionar métodos wrapper que compõem ou decompõem valores de tipo de extensão de tensores. Por exemplo:
class CustomModuleWrapper(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def var_weighted_mean(self, x: MaskedTensor):
"""Mean value of unmasked values in x, weighted by self.v."""
x = MaskedTensor(x.values * self.v, x.mask)
return (tf.reduce_sum(x.with_default(0)) /
tf.reduce_sum(tf.cast(x.mask, x.dtype)))
@tf.function()
def var_weighted_mean_wrapper(self, x_values, x_mask):
"""Raw tensor wrapper for var_weighted_mean."""
return self.var_weighted_mean(MaskedTensor(x_values, x_mask))
module = CustomModuleWrapper([3., 2., 8., 5.])
module.var_weighted_mean_wrapper.get_concrete_function(
tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets <tf.Tensor: shape=(), dtype=float32, numpy=12.0>
Conjuntos de dados
tf.data é uma API que permite criar pipelines de entrada complexos a partir de peças simples e reutilizáveis. Sua estrutura de dados central é tf.data.Dataset
, que representa uma sequência de elementos, na qual cada elemento consiste em um ou mais componentes.
Como criar conjuntos de dados com tipos de extensão
Os conjuntos de dados podem ser criados a partir de valores de tipo de extensão usando Dataset.from_tensors
, Dataset.from_tensor_slices
ou Dataset.from_generator
:
ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
print(value)
<MaskedTensor [0, 1, 2, 3]> <MaskedTensor [4, 5, 6, 7]> <MaskedTensor [8, 9, 10, 11]> <MaskedTensor [12, 13, 14, 15]> <MaskedTensor [16, 17, 18, 19]>
def value_gen():
for i in range(2, 7):
yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])
ds = tf.data.Dataset.from_generator(
value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>
Lote e unbatch de conjuntos de dados com tipos de extensão
Conjuntos de dados com tipos de extensão podem ser batche unbatch usando Dataset.batch
e Dataset.unbatch
.
batched_ds = ds.batch(2)
for value in batched_ds:
print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]> <MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]> <MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>