Jenis ekstensi

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHubUnduh buku catatan

Mempersiapkan

!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile

Jenis ekstensi

Tipe yang ditentukan pengguna dapat membuat proyek lebih mudah dibaca, modular, dan dapat dipelihara. Namun, sebagian besar API TensorFlow memiliki dukungan yang sangat terbatas untuk jenis Python yang ditentukan pengguna. Ini mencakup API tingkat tinggi (seperti Keras , tf.function , tf.SavedModel ) dan API tingkat rendah (seperti tf.while_loop dan tf.concat ). Jenis ekstensi TensorFlow dapat digunakan untuk membuat jenis berorientasi objek yang ditentukan pengguna yang bekerja secara lancar dengan API TensorFlow. Untuk membuat tipe ekstensi, cukup definisikan kelas Python dengan tf.experimental.ExtensionType sebagai dasarnya, dan gunakan anotasi tipe untuk menentukan tipe untuk setiap bidang.

class TensorGraph(tf.experimental.ExtensionType):
  """A collection of labeled nodes connected by weighted edges."""
  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]
  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
  """Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
  values: tf.Tensor     # shape=[num_nonzero]; dtype=any
  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64
  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64

Kelas dasar tf.experimental.ExtensionType bekerja mirip dengan typing.NamedTuple dan @dataclasses.dataclass dari pustaka Python standar. Secara khusus, secara otomatis menambahkan konstruktor dan metode khusus (seperti __repr__ dan __eq__ ) berdasarkan anotasi jenis bidang.

Biasanya, jenis ekstensi cenderung jatuh ke dalam salah satu dari dua kategori:

  • Struktur data , yang mengelompokkan kumpulan nilai terkait, dan dapat menyediakan operasi yang berguna berdasarkan nilai tersebut. Struktur data mungkin cukup umum (seperti contoh TensorGraph di atas); atau mereka mungkin sangat disesuaikan dengan model tertentu.

  • Jenis seperti tensor , yang mengkhususkan atau memperluas konsep "Tensor". Tipe dalam kategori ini memiliki rank , shape , dan biasanya tipe dtype ; dan masuk akal untuk menggunakannya dengan operasi Tensor (seperti tf.stack , tf.add , atau tf.matmul ). MaskedTensor dan CSRSparseMatrix adalah contoh tipe seperti tensor.

API yang didukung

Jenis ekstensi didukung oleh TensorFlow API berikut:

  • Keras : Jenis ekstensi dapat digunakan sebagai input dan output untuk Models dan Layers Keras.
  • tf.data.Dataset : Jenis ekstensi dapat dimasukkan dalam Datasets , dan dikembalikan oleh Dataset Iterators .
  • Tensorflow hub : Jenis ekstensi dapat digunakan sebagai input dan output untuk modul tf.hub .
  • SavedModel : Jenis ekstensi dapat digunakan sebagai input dan output untuk fungsi SavedModel .
  • tf.function : Jenis ekstensi dapat digunakan sebagai argumen dan mengembalikan nilai untuk fungsi yang dibungkus dengan dekorator @tf.function .
  • while loop : Jenis ekstensi dapat digunakan sebagai variabel loop di tf.while_loop , dan dapat digunakan sebagai argumen dan mengembalikan nilai untuk badan while-loop.
  • conditional : Jenis ekstensi dapat dipilih secara kondisional menggunakan tf.cond dan tf.case .
  • py_function : Jenis ekstensi dapat digunakan sebagai argumen dan mengembalikan nilai argumen func ke tf.py_function .
  • Operasi Tensor : Jenis ekstensi dapat diperluas untuk mendukung sebagian besar operasi TensorFlow yang menerima input Tensor (misalnya, tf.matmul , tf.gather , dan tf.reduce_sum ). Lihat bagian " Pengiriman " di bawah ini untuk informasi lebih lanjut.
  • strategi distribusi : Jenis ekstensi dapat digunakan sebagai nilai per replika.

Untuk detail selengkapnya, lihat bagian "TensorFlow API yang mendukung ExtensionTypes" di bawah.

Persyaratan

Jenis bidang

Semua bidang (alias variabel instan) harus dideklarasikan, dan anotasi jenis harus disediakan untuk setiap bidang. Jenis anotasi berikut didukung:

Jenis Contoh
bilangan bulat python i: int
Python mengapung f: float
String python s: str
Boolean Python b: bool
Python Tidak Ada n: None
Bentuk tensor shape: tf.TensorShape
Tipe d tensor dtype: tf.DType
Tensor t: tf.Tensor
Jenis ekstensi mt: MyMaskedTensor
Tensor Ragged rt: tf.RaggedTensor
Tensor Jarang st: tf.SparseTensor
Irisan Terindeks s: tf.IndexedSlices
Tensor Opsional o: tf.experimental.Optional
Ketik serikat pekerja int_or_float: typing.Union[int, float]
Tuple params: typing.Tuple[int, float, tf.Tensor, int]
Tupel panjang-var lengths: typing.Tuple[int, ...]
Pemetaan tags: typing.Mapping[str, tf.Tensor]
Nilai opsional weight: typing.Optional[tf.Tensor]

Mutabilitas

Jenis ekstensi harus tidak dapat diubah. Ini memastikan bahwa mereka dapat dilacak dengan benar oleh mekanisme pelacakan grafik TensorFlow. Jika Anda ingin mengubah nilai tipe ekstensi, pertimbangkan untuk mendefinisikan metode yang mengubah nilai. Misalnya, daripada mendefinisikan metode set_mask untuk mengubah MaskedTensor , Anda bisa mendefinisikan metode replace_mask yang mengembalikan MaskedTensor baru :

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def replace_mask(self, new_mask):
      self.values.shape.assert_is_compatible_with(new_mask.shape)
      return MaskedTensor(self.values, new_mask)

Fungsionalitas ditambahkan oleh ExtensionType

Kelas dasar ExtensionType menyediakan fungsionalitas berikut:

  • Sebuah konstruktor ( __init__ ).
  • Metode representasi yang dapat dicetak ( __repr__ ).
  • Operator persamaan dan ketidaksamaan ( __eq__ ).
  • Metode validasi ( __validate__ ).
  • Kekekalan yang dipaksakan.
  • TypeSpec bersarang.
  • Dukungan pengiriman Tensor API.

Lihat bagian "Menyesuaikan Jenis Ekstensi" di bawah untuk informasi selengkapnya tentang menyesuaikan fungsi ini.

Konstruktor

Konstruktor yang ditambahkan oleh ExtensionType mengambil setiap bidang sebagai argumen bernama (dalam urutan mereka terdaftar dalam definisi kelas). Konstruktor ini akan mengetik-memeriksa setiap parameter, dan mengonversinya jika perlu. Secara khusus, bidang Tensor dikonversi menggunakan tf.convert_to_tensor ; Bidang Tuple dikonversi ke tuple s; dan bidang Mapping dikonversi menjadi dikte yang tidak dapat diubah.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])

# Fields are type-checked and converted to the declared types.
# E.g., mt.values is converted to a Tensor.
print(mt.values)
tf.Tensor(
[[1 2 3]
 [4 5 6]], shape=(2, 3), dtype=int32)

Konstruktor memunculkan TypeError jika nilai bidang tidak dapat dikonversi ke tipe yang dideklarasikan:

try:
  MaskedTensor([1, 2, 3], None)
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got None

Nilai default untuk bidang dapat ditentukan dengan menetapkan nilainya di tingkat kelas:

class Pencil(tf.experimental.ExtensionType):
  color: str = "black"
  has_erasor: bool = True
  length: tf.Tensor = 1.0

Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)

Representasi yang dapat dicetak

ExtensionType menambahkan metode representasi cetak default ( __repr__ ) yang menyertakan nama kelas dan nilai untuk setiap bidang:

print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True,  True, False])>)

Operator kesetaraan

ExtensionType menambahkan operator kesetaraan default ( __eq__ dan __ne__ ) yang menganggap dua nilai sama jika mereka memiliki tipe yang sama dan semua bidangnya sama. Medan tensor dianggap sama jika memiliki bentuk yang sama dan elemen yang sama untuk semua elemen.

a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True
a == b: False
a == a.values: False

Metode validasi

ExtensionType menambahkan metode __validate__ , yang dapat diganti untuk melakukan pemeriksaan validasi pada bidang. Itu dijalankan setelah konstruktor dipanggil, dan setelah bidang diperiksa tipenya dan dikonversi ke tipe yang dideklarasikan, sehingga dapat diasumsikan bahwa semua bidang memiliki tipe yang dideklarasikan.

contoh berikut memperbarui MaskedTensor untuk memvalidasi shape s dan dtype s bidangnya:

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor
  def __validate__(self):
    self.values.shape.assert_is_compatible_with(self.mask.shape)
    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
  MaskedTensor([1, 2, 3], [0, 1, 0])  # wrong dtype for mask.
except AssertionError as e:
  print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.
except ValueError as e:
  print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible

Kekekalan yang dipaksakan

ExtensionType menimpa metode __setattr__ dan __delattr__ untuk mencegah mutasi, memastikan bahwa nilai tipe ekstensi tidak dapat diubah.

mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
  mt.mask = [True, True, True]
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
  mt.mask[0] = False
except TypeError as e:
  print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
  del mt.mask
except AttributeError as e:
  print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.

Tipe BersarangSpec

Setiap kelas ExtensionType memiliki kelas TypeSpec yang sesuai, yang dibuat secara otomatis dan disimpan sebagai <extension_type_name>.Spec .

Kelas ini menangkap semua informasi dari suatu nilai kecuali untuk nilai tensor bersarang apa pun. Secara khusus, TypeSpec untuk suatu nilai dibuat dengan mengganti Tensor, ExtensionType, atau CompositeTensor bersarang dengan TypeSpec -nya.

class Player(tf.experimental.ExtensionType):
  name: tf.Tensor
  attributes: Mapping[str, tf.Tensor]

anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name)  # Records dtype and shape, but not the string value.
print(anne_spec.attributes)  # Records keys and TensorSpecs for values.
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
TensorSpec(shape=(), dtype=tf.string, name=None)
ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})

Nilai TypeSpec dapat dibangun secara eksplisit, atau dapat dibangun dari nilai ExtensionType menggunakan tf.type_spec_from_value :

spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)

TypeSpec s digunakan oleh TensorFlow untuk membagi nilai menjadi komponen statis dan komponen dinamis :

  • Komponen statis (yang ditetapkan pada waktu pembuatan grafik) dikodekan dengan tf.TypeSpec .
  • Komponen dinamis (yang dapat bervariasi setiap kali grafik dijalankan) dikodekan sebagai daftar tf.Tensor s.

Misalnya, tf.function menelusuri kembali fungsi yang dibungkusnya setiap kali argumen memiliki TypeSpec yang sebelumnya tidak terlihat :

@tf.function
def anonymize_player(player):
  print("<<TRACING>>")
  return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
WARNING:tensorflow:Mapping types may not work well with tf.nest. Prefer using MutableMapping for <class 'tensorflow.python.framework.immutable_dict.ImmutableDict'>
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>>
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))

Untuk informasi lebih lanjut, lihat Panduan tf.function .

Menyesuaikan Jenis Ekstensi

Selain hanya mendeklarasikan bidang dan jenisnya, jenis ekstensi dapat:

  • Ganti representasi default yang dapat dicetak ( __repr__ ).
  • Tentukan metode.
  • Tentukan metode kelas dan metode statis.
  • Tentukan properti.
  • Ganti konstruktor default ( __init__ ).
  • Ganti operator kesetaraan default ( __eq__ ).
  • Tentukan operator (seperti __add__ dan __lt__ ).
  • Deklarasikan nilai default untuk bidang.
  • Tentukan subclass.

Mengganti representasi default yang dapat dicetak

Anda dapat mengganti operator konversi string default ini untuk jenis ekstensi. Contoh berikut memperbarui kelas MaskedTensor untuk menghasilkan representasi string yang lebih mudah dibaca saat nilai dicetak dalam mode Eager.

class MaskedTensor(tf.experimental.ExtensionType):
  """A tensor paired with a boolean mask, indicating which values are valid."""
  values: tf.Tensor
  mask: tf.Tensor       # shape=values.shape; false for invalid values.

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

def masked_tensor_str(values, mask):
  if isinstance(values, tf.Tensor):
    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
    else:
      return f'MaskedTensor(values={values}, mask={mask})'
  if len(values.shape) == 1:
    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
  else:
    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
  return '[%s]' % ', '.join(items)

mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                  mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>

Mendefinisikan metode

Jenis ekstensi dapat menentukan metode, sama seperti kelas Python normal lainnya. Misalnya, tipe MaskedTensor bisa mendefinisikan metode with_default yang mengembalikan salinan self dengan nilai mask yang diganti dengan nilai default yang diberikan. Metode opsional dapat dijelaskan dengan dekorator @tf.function .

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>

Mendefinisikan classmethods dan staticmethods

Jenis ekstensi dapat menentukan metode menggunakan dekorator @classmethod dan @staticmethod . Misalnya, jenis MaskedTensor dapat mendefinisikan metode pabrik yang menutupi elemen apa pun dengan nilai yang diberikan:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  @staticmethod
  def from_tensor_and_value_to_mask(values, value_to_mask):
    return MaskedTensor(values, values == value_to_mask)

x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[_, 0, _], [_, 0, 0]]>

Mendefinisikan properti

Jenis ekstensi dapat mendefinisikan properti menggunakan dekorator @property , sama seperti kelas Python biasa. Misalnya, tipe MaskedTensor bisa mendefinisikan properti dtype yang merupakan singkatan dari dtype nilai:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  @property
  def dtype(self):
    return self.values.dtype

MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32

Mengganti konstruktor default

Anda dapat mengganti konstruktor default untuk jenis ekstensi. Konstruktor khusus harus menetapkan nilai untuk setiap bidang yang dideklarasikan; dan setelah konstruktor kustom kembali, semua bidang akan diperiksa jenisnya, dan nilainya akan dikonversi seperti dijelaskan di atas.

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor
  def __init__(self, name, price, discount=0):
    self.name = name
    self.price = price * (1 - discount)

print(Toy("ball", 5.0, discount=0.2))  # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

Atau, Anda dapat mempertimbangkan untuk membiarkan konstruktor default apa adanya, tetapi menambahkan satu atau beberapa metode pabrik. Misalnya:

class Toy(tf.experimental.ExtensionType):
  name: str
  price: tf.Tensor

  @staticmethod
  def new_toy_with_discount(name, price, discount):
    return Toy(name, price * (1 - discount))

print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)

Mengganti operator kesetaraan default ( __eq__ )

Anda dapat mengganti operator __eq__ default untuk jenis ekstensi. Contoh berikut memperbarui MaskedTensor untuk mengabaikan elemen bertopeng saat membandingkan kesetaraan.

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def __eq__(self, other):
    result = tf.math.equal(self.values, other.values)
    result = result | ~(self.mask & other.mask)
    return tf.reduce_all(result)

x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)

Menggunakan referensi maju

Jika jenis untuk bidang belum ditentukan, Anda dapat menggunakan string yang berisi nama jenis sebagai gantinya. Dalam contoh berikut, string "Node" digunakan untuk membubuhi keterangan pada bidang children karena tipe Node belum (sepenuhnya) didefinisikan.

class Node(tf.experimental.ExtensionType):
  value: tf.Tensor
  children: Tuple["Node", ...] = ()

Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))

Mendefinisikan subclass

Jenis ekstensi dapat disubklasifikasikan menggunakan sintaks Python standar. Subkelas tipe ekstensi dapat menambahkan bidang, metode, dan properti baru; dan dapat menimpa konstruktor, representasi yang dapat dicetak, dan operator kesetaraan. Contoh berikut mendefinisikan kelas TensorGraph dasar yang menggunakan tiga bidang Tensor untuk mengkodekan sekumpulan tepi di antara node. Ini kemudian mendefinisikan subclass yang menambahkan bidang Tensor untuk merekam "nilai fitur" untuk setiap node. Subclass juga mendefinisikan metode untuk menyebarkan nilai fitur di sepanjang tepi.

class TensorGraph(tf.experimental.ExtensionType):
  num_nodes: tf.Tensor
  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.
  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.

class TensorGraphWithNodeFeature(TensorGraph):
  node_features: tf.Tensor  # node_features[n] = feature value for node n.

  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
    updates = tf.gather(self.node_features, self.edge_src) * weight
    new_node_features = tf.tensor_scatter_nd_add(
        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
    return TensorGraphWithNodeFeature(
        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)

g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1
    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])

print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10.  0.  2.  5. -1.  0.], shape=(6,), dtype=float32)
After propagating: tf.Tensor([10. 12.  4.  4. -1.  0.], shape=(6,), dtype=float32)

Mendefinisikan bidang pribadi

Bidang jenis ekstensi dapat ditandai pribadi dengan mengawalinya dengan garis bawah (mengikuti konvensi standar Python). Ini tidak memengaruhi cara TensorFlow memperlakukan bidang dengan cara apa pun; tetapi hanya berfungsi sebagai sinyal bagi pengguna jenis ekstensi mana pun bahwa bidang tersebut bersifat pribadi.

Menyesuaikan TypeSpec dari ExtensionType

Setiap kelas ExtensionType memiliki kelas TypeSpec yang sesuai, yang dibuat secara otomatis dan disimpan sebagai <extension_type_name>.Spec . Untuk informasi selengkapnya, lihat bagian "Spec Tipe Bersarang" di atas.

Untuk menyesuaikan TypeSpec , cukup tentukan kelas bersarang Anda sendiri bernama Spec , dan ExtensionType akan menggunakannya sebagai dasar untuk TypeSpec yang dibuat secara otomatis . Anda dapat menyesuaikan kelas Spec dengan:

  • Mengganti representasi default yang dapat dicetak.
  • Mengganti konstruktor default.
  • Mendefinisikan metode, metode kelas, metode statis, dan properti.

Contoh berikut mengkustomisasi kelas MaskedTensor.Spec agar lebih mudah digunakan:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  def with_values(self, new_values):
    return MaskedTensor(new_values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    def __repr__(self):
      return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

Pengiriman Tensor API

Jenis ekstensi dapat "seperti tensor", dalam arti bahwa mereka mengkhususkan atau memperluas antarmuka yang ditentukan oleh jenis tf.Tensor . Contoh jenis ekstensi seperti tensor termasuk RaggedTensor , SparseTensor , dan MaskedTensor . Dekorator pengiriman dapat digunakan untuk mengganti perilaku default operasi TensorFlow saat diterapkan ke jenis ekstensi seperti tensor. TensorFlow saat ini mendefinisikan tiga dekorator pengiriman:

Pengiriman untuk satu API

Dekorator tf.experimental.dispatch_for_api menimpa perilaku default dari operasi TensorFlow tertentu saat dipanggil dengan tanda tangan yang ditentukan. Misalnya, Anda dapat menggunakan dekorator ini untuk menentukan bagaimana tf.stack harus memproses nilai MaskedTensor :

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))

Ini mengesampingkan implementasi default untuk tf.stack setiap kali dipanggil dengan daftar nilai MaskedTensor (karena argumen values dianotasi dengan typing.List[MaskedTensor] ):

x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>

Untuk mengizinkan tf.stack menangani daftar nilai campuran MaskedTensor dan Tensor , Anda dapat menyaring anotasi tipe untuk parameter values dan memperbarui isi fungsi dengan tepat:

tf.experimental.unregister_dispatch_for(masked_stack)

def convert_to_masked_tensor(x):
  if isinstance(x, MaskedTensor):
    return x
  else:
    return MaskedTensor(x, tf.ones_like(x, tf.bool))

@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
  values = [convert_to_masked_tensor(v) for v in values]
  return MaskedTensor(tf.stack([v.values for v in values], axis),
                      tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>

Untuk daftar API yang dapat diganti, lihat dokumentasi API untuk tf.experimental.dispatch_for_api .

Pengiriman untuk semua API elemen unary

Dekorator tf.experimental.dispatch_for_unary_elementwise_apis menimpa perilaku default dari semua operasi elemen unary (seperti tf.math.cos ) setiap kali nilai untuk argumen pertama (biasanya bernama x ) cocok dengan anotasi tipe x_type . Fungsi yang didekorasi harus mengambil dua argumen:

  • api_func : Sebuah fungsi yang mengambil parameter tunggal dan melakukan operasi elemen (misalnya, tf.abs ).
  • x : Argumen pertama untuk operasi elementwise.

Contoh berikut memperbarui semua operasi elemen unary untuk menangani jenis MaskedTensor :

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
 def masked_tensor_unary_elementwise_api_handler(api_func, x):
   return MaskedTensor(api_func(x.values), x.mask)

Fungsi ini sekarang akan digunakan setiap kali operasi elemen unary dipanggil pada MaskedTensor .

x = MaskedTensor([1, -2, -3], [True, False, True])
 print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>

Pengiriman untuk biner semua API elemen

Demikian pula, tf.experimental.dispatch_for_binary_elementwise_apis dapat digunakan untuk memperbarui semua operasi elemen biner untuk menangani jenis MaskedTensor :

@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>

Untuk daftar API elementwise yang diganti, lihat dokumentasi API untuk tf.experimental.dispatch_for_unary_elementwise_apis dan tf.experimental.dispatch_for_binary_elementwise_apis .

Jenis Ekstensi yang Dapat Dibatch

ExtensionType adalah batchable jika satu instance dapat digunakan untuk mewakili sekumpulan nilai. Biasanya, ini dilakukan dengan menambahkan dimensi batch ke semua Tensor bersarang. TensorFlow API berikut mengharuskan input jenis ekstensi apa pun dapat di-batch:

Secara default, BatchableExtensionType membuat nilai batch dengan mengelompokkan Tensor s, CompositeTensor s, dan ExtensionType s yang bersarang. Jika ini tidak sesuai untuk kelas Anda, maka Anda perlu menggunakan tf.experimental.ExtensionTypeBatchEncoder untuk mengganti perilaku default ini. Misalnya, tidak akan tepat untuk membuat kumpulan nilai tf.SparseTensor hanya dengan menumpuk masing-masing values tensor sparse , indices , dan bidang dense_shape -- dalam banyak kasus, Anda tidak dapat menumpuk tensor ini, karena memiliki bentuk yang tidak kompatibel ; dan bahkan jika Anda bisa, hasilnya tidak akan menjadi SparseTensor yang valid.

Contoh BatchableExtensionType: Jaringan

Sebagai contoh, pertimbangkan kelas Network sederhana yang digunakan untuk penyeimbangan beban, yang melacak berapa banyak pekerjaan yang tersisa untuk dilakukan di setiap node, dan berapa banyak bandwidth yang tersedia untuk memindahkan pekerjaan antar node:

class Network(tf.experimental.ExtensionType):  # This version is not batchable.
  work: tf.Tensor       # work[n] = work left to do at node n
  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2

net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])

Untuk membuat tipe ini dapat dikelompokkan, ubah tipe dasar menjadi BatchableExtensionType , dan sesuaikan bentuk setiap bidang untuk menyertakan dimensi kumpulan opsional. Contoh berikut juga menambahkan bidang shape untuk melacak bentuk kumpulan. Bidang shape ini tidak diperlukan oleh tf.data.Dataset atau tf.map_fn , tetapi diperlukan oleh tf.Keras .

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)

def network_repr(network):
  work = network.work
  bandwidth = network.bandwidth
  if hasattr(work, 'numpy'):
    work = ' '.join(str(work.numpy()).split())
  if hasattr(bandwidth, 'numpy'):
    bandwidth = ' '.join(str(bandwidth.numpy()).split())
  return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
    work=tf.stack([net1.work, net2.work]),
    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

Anda kemudian dapat menggunakan tf.data.Dataset untuk beralih melalui sekumpulan jaringan:

dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
  print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]>
Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>

Dan Anda juga dapat menggunakan map_fn untuk menerapkan fungsi ke setiap elemen batch:

def balance_work_greedy(network):
  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
  delta /= 4
  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
  new_work = network.work + tf.reduce_sum(delta, -1)
  return Network(new_work, network.bandwidth)

tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>

API TensorFlow yang mendukung ExtensionTypes

@tf.fungsi

tf.function adalah dekorator yang menghitung grafik TensorFlow untuk fungsi Python, yang secara substansial dapat meningkatkan kinerja kode TensorFlow Anda. Nilai tipe ekstensi dapat digunakan secara transparan dengan fungsi yang didekorasi dengan @tf.function .

class Pastry(tf.experimental.ExtensionType):
  sweetness: tf.Tensor  # 2d embedding that encodes sweetness
  chewiness: tf.Tensor  # 2d embedding that encodes chewiness

@tf.function
def combine_pastry_features(x: Pastry):
  return (x.sweetness + x.chewiness) / 2

cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

Jika Anda ingin secara eksplisit menentukan input_signature untuk tf.function , maka Anda dapat melakukannya menggunakan TypeSpec dari tipe ekstensi.

pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))

@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
  return Pastry(x.sweetness + delta, x.chewiness)

increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)

Fungsi konkret

Fungsi konkret merangkum grafik terlacak individu yang dibangun oleh tf.function . Jenis ekstensi dapat digunakan secara transparan dengan fungsi konkret.

cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>

Kontrol aliran operasi

Jenis ekstensi didukung oleh operasi aliran kontrol TensorFlow:

# Example: using tf.cond to select between two MaskedTensors.  Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
  return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>

Aliran kontrol tanda tangan

Jenis ekstensi juga didukung oleh pernyataan aliran kontrol di tf.function (menggunakan tanda tangan). Dalam contoh berikut, pernyataan if dan pernyataan for secara otomatis dikonversi ke operasi tf.cond dan tf.while_loop , yang mendukung jenis ekstensi.

@tf.function
def fn(x, b):
  if b:
    x = MaskedTensor(x, tf.less(x, 0))
  else:
    x = MaskedTensor(x, tf.greater(x, 0))
  for i in tf.range(5 if b else 7):
    x = x.with_values(x.values + 1 / 2)
  return x

print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]>
<MaskedTensor [4.5, _, 6.5]>

Keras

tf.keras adalah API tingkat tinggi TensorFlow untuk membangun dan melatih model pembelajaran mendalam. Jenis ekstensi dapat diteruskan sebagai input ke model Keras, diteruskan di antara lapisan Keras, dan dikembalikan oleh model Keras. Keras saat ini menempatkan dua persyaratan pada jenis ekstensi:

  • Mereka harus dapat di-batch (lihat "Jenis Ekstensi yang Dapat Dibatch" di atas).
  • Harus memiliki bidang atau properti bernama shape . shape[0] diasumsikan sebagai dimensi batch.

Dua subbagian berikut memberikan contoh yang menunjukkan bagaimana jenis ekstensi dapat digunakan dengan Keras.

Keras contoh: Network

Untuk contoh pertama, pertimbangkan kelas Network yang didefinisikan di bagian "Tipe Ekstensi Batchable" di atas, yang dapat digunakan untuk pekerjaan penyeimbangan beban antar node. Definisinya diulang di sini:

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape.  A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)
single_network = Network(  # A single network w/ 4 nodes.
    work=[8.0, 5, 12, 2],
    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])

batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.
    work=[[8.0, 5], [3, 2]],
    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])

Anda dapat menentukan lapisan Keras baru yang memproses Network s.

class BalanceNetworkLayer(tf.keras.layers.Layer):
  """Layer that balances work between nodes in a network.

  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
  """
  def call(self, inputs):
    # This function is defined above, in "Batchable ExtensionTypes" section.
    return balance_work_greedy(inputs)

Anda kemudian dapat menggunakan lapisan ini untuk membuat model sederhana. Untuk memasukkan ExtensionType ke dalam model, Anda dapat menggunakan lapisan tf.keras.layer.Input dengan type_spec disetel ke TypeSpec jenis ekstensi. Jika model Keras akan digunakan untuk memproses batch, maka type_spec harus menyertakan dimensi batch.

input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    BalanceNetworkLayer(),
    ])

Terakhir, Anda dapat menerapkan model ke satu jaringan dan ke kumpulan jaringan.

model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>

Contoh Keras: MaskedTensor

Dalam contoh ini, MaskedTensor diperluas untuk mendukung Keras . shape didefinisikan sebagai properti yang dihitung dari bidang values . Keras mengharuskan Anda menambahkan properti ini ke jenis ekstensi dan TypeSpec -nya. MaskedTensor juga mendefinisikan variabel __name__ , yang akan diperlukan untuk serialisasi SavedModel (di bawah).

class MaskedTensor(tf.experimental.BatchableExtensionType):
  # __name__ is required for serialization in SavedModel; see below for details.
  __name__ = 'extension_type_colab.MaskedTensor'

  values: tf.Tensor
  mask: tf.Tensor

  shape = property(lambda self: self.values.shape)
  dtype = property(lambda self: self.values.dtype)

  def with_default(self, default):
    return tf.where(self.mask, self.values, default)

  def __repr__(self):
    return masked_tensor_str(self.values, self.mask)

  class Spec:
    def __init__(self, shape, dtype=tf.float32):
      self.values = tf.TensorSpec(shape, dtype)
      self.mask = tf.TensorSpec(shape, tf.bool)

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)

    def with_shape(self):
      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
                               tf.TensorSpec(shape, self.mask.dtype))

Selanjutnya, dekorator pengiriman digunakan untuk mengganti perilaku default beberapa TensorFlow API. Karena API ini digunakan oleh lapisan Keras standar (seperti lapisan Dense ), menimpanya akan memungkinkan kita untuk menggunakan lapisan tersebut dengan MaskedTensor . Untuk tujuan contoh ini, matmul untuk tensor bertopeng didefinisikan untuk memperlakukan nilai terselubung sebagai nol (yaitu, untuk tidak menyertakannya dalam produk).

@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
 return MaskedTensor(op(x.values), x.mask)

@tf.experimental.dispatch_for_binary_elementwise_apis(
    Union[MaskedTensor, tf.Tensor],
    Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
  x = convert_to_masked_tensor(x)
  y = convert_to_masked_tensor(y)
  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)

@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None):
  if isinstance(a, MaskedTensor):
    a = a.with_default(0)
  if isinstance(b, MaskedTensor):
    b = b.with_default(0)
  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                   adjoint_b, a_is_sparse, b_is_sparse, output_type)

Anda kemudian dapat membuat model Keras yang menerima input MaskedTensor , menggunakan lapisan Keras standar:

input_spec = MaskedTensor.Spec([None, 2], tf.float32)

masked_tensor_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    tf.keras.layers.Dense(16, activation="relu"),
    tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
                  [[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3
1/1 [==============================] - 1s 955ms/step - loss: 10.2833
Epoch 2/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
Epoch 3/3
1/1 [==============================] - 0s 5ms/step - loss: 10.2833
tf.Tensor(
[[-0.09944128]
 [-0.7225147 ]
 [-1.3020657 ]], shape=(3, 1), dtype=float32)

Model Tersimpan

SavedModel adalah program TensorFlow serial, termasuk bobot dan komputasi. Itu dapat dibangun dari model Keras atau dari model khusus. Dalam kedua kasus, jenis ekstensi dapat digunakan secara transparan dengan fungsi dan metode yang ditentukan oleh SavedModel.

SavedModel dapat menyimpan model, lapisan, dan fungsi yang memproses jenis ekstensi, selama jenis ekstensi memiliki bidang __name__ . Nama ini digunakan untuk mendaftarkan jenis ekstensi, sehingga dapat ditemukan saat model dimuat.

Contoh: menyimpan model Keras

Model keras yang menggunakan jenis ekstensi dapat disimpan menggunakan SavedModel .

masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
2021-11-06 01:25:14.285250: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
INFO:tensorflow:Assets written to: /tmp/tmp3ceuupv9/assets
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-0.09944128],
       [-0.7225147 ],
       [-1.3020657 ]], dtype=float32)>

Contoh: menyimpan model khusus

SavedModel juga dapat digunakan untuk menyimpan subkelas tf.Module kustom dengan fungsi yang memproses tipe ekstensi.

class CustomModule(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def grow(self, x: MaskedTensor):
    """Increase values in `x` by multiplying them by `self.v`."""
    return MaskedTensor(x.values * self.v, x.mask)

module = CustomModule(100.0)

module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
                                                    dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
INFO:tensorflow:Assets written to: /tmp/tmp2x8zq5kb/assets
<MaskedTensor [_, 200.0, _]>

Memuat SavedModel saat ExtensionType tidak tersedia

Jika Anda memuat SavedModel yang menggunakan ExtensionType , tetapi ExtensionType itu tidak tersedia (yaitu, belum diimpor), maka Anda akan melihat peringatan dan TensorFlow akan kembali menggunakan objek "jenis ekstensi anonim". Objek ini akan memiliki bidang yang sama dengan tipe aslinya, tetapi tidak memiliki penyesuaian lebih lanjut yang telah Anda tambahkan untuk tipe tersebut, seperti metode atau properti kustom.

Menggunakan ExtensionTypes dengan penyajian TensorFlow

Saat ini, penyajian TensorFlow (dan konsumen lain dari kamus "tanda tangan" SavedModel) mengharuskan semua input dan output menjadi tensor mentah. Jika Anda ingin menggunakan penyajian TensorFlow dengan model yang menggunakan jenis ekstensi, Anda dapat menambahkan metode pembungkus yang menyusun atau menguraikan nilai jenis ekstensi dari tensor. Misalnya:

class CustomModuleWrapper(tf.Module):
  def __init__(self, variable_value):
    super().__init__()
    self.v = tf.Variable(variable_value)

  @tf.function
  def var_weighted_mean(self, x: MaskedTensor):
    """Mean value of unmasked values in x, weighted by self.v."""
    x = MaskedTensor(x.values * self.v, x.mask)
    return (tf.reduce_sum(x.with_default(0)) /
            tf.reduce_sum(tf.cast(x.mask, x.dtype)))

  @tf.function()
  def var_weighted_mean_wrapper(self, x_values, x_mask):
    """Raw tensor wrapper for var_weighted_mean."""
    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))

module = CustomModuleWrapper([3., 2., 8., 5.])

module.var_weighted_mean_wrapper.get_concrete_function(
    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
INFO:tensorflow:Assets written to: /tmp/tmpxhh4zh0i/assets
<tf.Tensor: shape=(), dtype=float32, numpy=12.0>

Kumpulan data

tf.data adalah API yang memungkinkan Anda membangun saluran input yang kompleks dari bagian sederhana yang dapat digunakan kembali. Struktur data intinya adalah tf.data.Dataset , yang mewakili urutan elemen, di mana setiap elemen terdiri dari satu atau lebih komponen.

Membangun Dataset dengan tipe ekstensi

Kumpulan data dapat dibangun dari nilai tipe ekstensi menggunakan Dataset.from_tensors , Dataset.from_tensor_slices , atau Dataset.from_generator :

ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
  print(value)
<MaskedTensor [0, 1, 2, 3]>
<MaskedTensor [4, 5, 6, 7]>
<MaskedTensor [8, 9, 10, 11]>
<MaskedTensor [12, 13, 14, 15]>
<MaskedTensor [16, 17, 18, 19]>
def value_gen():
  for i in range(2, 7):
    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])

ds = tf.data.Dataset.from_generator(
    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>

Batching dan unbatching Dataset dengan tipe ekstensi

Kumpulan data dengan tipe ekstensi dapat berupa batchand dan unbatch menggunakan Dataset.batch dan Dataset.unbatch .

batched_ds = ds.batch(2)
for value in batched_ds:
  print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]>
<MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]>
<MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
  print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]>
<MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]>
<MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]>
<MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]>
<MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>