Tipos de extensão

Veja no TensorFlow.org Executar no Google Colab Ver fonte no GitHubBaixar caderno

Configurar

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile

Tipos de extensão

Os tipos definidos pelo usuário podem tornar os projetos mais legíveis, modulares e de fácil manutenção. No entanto, a maioria das APIs do TensorFlow tem suporte muito limitado para tipos Python definidos pelo usuário. Isso inclui APIs de alto nível (como Keras , tf.function , tf.SavedModel ) e APIs de nível inferior (como tf.while_loop e tf.concat ). Os tipos de extensão do TensorFlow podem ser usados ​​para criar tipos orientados a objetos definidos pelo usuário que funcionam perfeitamente com as APIs do TensorFlow. Para criar um tipo de extensão, simplesmente defina uma classe Python com tf.experimental.ExtensionType como base e use anotações de tipo para especificar o tipo de cada campo.

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

A classe base tf.experimental.ExtensionType funciona de forma semelhante à typing.NamedTuple e @dataclasses.dataclass da biblioteca padrão do Python. Em particular, ele adiciona automaticamente um construtor e métodos especiais (como __repr__ e __eq__ ) com base nas anotações do tipo de campo.

Normalmente, os tipos de extensão tendem a se enquadrar em uma das duas categorias:

  • Estruturas de dados , que agrupam uma coleção de valores relacionados e podem fornecer operações úteis com base nesses valores. As estruturas de dados podem ser bastante gerais (como o exemplo do TensorGraph acima); ou podem ser altamente personalizados para um modelo específico.

  • Tipos do tipo tensor , que especializam ou estendem o conceito de "tensor". Tipos nesta categoria têm uma rank , uma shape e geralmente um dtype ; e faz sentido usá-los com operações de tensor (como tf.stack , tf.add ou tf.matmul ). MaskedTensor e CSRSparseMatrix são exemplos de tipos do tipo tensor.

APIs compatíveis

Os tipos de extensão são compatíveis com as seguintes APIs do TensorFlow:

  • Keras : Os tipos de extensão podem ser usados ​​como entradas e saídas para Models e Layers Keras.
  • tf.data.Dataset : Os tipos de extensão podem ser incluídos em Datasets e retornados por Iterators de dataset.
  • Hub Tensorflow : Os tipos de extensão podem ser usados ​​como entradas e saídas para módulos tf.hub .
  • SavedModel : Os tipos de extensão podem ser usados ​​como entradas e saídas para funções SavedModel .
  • tf.function : Os tipos de extensão podem ser usados ​​como argumentos e valores de retorno para funções agrupadas com o decorador @tf.function .
  • while loops : Os tipos de extensão podem ser usados ​​como variáveis ​​de loop em tf.while_loop , e podem ser usados ​​como argumentos e valores de retorno para o corpo do while-loop.
  • condicionais : Os tipos de extensão podem ser selecionados condicionalmente usando tf.cond e tf.case .
  • py_function : Os tipos de extensão podem ser usados ​​como argumentos e valores de retorno para o argumento func para tf.py_function .
  • Operações de tensor : os tipos de extensão podem ser estendidos para oferecer suporte à maioria das operações de TensorFlow que aceitam entradas de tensor (por exemplo, tf.matmul , tf.gather e tf.reduce_sum ). Consulte a seção " Despacho " abaixo para obter mais informações.
  • estratégia de distribuição : os tipos de extensão podem ser usados ​​como valores por réplica.

Para obter mais detalhes, consulte a seção "APIs do TensorFlow compatíveis com ExtensionTypes" abaixo.

Requisitos

Tipos de campo

Todos os campos (também conhecidos como variáveis ​​de instância) devem ser declarados e uma anotação de tipo deve ser fornecida para cada campo. As seguintes anotações de tipo são compatíveis:

Modelo Exemplo
Inteiros Python i: int
Python flutua f: float
Strings Python s: str
Booleanos Python b: bool
Python Nenhum n: None
Formas de tensor shape: tf.TensorShape
Tipos de tensor dtype: tf.DType
Tensores t: tf.Tensor
Tipos de extensão mt: MyMaskedTensor
Tensores irregulares rt: tf.RaggedTensor
Tensores Esparsos st: tf.SparseTensor
Fatias Indexadas s: tf.IndexedSlices
Tensores Opcionais o: tf.experimental.Optional
Tipo de uniões int_or_float: typing.Union[int, float]
Tuplas params: typing.Tuple[int, float, tf.Tensor, int]
Tuplas de comprimento de var lengths: typing.Tuple[int, ...]
Mapeamentos tags: typing.Mapping[str, tf.Tensor]
Valores opcionais weight: typing.Optional[tf.Tensor]

Mutabilidade

Os tipos de extensão devem ser imutáveis. Isso garante que eles possam ser rastreados adequadamente pelos mecanismos de rastreamento de gráficos do TensorFlow. Se você quiser alterar um valor de tipo de extensão, considere definir métodos que transformam valores. Por exemplo, em vez de definir um método set_mask para alterar um MaskedTensor , você pode definir um método replace_mask que retorna um novo MaskedTensor :

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

Funcionalidade adicionada por ExtensionType

A classe base ExtensionType fornece a seguinte funcionalidade:

  • Um construtor ( __init__ ).
  • Um método de representação imprimível ( __repr__ ).
  • Operadores de igualdade e desigualdade ( __eq__ ).
  • Um método de validação ( __validate__ ).
  • Imutabilidade imposta.
  • Um TypeSpec aninhado.
  • Suporte de envio da API do tensor.

Consulte a seção "Personalizando ExtensionTypes" abaixo para obter mais informações sobre como personalizar essa funcionalidade.

Construtor

O construtor adicionado por ExtensionType recebe cada campo como um argumento nomeado (na ordem em que foram listados na definição da classe). Esse construtor verificará cada parâmetro e os converterá quando necessário. Em particular, os campos de Tensor são convertidos usando tf.convert_to_tensor ; Tuple campos de tupla são convertidos em tuple s; e os campos de Mapping são convertidos em dicts imutáveis.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# E.g., mt.values is converted to a Tensor.
print(mt.values)
tf.Tensor(
[[1 2 3]
 [4 5 6]], shape=(2, 3), dtype=int32)

O construtor gera um TypeError se um valor de campo não puder ser convertido em seu tipo declarado:

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got None

O valor padrão para um campo pode ser especificado definindo seu valor no nível de classe:

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)

Representação para impressão

ExtensionType adiciona um método de representação padrão para impressão ( __repr__ ) que inclui o nome da classe e o valor de cada campo:

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True,  True, False])>)

Operadores de igualdade

ExtensionType adiciona operadores de igualdade padrão ( __eq__ e __ne__ ) que consideram dois valores iguais se tiverem o mesmo tipo e todos os seus campos forem iguais. Os campos de tensor são considerados iguais se tiverem a mesma forma e forem iguais em todos os elementos.

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True
a == b: False
a == a.values: False

Método de validação

ExtensionType adiciona um método __validate__ , que pode ser substituído para realizar verificações de validação em campos. Ele é executado depois que o construtor é chamado e depois que os campos são verificados e convertidos em seus tipos declarados, para que possa assumir que todos os campos têm seus tipos declarados.

O exemplo a seguir atualiza o MaskedTensor para validar a shape dtype s de seus campos:

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # wrong dtype for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible

Imutabilidade imposta

ExtensionType substitui os métodos __setattr__ e __delattr__ para evitar a mutação, garantindo que os valores do tipo de extensão sejam imutáveis.

mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.

Especificação de tipo aninhado

Cada classe ExtensionType tem uma classe TypeSpec correspondente, que é criada automaticamente e armazenada como <extension_type_name>.Spec .

Essa classe captura todas as informações de um valor, exceto os valores de quaisquer tensores aninhados. Em particular, o TypeSpec para um valor é criado substituindo qualquer Tensor, ExtensionType ou CompositeTensor aninhado por seu TypeSpec .

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records dtype and shape, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
TensorSpec(shape=(), dtype=tf.string, name=None)
ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})

Os valores TypeSpec podem ser construídos explicitamente ou podem ser construídos a partir de um valor ExtensionType usando tf.type_spec_from_value :

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TypeSpec s são usados ​​pelo TensorFlow para dividir valores em um componente estático e um componente dinâmico :

  • O componente estático (que é fixado no tempo de construção do gráfico) é codificado com um tf.TypeSpec .
  • O componente dinâmico (que pode variar cada vez que o gráfico é executado) é codificado como uma lista de tf.Tensor s.

Por exemplo, tf.function refaz sua função encapsulada sempre que um argumento tem um TypeSpec não visto anteriormente:

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))

Para obter mais informações, consulte o Guia tf.function .

Personalizando tipos de extensão

Além de simplesmente declarar campos e seus tipos, os tipos de extensão podem:

  • Substitua a representação imprimível padrão ( __repr__ ).
  • Defina métodos.
  • Defina métodos de classe e métodos estáticos.
  • Defina propriedades.
  • Substitua o construtor padrão ( __init__ ).
  • Substitua o operador de igualdade padrão ( __eq__ ).
  • Defina operadores (como __add__ e __lt__ ).
  • Declare valores padrão para campos.
  • Defina subclasses.

Substituindo a representação padrão para impressão

Você pode substituir esse operador de conversão de string padrão para tipos de extensão. O exemplo a seguir atualiza a classe MaskedTensor para gerar uma representação de cadeia de caracteres mais legível quando os valores são impressos no modo Eager.

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>

Definindo métodos

Os tipos de extensão podem definir métodos, assim como qualquer classe Python normal. Por exemplo, o tipo MaskedTensor pode definir um método with_default que retorna uma cópia de self com valores mascarados substituídos por um determinado valor default . Os métodos podem opcionalmente ser anotados com o decorador @tf.function .

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>

Definindo métodos de classe e métodos estáticos

Os tipos de extensão podem definir métodos usando os decoradores @classmethod e @staticmethod . Por exemplo, o tipo MaskedTensor pode definir um método de fábrica que mascara qualquer elemento com um determinado valor:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values == value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[_, 0, _], [_, 0, 0]]>

Definindo propriedades

Os tipos de extensão podem definir propriedades usando o decorador @property , assim como qualquer classe normal do Python. Por exemplo, o tipo MaskedTensor pode definir uma propriedade dtype que é um atalho para o dtype dos valores:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32

Substituindo o construtor padrão

Você pode substituir o construtor padrão para tipos de extensão. Construtores personalizados devem definir um valor para cada campo declarado; e depois que o construtor personalizado retornar, todos os campos serão verificados por tipo e os valores serão convertidos conforme descrito acima.

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

Como alternativa, você pode considerar deixar o construtor padrão como está, mas adicionar um ou mais métodos de fábrica. Por exemplo:

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

Substituindo o operador de igualdade padrão ( __eq__ )

Você pode substituir o operador __eq__ padrão para tipos de extensão. O exemplo a seguir atualiza MaskedTensor para ignorar elementos mascarados ao comparar a igualdade.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)

Usando referências diretas

Se o tipo de um campo ainda não foi definido, você pode usar uma string contendo o nome do tipo. No exemplo a seguir, a string "Node" é usada para anotar o campo children porque o tipo Node ainda não foi (totalmente) definido.

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))

Definindo subclasses

Os tipos de extensão podem ser subclassificados usando a sintaxe padrão do Python. As subclasses do tipo de extensão podem adicionar novos campos, métodos e propriedades; e pode substituir o construtor, a representação imprimível e o operador de igualdade. O exemplo a seguir define uma classe TensorGraph básica que usa três campos Tensor para codificar um conjunto de arestas entre nós. Em seguida, ele define uma subclasse que adiciona um campo Tensor para registrar um "valor de recurso" para cada nó. A subclasse também define um método para propagar os valores dos recursos ao longo das arestas.

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10.  0.  2.  5. -1.  0.], shape=(6,), dtype=float32)
After propagating: tf.Tensor([10. 12.  4.  4. -1.  0.], shape=(6,), dtype=float32)

Definindo campos privados

Os campos de um tipo de extensão podem ser marcados como privados prefixando-os com um sublinhado (seguindo as convenções padrão do Python). Isso não afeta a maneira como o TensorFlow trata os campos de forma alguma; mas simplesmente serve como um sinal para qualquer usuário do tipo de extensão de que esses campos são privados.

Personalizando o TypeSpec do ExtensionType

Cada classe ExtensionType tem uma classe TypeSpec correspondente, que é criada automaticamente e armazenada como <extension_type_name>.Spec . Para obter mais informações, consulte a seção "Nested TypeSpec" acima.

Para personalizar o TypeSpec , simplesmente defina sua própria classe aninhada chamada Spec , e ExtensionType usará isso como base para o TypeSpec construído automaticamente. Você pode personalizar a classe Spec por:

  • Substituindo a representação imprimível padrão.
  • Substituindo o construtor padrão.
  • Definindo métodos, métodos de classe, métodos estáticos e propriedades.

O exemplo a seguir personaliza a classe MaskedTensor.Spec para facilitar o uso:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

Despacho da API do tensor

Os tipos de extensão podem ser "tipo tensor", no sentido de que se especializam ou estendem a interface definida pelo tipo tf.Tensor . Exemplos de tipos de extensão do tipo tensor incluem RaggedTensor , SparseTensor e MaskedTensor . Decoradores de despacho podem ser usados ​​para substituir o comportamento padrão das operações do TensorFlow quando aplicados a tipos de extensão do tipo tensor. Atualmente, o TensorFlow define três decoradores de despacho:

Despacho para uma única API

O decorador tf.experimental.dispatch_for_api substitui o comportamento padrão de uma operação TensorFlow especificada quando é chamada com a assinatura especificada. Por exemplo, você pode usar este decorador para especificar como o tf.stack deve processar os valores do MaskedTensor :

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

Isso substitui a implementação padrão para tf.stack sempre que é chamado com uma lista de valores MaskedTensor (já que o argumento values é anotado com typing.List[MaskedTensor] ):

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>

Para permitir que o tf.stack com listas de valores mistos de MaskedTensor e Tensor , você pode refinar a anotação de tipo para o parâmetro de values e atualizar o corpo da função adequadamente:

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>

Para obter uma lista de APIs que podem ser substituídas, consulte a documentação da API para tf.experimental.dispatch_for_api .

Despacho para todas as APIs elementwise unárias

O decorador tf.experimental.dispatch_for_unary_elementwise_apis substitui o comportamento padrão de todas as operações elementares unárias (como tf.math.cos ) sempre que o valor do primeiro argumento (normalmente denominado x ) corresponder à anotação de tipo x_type . A função decorada deve receber dois argumentos:

  • api_func : Uma função que recebe um único parâmetro e executa a operação elementwise (por exemplo, tf.abs ).
  • x : O primeiro argumento para a operação elementwise.

O exemplo a seguir atualiza todas as operações elementares unárias para lidar com o tipo MaskedTensor :

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

Esta função agora será usada sempre que uma operação elementar unária for chamada em um MaskedTensor .

x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>

Despacho para todas as APIs elementwise binárias

Da mesma forma, tf.experimental.dispatch_for_binary_elementwise_apis pode ser usado para atualizar todas as operações elementares binárias para lidar com o tipo MaskedTensor :

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>

Para obter uma lista das APIs elementwise que foram substituídas, consulte a documentação da API para tf.experimental.dispatch_for_unary_elementwise_apis e tf.experimental.dispatch_for_binary_elementwise_apis .

Tipos de extensão em lote

Um ExtensionType é em lote se uma única instância puder ser usada para representar um lote de valores. Normalmente, isso é feito adicionando dimensões de lote a todos os Tensor aninhados. As seguintes APIs do TensorFlow exigem que todas as entradas de tipo de extensão sejam em lote:

Por padrão, BatchableExtensionType cria valores em lote agrupando quaisquer Tensor s, CompositeTensor s e ExtensionType s aninhados. Se isso não for apropriado para sua classe, você precisará usar tf.experimental.ExtensionTypeBatchEncoder para substituir esse comportamento padrão. Por exemplo, não seria apropriado criar um lote de valores tf.SparseTensor simplesmente empilhando values de tensores esparsos individuais , indices e campos de dense_shape -- na maioria dos casos, você não pode empilhar esses tensores, pois eles têm formas incompatíveis ; e mesmo que você pudesse, o resultado não seria um SparseTensor válido.

Exemplo de BatchableExtensionType: rede

Como exemplo, considere uma classe de Network simples usada para balanceamento de carga, que rastreia quanto trabalho resta a ser feito em cada nó e quanta largura de banda está disponível para mover o trabalho entre nós:

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

Para tornar esse tipo em lote, altere o tipo base para BatchableExtensionType e ajuste a forma de cada campo para incluir dimensões de lote opcionais. O exemplo a seguir também adiciona um campo de shape para acompanhar a forma do lote. Este campo de shape não é requerido por tf.data.Dataset ou tf.map_fn , mas é requerido por tf.Keras .

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

Você pode usar tf.data.Dataset para iterar por meio de um lote de redes:

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>

E você também pode usar map_fn para aplicar uma função a cada elemento de lote:

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

APIs do TensorFlow compatíveis com ExtensionTypes

@tf.função

tf.function é um decorador que pré-computa gráficos do TensorFlow para funções do Python, o que pode melhorar substancialmente o desempenho do seu código do TensorFlow. Os valores de tipo de extensão podem ser usados ​​de forma transparente com funções @tf.function -decorated.

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

Se você deseja especificar explicitamente o input_signature para tf.function , pode fazê-lo usando TypeSpec do tipo de extensão.

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)

Funções concretas

As funções concretas encapsulam gráficos rastreados individuais que são construídos por tf.function . Os tipos de extensão podem ser usados ​​de forma transparente com funções concretas.

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

Operações de fluxo de controle

Os tipos de extensão são compatíveis com as operações de fluxo de controle do TensorFlow:

# Example: using tf.cond to select between two MaskedTensors.  Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>

Fluxo de controle de autógrafos

Os tipos de extensão também são suportados por instruções de fluxo de controle em tf.function (usando autógrafo). No exemplo a seguir, as instruções if e for são convertidas automaticamente em operações tf.cond e tf.while_loop , que suportam tipos de extensão.

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]>
<MaskedTensor [4.5, _, 6.5]>

Keras

tf.keras é a API de alto nível do TensorFlow para criar e treinar modelos de aprendizado profundo. Os tipos de extensão podem ser passados ​​como entradas para um modelo Keras, passados ​​entre camadas Keras e retornados por modelos Keras. Keras atualmente coloca dois requisitos nos tipos de extensão:

  • Eles devem ser em lote (consulte "Extensões em lote" acima).
  • O deve ter um campo ou propriedade denominado shape . shape[0] é considerado a dimensão do lote.

As duas subseções a seguir fornecem exemplos mostrando como os tipos de extensão podem ser usados ​​com o Keras.

Exemplo Keras: Network

Para o primeiro exemplo, considere a classe Network definida na seção "Batchable ExtensionTypes" acima, que pode ser usada para trabalho de balanceamento de carga entre nós. Sua definição é repetida aqui:

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)
single_network = Network(  # A single network w/ 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

Você pode definir uma nova camada Keras que processa Network s.

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above, in "Batchable ExtensionTypes" section.
    return balance_work_greedy(inputs)

Você pode então usar essas camadas para criar um modelo simples. Para alimentar um ExtensionType em um modelo, você pode usar uma camada tf.keras.layer.Input com type_spec definido como TypeSpec do tipo de extensão. Se o modelo Keras for usado para processar lotes, o type_spec deverá incluir a dimensão do lote.

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

Por fim, você pode aplicar o modelo a uma única rede e a um lote de redes.

model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>

Exemplo Keras: MaskedTensor

Neste exemplo, MaskedTensor é estendido para dar suporte a Keras . shape é definido como uma propriedade que é calculada a partir do campo de values . Keras requer que você adicione essa propriedade ao tipo de extensão e seu TypeSpec . MaskedTensor também define uma variável __name__ , que será necessária para a serialização de SavedModel (abaixo).

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

Em seguida, os decoradores de expedição são usados ​​para substituir o comportamento padrão de várias APIs do TensorFlow. Como essas APIs são usadas por camadas Keras padrão (como a camada Dense ), substituí-las nos permitirá usar essas camadas com MaskedTensor . Para os propósitos deste exemplo, matmul para tensores mascarados é definido para tratar os valores mascarados como zeros (ou seja, para não incluí-los no produto).

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

Você pode então construir um modelo Keras que aceite entradas MaskedTensor , usando camadas Keras padrão:

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3
1/1 [==============================] - 1s 955ms/step - loss: 10.2833
Epoch 2/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
Epoch 3/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
tf.Tensor(
[[-0.09944128]
 [-0.7225147 ]
 [-1.3020657 ]], shape=(3, 1), dtype=float32)

Modelo salvo

Um SavedModel é um programa TensorFlow serializado, incluindo pesos e computação. Ele pode ser construído a partir de um modelo Keras ou de um modelo personalizado. Em ambos os casos, os tipos de extensão podem ser usados ​​de forma transparente com as funções e métodos definidos por um SavedModel.

SavedModel pode salvar modelos, camadas e funções que processam tipos de extensão, desde que os tipos de extensão tenham um campo __name__ . Este nome é usado para registrar o tipo de extensão, para que possa ser localizado quando o modelo for carregado.

Exemplo: salvar um modelo Keras

Os modelos Keras que usam tipos de extensão podem ser salvos usando SavedModel .

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
2021-11-06 01:25:14.285250: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-0.09944128],
       [-0.7225147 ],
       [-1.3020657 ]], dtype=float32)>

Exemplo: salvar um modelo personalizado

SavedModel também pode ser usado para salvar subclasses tf.Module personalizadas com funções que processam tipos de extensão.

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
<MaskedTensor [_, 200.0, _]>

Carregando um SavedModel quando o ExtensionType não está disponível

Se você carregar um SavedModel que usa um ExtensionType , mas esse ExtensionType não está disponível (ou seja, não foi importado), você verá um aviso e o TensorFlow voltará a usar um objeto "tipo de extensão anônimo". Esse objeto terá os mesmos campos que o tipo original, mas não terá qualquer personalização adicional que você tenha adicionado ao tipo, como métodos ou propriedades personalizados.

Como usar ExtensionTypes com veiculação do TensorFlow

Atualmente, a veiculação do TensorFlow (e outros consumidores do dicionário de "assinaturas" SavedModel) exige que todas as entradas e saídas sejam tensores brutos. Se você deseja usar o TensorFlow servindo com um modelo que usa tipos de extensão, você pode adicionar métodos wrapper que compõem ou decompõem valores de tipo de extensão de tensores. Por exemplo:

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
<tf.Tensor: shape=(), dtype=float32, numpy=12.0>

Conjuntos de dados

tf.data é uma API que permite criar pipelines de entrada complexos a partir de peças simples e reutilizáveis. Sua estrutura de dados central é tf.data.Dataset , que representa uma sequência de elementos, na qual cada elemento consiste em um ou mais componentes.

Como criar conjuntos de dados com tipos de extensão

Os conjuntos de dados podem ser criados a partir de valores de tipo de extensão usando Dataset.from_tensors , Dataset.from_tensor_slices ou Dataset.from_generator :

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)
<MaskedTensor [0, 1, 2, 3]>
<MaskedTensor [4, 5, 6, 7]>
<MaskedTensor [8, 9, 10, 11]>
<MaskedTensor [12, 13, 14, 15]>
<MaskedTensor [16, 17, 18, 19]>
def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>

Lote e unbatch de conjuntos de dados com tipos de extensão

Conjuntos de dados com tipos de extensão podem ser batche unbatch usando Dataset.batch e Dataset.unbatch .

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]>
<MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]>
<MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>