TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
セットアップ
!pip install -q tf_nightly
import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional
import tempfile
2022-12-14 20:32:52.787665: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
拡張型
ユーザー定義型を使用すると、プロジェクトが読みやすくなり、モジュール化され、保守しやすくなります。ただし、ほとんどの TensorFlow API では、ユーザー定義の Python 型が限定的にしかサポートされていません。これには、高レベル API(Keras、tf.function、tf.SavedModel
など)と低レベル API(tf.while_loop
や tf.concat
など)の両方が含まれます。 TensorFlow 拡張型を使用して、TensorFlow の API とシームレスに連携するユーザー定義のオブジェクト指向型を作成できます。拡張型を作成するには、単純に tf.experimental.ExtensionType
をベースとして Python クラスを定義し、型注釈を使用して各フィールドの型を指定します。
class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.
class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64
tf.experimental.ExtensionType
基底クラスは、標準の Python ライブラリの typing.NamedTuple
および @dataclasses.dataclass
と同じように機能します。特に、フィールド型の注釈に基づいて、コンストラクタと特別なメソッド(__repr__
や __eq__
など)が自動的に追加されます。
通常、拡張型は次の 2 つのカテゴリのいずれかに分類される傾向があります。
データ構造。関連する値のコレクションをグループ化し、それらの値に基づいて役立つ演算を提供できます。データ構造は汎用性が高い場合(上記の
TensorGraph
の例など)、または特定のモデルに合わせて高度にカスタマイズされている場合があります。テンソルのような型。「テンソル」の概念を特殊化または拡張します。このカテゴリの型には、
rank
、shape
、そして通常はdtype
があります。 テンソル演算 (tf.stack
、tf.add
、またはtf.matmul
など)でそれらを使用することは合理的です。MaskedTensor
とCSRSparseMatrix
は、テンソルのような型の例です。
サポートされている API
拡張型は以下の TensorFlow API でサポートされています。
- Keras: 拡張型は Keras
Models
とLayers
の入出力として使用できます。 tf.data.Dataset
: 拡張型は、データセットDatasets
に含むことができ、データセットIterators
で返すことができます。- TensorFlow Hub: 拡張型は
tf.hub
の入出力として使用できます。 - SavedModel: 拡張型は
SavedModel
関数の入出力として使用できます。 tf.function
: 拡張型は、@tf.function
デコレータでラップされた関数の引数および戻り値として使用できます。- While ループ: 拡張型は
tf.while_loop
でループ変数として使用でき、while ループの本体の引数および戻り値として使用できます。 - 条件付き:
tf.cond
およびtf.case
を使用して、拡張型を条件付きで選択できます。 tf.py_function
: 拡張型は引数として使用でき、tf.py_function
へのfunc
引数の値を返します。- テンソル演算: テンソルの入力(
tf.matmul
、tf.gather
、およびtf.reduce_sum
など)を受け入れるほとんどの TensorFlow 演算をサポートするために拡張型を拡張できます。詳細については、以下の「ディスパッチ」セクションに移動してください。 - 分散ストラテジー: 拡張型はレプリカごとの値として使用できます。
詳細については、以下の「ExtensionTypes をサポートする TensorFlow API」のセクションをご覧ください。
要件
フィールド型
すべてのフィールド(インスタンス変数)を宣言する必要があり、各フィールドに型注釈を指定する必要があります。次の型注釈がサポートされています。
型 | 例 |
---|---|
Python 整数 | i: int |
Python フロート | f: float |
Python 文字列 | s: str |
Python ブール値 | b: bool |
Python None |
n: None |
テンソル形状 | shape: tf.TensorShape |
テンソル dtype |
dtype: tf.DType |
テンソル | t: tf.Tensor |
拡張型 | mt: MyMaskedTensor |
不規則なテンソル | rt: tf.RaggedTensor |
スパーステンソル | st: tf.SparseTensor |
インデックススライス | s: tf.IndexedSlices |
オプションのテンソル | o: tf.experimental.Optional |
型結合 | int_or_float: typing.Union[int, float] |
タプル | params: typing.Tuple[int, float, tf.Tensor, int] |
可変長タプル | lengths: typing.Tuple[int, ...] |
マッピング | tags: typing.Mapping[str, tf.Tensor] |
オプションの値 | weight: typing.Optional[tf.Tensor] |
可変性
拡張型は不変である必要があります。これにより、TensorFlow のグラフトレースメカニズムによって適切に追跡できるようになります。拡張型の値を変更する場合は、代わりに値を変換するメソッドを定義することを検討してください。たとえば、MaskedTensor
を変更する set_mask
メソッドを定義するのではなく、新しい MaskedTensor
を返す replace_mask
メソッドを定義できます。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def replace_mask(self, new_mask):
self.values.shape.assert_is_compatible_with(new_mask.shape)
return MaskedTensor(self.values, new_mask)
ExtensionType
によって追加される機能
ExtensionType
基底クラスは、次の機能を提供します。
- コンストラクタ(
__init__
)。 - 出力可能な表現メソッド(
__repr__
)。 - 等価演算子と不等価演算子(
__eq__
)。 - 検証メソッド(
__validate__
)。 - 不変性の強制。
- ネストされた
TypeSpec
。 - テンソル API ディスパッチのサポート。
この機能のカスタマイズの詳細については、以下の「 ExtensionType
のカスタマイズ」セクションに移動してください。
コンストラクタ
ExtensionType
によって追加されたコンストラクタは、各フィールドを名前付き引数として(クラス定義にリストされている順序で)受け取ります。このコンストラクタは、各パラメーターを型チェックし、必要に応じて変換します。特に、Tensor
フィールドは tf.convert_to_tensor
を使用して変換されます。 Tuple
フィールドは tuple
に変換されます。 Mapping
フィールドは不変の dict に変換されます。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
# Constructor takes one parameter for each field.
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
# Fields are type-checked and converted to the declared types.
# For example, `mt.values` is converted to a Tensor.
print(mt.values)
tf.Tensor( [[1 2 3] [4 5 6]], shape=(2, 3), dtype=int32)
フィールド値を宣言された型に変換できない場合、コンストラクタは TypeError
を発生させます。
try:
MaskedTensor([1, 2, 3], None)
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: mask: expected a Tensor, got 'NoneType'
フィールドのデフォルト値は、クラスレベルで値を設定することによって指定できます。
class Pencil(tf.experimental.ExtensionType):
color: str = "black"
has_erasor: bool = True
length: tf.Tensor = 1.0
Pencil()
Pencil(color='black', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
Pencil(length=0.5, color="blue")
Pencil(color='blue', has_erasor=True, length=<tf.Tensor: shape=(), dtype=float32, numpy=0.5>)
出力可能な表現
ExtensionType
は、クラス名と各フィールドの値を含むデフォルトの出力可能な表現メソッド(__repr__
)を追加します。
print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))
MaskedTensor(values=<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, mask=<tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True, True, False])>)
等値演算子
ExtensionType
は、2 つの値が同じ型を持ち、すべてのフィールドが等しい場合に等しいと見なすデフォルトの等価演算子(__eq__
および __ne__
)を追加します。テンソルフィールドは、同じ形状を持ち、すべての要素に対して要素ごとに等しい場合、等しいと見なされます。
a = MaskedTensor([1, 2], [True, False])
b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])
print(f"a == a: {a==a}")
print(f"a == b: {a==b}")
print(f"a == a.values: {a==a.values}")
a == a: True a == b: False a == a.values: False
注意: いずれかのフィールドに Tensor
が含まれている場合、__eq__
は(Python ブール値ではなく)スカラーブール値 Tensor
を返す場合があります。
検証メソッド
ExtensionType
は、フィールドの検証チェックを実行するためにオーバーライドできる __validate__
メソッドを追加します。コンストラクタが呼び出された後、フィールドが型チェックされ、宣言された型に変換された後に実行されるため、すべてのフィールドの型は宣言された型であると想定できます。
次の例では、MaskedTensor
を更新して、そのフィールドの shape
と dtype
を検証します。
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor
def __validate__(self):
self.values.shape.assert_is_compatible_with(self.mask.shape)
assert self.mask.dtype.is_bool, 'mask.dtype must be bool'
try:
MaskedTensor([1, 2, 3], [0, 1, 0]) # Wrong `dtype` for mask.
except AssertionError as e:
print(f"Got expected AssertionError: {e}")
Got expected AssertionError: mask.dtype must be bool
try:
MaskedTensor([1, 2, 3], [True, False]) # shapes don't match.
except ValueError as e:
print(f"Got expected ValueError: {e}")
Got expected ValueError: Shapes (3,) and (2,) are incompatible
不変性の強制
ExtensionType
は __setattr__
と __delattr__
メソッドをオーバーライドして突然変異を防ぎ、拡張型の値が不変であることを保証します。
mt = MaskedTensor([1, 2, 3], [True, False, True])
try:
mt.mask = [True, True, True]
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
try:
mt.mask[0] = False
except TypeError as e:
print(f"Got expected TypeError: {e}")
Got expected TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
try:
del mt.mask
except AttributeError as e:
print(f"Got expected AttributeError: {e}")
Got expected AttributeError: Cannot mutate attribute `mask` outside the custom constructor of ExtensionType.
ネストされた TypeSpec
各 ExtensionType
クラスには対応する TypeSpec
クラスがあり、これは自動的に作成され、<extension_type_name>.Spec
として保存されます。
このクラスは、ネストされたテンソルの値以外の値からすべての情報を取得します。特に、値の TypeSpec
は、ネストされたテンソル、ExtensionType、または CompositeTensor をその TypeSpec
に置き換えることによって作成されます。
class Player(tf.experimental.ExtensionType):
name: tf.Tensor
attributes: Mapping[str, tf.Tensor]
anne = Player("Anne", {"height": 8.3, "speed": 28.1})
anne_spec = tf.type_spec_from_value(anne)
print(anne_spec.name) # Records `dtype` and `shape`, but not the string value.
print(anne_spec.attributes) # Records keys and TensorSpecs for values.
TensorSpec(shape=(), dtype=tf.string, name=None) ImmutableDict({'height': TensorSpec(shape=(), dtype=tf.float32, name=None), 'speed': TensorSpec(shape=(), dtype=tf.float32, name=None)})
TypeSpec
値は明示的に構築することも、 tf.type_spec_from_value
を使用して ExtensionType
値から構築することもできます。
spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})
spec2 = tf.type_spec_from_value(anne)
TypeSpec
は、値を静的コンポーネントと動的コンポーネントに分割するために TensorFlow によって使用されます。
- 静的コンポーネント(グラフ構築時に固定される)は
tf.TypeSpec
でエンコードされます。 - 動的コンポーネント(グラフが実行されるたびに変化する可能性があります)は、
tf.Tensor
のリストとしてエンコードされます。
たとえば、tf.function
は、引数に以前は見られなかった TypeSpec
があるときはいつでも、そのラップされた関数を再トレースします。
@tf.function
def anonymize_player(player):
print("<<TRACING>>")
return Player("<anonymous>", player.attributes)
# Function gets traced (first time the function has been called):
anonymize_player(Player("Anne", {"height": 8.3, "speed": 28.1}))
<<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.3>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=28.1>}))
# Function does NOT get traced (same TypeSpec: just tensor values changed)
anonymize_player(Player("Bart", {"height": 8.1, "speed": 25.3}))
Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=8.1>, 'speed': <tf.Tensor: shape=(), dtype=float32, numpy=25.3>}))
# Function gets traced (new TypeSpec: keys for attributes changed):
anonymize_player(Player("Chuck", {"height": 11.0, "jump": 5.3}))
<<TRACING>> Player(name=<tf.Tensor: shape=(), dtype=string, numpy=b'<anonymous>'>, attributes=ImmutableDict({'height': <tf.Tensor: shape=(), dtype=float32, numpy=11.0>, 'jump': <tf.Tensor: shape=(), dtype=float32, numpy=5.3>}))
詳細については、tf.function ガイドをご覧ください。
ExtensionType
のカスタマイズ
単純にフィールドとその型を宣言するだけでなく、拡張型は次のことができます。
- デフォルトの出力可能な表現(
__repr__
)をオーバーライドします。 - メソッドを定義します。
classmethod
とstaticmethod
を定義します。- プロパティを定義します。
- デフォルトのコンストラクタ(
__init__
)をオーバーライドします。 - デフォルトの等価演算子(
__eq__
)をオーバーライドします。 - 演算子を定義します(
__add__
や__lt__
など)。 - フィールドのデフォルト値を宣言します。
- サブクラスを定義します。
デフォルトの印刷可能な表現のオーバーライド
拡張型のこのデフォルトの文字列変換演算子をオーバーライドできます。次の例では、MaskedTensor
クラスを更新して、値が Eager モードで出力されるときに、より読みやすい文字列表現を生成します。
class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for invalid values.
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def masked_tensor_str(values, mask):
if isinstance(values, tf.Tensor):
if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):
return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'
else:
return f'MaskedTensor(values={values}, mask={mask})'
if len(values.shape) == 1:
items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
else:
items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]
return '[%s]' % ', '.join(items)
mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
mask=[[True, True, False], [True, False, True]])
print(mt)
<MaskedTensor [[1, 2, _], [4, _, 6]]>
メソッドの定義
拡張型は、通常の Python クラスと同様に、メソッドを定義できます。たとえば、MaskedTensor
型は、指定された default
値に置き換えられたマスクされた値を持つ self
のコピーを返す with_default
メソッドを定義できます。メソッドには、オプションで @tf.function
デコレータで注釈を付けることができます。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def with_default(self, default):
return tf.where(self.mask, self.values, default)
MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 3], dtype=int32)>
classmethod
と staticmethod
の定義
拡張型は、@classmethod
および @staticmethod
デコレータを使用してメソッドを定義できます。たとえば、MaskedTensor
型は、任意の要素を特定の値でマスクするファクトリメソッドを定義できます。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
@staticmethod
def from_tensor_and_value_to_mask(values, value_to_mask):
return MaskedTensor(values, values != value_to_mask)
x = tf.constant([[1, 0, 2], [3, 0, 0]])
MaskedTensor.from_tensor_and_value_to_mask(x, 0)
<MaskedTensor [[1, _, 2], [3, _, _]]>
プロパティの定義
拡張型は、通常の Python クラスと同様に、@property
デコレータを使用してプロパティを定義できます。たとえば、MaskedTensor
型は、値の dtype
の短縮形である dtype
プロパティを定義できます。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@property
def dtype(self):
return self.values.dtype
MaskedTensor([1, 2, 3], [True, False, True]).dtype
tf.int32
デフォルトのコンストラクタのオーバーライド
拡張型の既定のコンストラクタをオーバーライドできます。カスタムコンストラクタは、宣言されたフィールドごとに値を設定する必要があります。カスタムコンストラクタが戻った後、すべてのフィールドが型チェックされ、値が上記のように変換されます。
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
def __init__(self, name, price, discount=0):
self.name = name
self.price = price * (1 - discount)
print(Toy("ball", 5.0, discount=0.2)) # On sale -- 20% off!
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
または、デフォルトのコンストラクタをそのままにして、1 つ以上のファクトリメソッドを追加することも検討できます。例えば、次のとおりです。
class Toy(tf.experimental.ExtensionType):
name: str
price: tf.Tensor
@staticmethod
def new_toy_with_discount(name, price, discount):
return Toy(name, price * (1 - discount))
print(Toy.new_toy_with_discount("ball", 5.0, discount=0.2))
Toy(name='ball', price=<tf.Tensor: shape=(), dtype=float32, numpy=4.0>)
デフォルトの等価演算子(__eq__
)のオーバーライド
拡張型のデフォルトの __eq__
演算子をオーバーライドできます。次の例では、等しいかどうかを比較するときにマスクされた要素を無視するように MaskedTensor
を更新します。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def __eq__(self, other):
result = tf.math.equal(self.values, other.values)
result = result | ~(self.mask & other.mask)
return tf.reduce_all(result)
x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])
y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])
print(x == y)
tf.Tensor(True, shape=(), dtype=bool)
注意: 通常、__ne__
をオーバーライドする必要はありません。デフォルトの実装では単に __eq__
を呼び出して結果を否定するだけだからです。
前方参照の使用
フィールドの型がまだ定義されていない場合は、代わりに型の名前を含む文字列を使用できます。次の例では、Node
型がまだ(完全に)定義されていないため、文字列 "Node"
を使用して children
フィールドに注釈を付けています。
class Node(tf.experimental.ExtensionType):
value: tf.Tensor
children: Tuple["Node", ...] = ()
Node(3, [Node(5), Node(2)])
Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=3>, children=(Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>, children=()), Node(value=<tf.Tensor: shape=(), dtype=int32, numpy=2>, children=())))
サブクラスの定義
拡張型は、標準の Python 構文を使用してサブクラス化できます。拡張型のサブクラスは、新しいフィールド、メソッド、およびプロパティを追加できます。コンストラクタ、出力可能な表現、および等値演算子をオーバーライドする場合があります。次の例では、3 つの Tensor
フィールドを使用してノード間の一連のエッジをエンコードする基本的な TensorGraph
クラスを定義します。次に、Tensor
フィールドを追加して各ノードの「特徴量値」を記録するサブクラスを定義します。サブクラスは、特徴量値をエッジに沿って伝播するメソッドも定義します。
class TensorGraph(tf.experimental.ExtensionType):
num_nodes: tf.Tensor
edge_src: tf.Tensor # edge_src[e] = index of src node for edge e.
edge_dst: tf.Tensor # edge_dst[e] = index of dst node for edge e.
class TensorGraphWithNodeFeature(TensorGraph):
node_features: tf.Tensor # node_features[n] = feature value for node n.
def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':
updates = tf.gather(self.node_features, self.edge_src) * weight
new_node_features = tf.tensor_scatter_nd_add(
self.node_features, tf.expand_dims(self.edge_dst, 1), updates)
return TensorGraphWithNodeFeature(
self.num_nodes, self.edge_src, self.edge_dst, new_node_features)
g = TensorGraphWithNodeFeature( # Edges: 0->1, 4->3, 2->2, 2->1
num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],
node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])
print("Original features:", g.node_features)
print("After propagating:", g.propagate_features().node_features)
Original features: tf.Tensor([10. 0. 2. 5. -1. 0.], shape=(6,), dtype=float32) After propagating: tf.Tensor([10. 12. 4. 4. -1. 0.], shape=(6,), dtype=float32)
プライベートフィールドの定義
拡張型のフィールドは、アンダースコアを(標準の Python 規則に従って)プレフィックスとして付けることにより、非公開としてマークすることができます。これは、TensorFlow がフィールドを処理する方法にはまったく影響しません。これらのフィールドがプライベートであることを拡張型のユーザーに通知するだけです。
ExtensionType
の TypeSpec
のカスタマイズ
各 ExtensionType
クラスには対応する TypeSpec
クラスがあり、これは自動的に作成され、<extension_type_name>.Spec
として保存されます。詳細については、上記の「ネストされた TypeSpec」セクションをご覧ください。
TypeSpec
をカスタマイズするには、Spec
という名前の独自のネストされたクラスを定義するだけで、ExtensionType
は自動的に構築された TypeSpec
の基礎としてそれを使用します。次の方法で Spec
クラスをカスタマイズできます。
- デフォルトの出力可能な表現のオーバーライド。
- デフォルトのコンストラクタのオーバーライド。
- メソッド、
classmethod
、staticmethod
、およびプロパティの定義。
次の例では、MaskedTensor.Spec
クラスをカスタマイズして使いやすくしています。
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
def with_values(self, new_values):
return MaskedTensor(new_values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
def __repr__(self):
return f"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})"
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
注意: カスタム Spec
クラスは、元の ExtensionType
で宣言されなかったインスタンス変数を使用することはできません。
テンソル API ディスパッチ
拡張型は、tf.Tensor
型によって定義されたインターフェースを特殊化または拡張するという意味で、「テンソルのような」ものにすることができます。テンソルのような拡張型の例には、RaggedTensor
、SparseTensor
、および MaskedTensor
が含まれます。ディスパッチデコレータは、テンソルのような拡張型に適用された場合に、TensorFlow 演算のデフォルトの動作をオーバーライドするために使用できます。 TensorFlow は現在、3 つのディスパッチデコレータを定義しています。
@tf.experimental.dispatch_for_api(tf_api)
@tf.experimental.dispatch_for_unary_elementwise_apis(x_type)
@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)
単一の API のディスパッチ
tf.experimental.dispatch_for_api
デコレータは、指定されたシグネチャで呼び出されると、指定された TensorFlow 演算のデフォルトの動作をオーバーライドします。たとえば、このデコレータを使用して、tf.stack
が MaskedTensor
値を処理する方法を指定できます。
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack(values: List[MaskedTensor], axis = 0):
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
これは、MaskedTensor
値のリストで呼び出されるたびに、tf.stack
のデフォルトの実装をオーバーライドします(values
引数には、typing.List[MaskedTensor]
で注釈が付けられているためです)。
x = MaskedTensor([1, 2, 3], [True, True, False])
y = MaskedTensor([4, 5, 6], [False, True, True])
tf.stack([x, y])
<MaskedTensor [[1, 2, _], [_, 5, 6]]>
tf.stack
が混在した MaskedTensor
値と Tensor
値のリストを処理できるようにするには、values
パラメータの型注釈を設定し直し、関数の本体を適切に更新します。
tf.experimental.unregister_dispatch_for(masked_stack)
def convert_to_masked_tensor(x):
if isinstance(x, MaskedTensor):
return x
else:
return MaskedTensor(x, tf.ones_like(x, tf.bool))
@tf.experimental.dispatch_for_api(tf.stack)
def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):
values = [convert_to_masked_tensor(v) for v in values]
return MaskedTensor(tf.stack([v.values for v in values], axis),
tf.stack([v.mask for v in values], axis))
x = MaskedTensor([1, 2, 3], [True, True, False])
y = tf.constant([4, 5, 6])
tf.stack([x, y, x])
<MaskedTensor [[1, 2, _], [4, 5, 6], [1, 2, _]]>
オーバーライドできる API のリストについては、tf.experimental.dispatch_for_api
の API ドキュメントをご覧ください。
すべての単項要素ごとの API のディスパッチ
tf.experimental.dispatch_for_unary_elementwise_apis
デコレータは、最初の引数(通常は x
という名前)の値が型注釈 x_type
と一致する場合はいつでも、すべての単項要素ごとの演算(tf.math.cos
など)のデフォルトの動作をオーバーライドします。装飾された関数は、次の 2 つの引数を取る必要があります。
api_func
: 単一のパラメータを取り、要素ごとの演算を実行する関数(たとえば、tf.abs
)。x
: 要素ごとの演算の最初の引数。
次の例では、MaskedTensor
型を処理するためにすべての単項要素ごとの演算を更新します。
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def masked_tensor_unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
MaskedTensor
で単項要素ごとの演算が呼び出されるたびに、この関数が使用されるようになりました。
x = MaskedTensor([1, -2, -3], [True, False, True])
print(tf.abs(x))
<MaskedTensor [1, _, 3]>
print(tf.ones_like(x, dtype=tf.float32))
<MaskedTensor [1.0, _, 1.0]>
バイナリのすべての要素ごとの API のディスパッチ
同様に、tf.experimental.dispatch_for_binary_elementwise_apis
を使用して、MaskedTensor
型を処理するためにすべてのバイナリ要素ごとの演算を更新できます。
@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def masked_tensor_binary_elementwise_api_handler(api_func, x, y):
return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
x = MaskedTensor([1, -2, -3], [True, False, True])
y = MaskedTensor([[4], [5]], [[True], [False]])
tf.math.add(x, y)
<MaskedTensor [[5, _, 1], [_, _, _]]>
オーバーライドされる要素ごとの API のリストについては、tf.experimental.dispatch_for_unary_elementwise_apis
および tf.experimental.dispatch_for_binary_elementwise_apis
の API ドキュメントをご覧ください。
バッチ処理可能な ExtensionType
1 つのインスタンスを使用して値のバッチを表すことができる場合、ExtensionType
はバッチ可能です。通常、これはネストされたすべての Tensor
にバッチディメンションを追加することによって実現されます。次の TensorFlow API では、拡張型の入力がバッチ可能である必要があります。
tf.data.Dataset
(batch
、unbatch
、from_tensor_slices
)tf.keras
(fit
、evaluate
、predict
)tf.map_fn
デフォルトでは、BatchableExtensionType
は、ネストされた Tensor
、CompositeTensor
、およびExtensionType
をバッチ処理することにより、バッチ処理された値を作成します。これがクラスに適していない場合は、tf.experimental.ExtensionTypeBatchEncoder
を使用してこのデフォルトの動作をオーバーライドする必要があります。例えば、個々のスパーステンソルの values
、indices
、および dense_shape
を単純にスタックしてtf.SparseTensor
値のバッチを作成することは適切ではありません。ほとんどの場合、これらのテンソルの形状には互換性がないため、スタックできません。たとえできたとしても、結果は有効な SparseTensor
にはなりません。
注意: BatchableExtensionType
は、tf.stack
、tf.concat
、tf.slice
などのディスパッチャを自動的に定義しません。クラスをこれらの API でサポートする必要がある場合は、上記のディスパッチデコレータを使用してください。
BatchableExtensionType
の例: Network
例として、負荷分散に使用される単純な Network
クラスを考えてみましょう。これは、各ノードで実行するために残っている作業の量と、ノード間で作業を移動するために使用できる帯域幅を追跡します。
class Network(tf.experimental.ExtensionType): # This version is not batchable.
work: tf.Tensor # work[n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[n1, n2] = bandwidth from n1->n2
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
この型をバッチ処理可能にするには、ベースタイプを BatchableExtensionType
に変更し、各フィールドの形状を調整してオプションのバッチの次元を含めます。次の例では、バッチ形状を追跡するための shape
フィールドも追加します。この shape
フィールドは tf.data.Dataset
または tf.map_fn
では必要ありませんが、tf.keras
では必要です。
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
def network_repr(network):
work = network.work
bandwidth = network.bandwidth
if hasattr(work, 'numpy'):
work = ' '.join(str(work.numpy()).split())
if hasattr(bandwidth, 'numpy'):
bandwidth = ' '.join(str(bandwidth.numpy()).split())
return (f"<Network shape={network.shape} work={work} bandwidth={bandwidth}>")
net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])
net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])
batch_of_networks = Network(
work=tf.stack([net1.work, net2.work]),
bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))
print(f"net1={net1}")
print(f"net2={net2}")
print(f"batch={batch_of_networks}")
net1=<Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> net2=<Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]> batch=<Network shape=(2,) work=[[5. 3. 8.] [3. 4. 2.]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
その後、tf.data.Dataset
を使用して、ネットワークのバッチを反復処理できます。
dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)
for i, network in enumerate(dataset):
print(f"Batch element {i}: {network}")
Batch element 0: <Network shape=() work=[5. 3. 8.] bandwidth=[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]]> Batch element 1: <Network shape=() work=[3. 4. 2.] bandwidth=[[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]>
また、map_fn
を使用して、関数を各バッチ要素に適用することもできます。
def balance_work_greedy(network):
delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))
delta /= 4
delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)
new_work = network.work + tf.reduce_sum(delta, -1)
return Network(new_work, network.bandwidth)
tf.map_fn(balance_work_greedy, batch_of_networks)
<Network shape=(2,) work=[[5.5 1.25 9.25] [3. 4.75 1.25]] bandwidth=[[[0. 2. 0.] [2. 0. 3.] [0. 3. 0.]] [[0. 2. 2.] [2. 0. 2.] [2. 2. 0.]]]>
ExtensionType
をサポートする TensorFlow API
@tf.function
tf.function
は、Python 関数の TensorFlow グラフを事前計算するデコレータで、TensorFlow コードのパフォーマンスを大幅に改善できます。拡張型の値は、@tf.function
でデコレートされた関数で透過的に使用できます。
class Pastry(tf.experimental.ExtensionType):
sweetness: tf.Tensor # 2d embedding that encodes sweetness
chewiness: tf.Tensor # 2d embedding that encodes chewiness
@tf.function
def combine_pastry_features(x: Pastry):
return (x.sweetness + x.chewiness) / 2
cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])
combine_pastry_features(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
tf.function
の input_signature
を明示的に指定する場合は、拡張型の TypeSpec
を使用して指定できます。
pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))
@tf.function(input_signature=[pastry_spec])
def increase_sweetness(x: Pastry, delta=1.0):
return Pastry(x.sweetness + delta, x.chewiness)
increase_sweetness(cookie)
Pastry(sweetness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.2, 1.4], dtype=float32)>, chewiness=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.8, 0.2], dtype=float32)>)
具象関数
具象関数は、tf.function
で構築された個別のトレース済みグラフをカプセル化します。拡張型は、具象関数で透過的に使用できます。
cf = combine_pastry_features.get_concrete_function(pastry_spec)
cf(cookie)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1. , 0.3], dtype=float32)>
制御フロー演算
拡張型は以下の TensorFlow の制御フロー演算でサポートされています。
# Example: using tf.cond to select between two MaskedTensors. Note that the
# two MaskedTensors don't need to have the same shape.
a = MaskedTensor([1., 2, 3], [True, False, True])
b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])
condition = tf.constant(True)
print(tf.cond(condition, lambda: a, lambda: b))
<MaskedTensor [1.0, _, 3.0]>
# Example: using tf.while_loop with MaskedTensor.
cond = lambda i, _: i < 10
def body(i, mt):
return i + 1, mt.with_values(mt.values + 3 / 7)
print(tf.while_loop(cond, body, [0, b])[1])
<MaskedTensor [26.285717, 37.285698, 112.285736, _]>
Autograph 制御フロー
拡張型は、tf.function
の制御フローステートメントでもサポートされます(autograph を使用)。次の例では、if
ステートメントと for
ステートメントは、拡張型をサポートする tf.cond
および tf.while_loop
演算に自動的に変換されます。
@tf.function
def fn(x, b):
if b:
x = MaskedTensor(x, tf.less(x, 0))
else:
x = MaskedTensor(x, tf.greater(x, 0))
for i in tf.range(5 if b else 7):
x = x.with_values(x.values + 1 / 2)
return x
print(fn(tf.constant([1., -2, 3]), tf.constant(True)))
print(fn(tf.constant([1., -2, 3]), tf.constant(False)))
<MaskedTensor [_, 0.5, _]> <MaskedTensor [4.5, _, 6.5]>
Keras
tf.keras は、ディープラーニングモデルを構築およびトレーニングするための TensorFlow の高レベル API です。拡張型は、入力として Keras モデルに渡され、Keras レイヤー間で渡され、Keras モデルによって返されます。 Keras は現在、拡張型に対して 2 つの要件を課しています。
- バッチ可能でなければならない(上記の「バッチ処理可能な
ExtensionType
」に移動してください)。 shape
という名前のフィールドまたはプロパティが必要である。shape[0]
はバッチの次元と見なされます。
次の 2 つのサブセクションでは、Keras で拡張型を使用する方法についての例を示します。
Keras の例: Network
最初の例として、上記の「バッチ処理可能な ExtensionType
」セクションで定義された Network
クラスを考えてみましょう。これは、ノード間の作業の負荷分散に使用できます。ここではその定義が繰り返されています。
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
single_network = Network( # A single network with 4 nodes.
work=[8.0, 5, 12, 2],
bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])
batch_of_networks = Network( # Batch of 2 networks, each w/ 2 nodes.
work=[[8.0, 5], [3, 2]],
bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])
Network
を処理する新しい Keras レイヤーを定義できます。
class BalanceNetworkLayer(tf.keras.layers.Layer):
"""Layer that balances work between nodes in a network.
Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.
"""
def call(self, inputs):
# This function is defined above in the "Batchable `ExtensionType`s" section.
return balance_work_greedy(inputs)
次に、これらのレイヤーを使用して単純なモデルを作成できます。ExtensionType
をモデルにフィードするには、type_spec
が拡張型の TypeSpec
に設定された tf.keras.layer.Input
レイヤーを使用できます。Keras モデルをバッチで使用する場合、type_spec
をバッチの次元に含める必要があります。
input_spec = Network.Spec(shape=None,
work=tf.TensorSpec(None, tf.float32),
bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
BalanceNetworkLayer(),
])
最後に、モデルを単一のネットワークとネットワークのバッチに適用できます。
model(single_network)
<Network shape=() work=[ 9.25 5. 14. -1.25] bandwidth=[[0. 1. 2. 2.] [1. 0. 0. 2.] [2. 0. 0. 1.] [2. 2. 1. 0.]]>
model(batch_of_networks)
<Network shape=(2,) work=[[8.75 4.25] [3.25 1.75]] bandwidth=[[[0. 1.] [1. 0.]] [[0. 2.] [2. 0.]]]>
Keras の例: MaskedTensor
この例では、MaskedTensor
が Keras
をサポートするように拡張されています。shape
は、values
フィールドから計算されるプロパティとして定義されます。Keras では、このプロパティを拡張型とその TypeSpec
の両方に追加する必要があります。MaskedTensor
は __name__
変数も定義します。これは、SavedModel
のシリアル化(下記)に必要です。
class MaskedTensor(tf.experimental.BatchableExtensionType):
# __name__ is required for serialization in SavedModel; see below for details.
__name__ = 'extension_type_colab.MaskedTensor'
values: tf.Tensor
mask: tf.Tensor
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_default(self, default):
return tf.where(self.mask, self.values, default)
def __repr__(self):
return masked_tensor_str(self.values, self.mask)
class Spec:
def __init__(self, shape, dtype=tf.float32):
self.values = tf.TensorSpec(shape, dtype)
self.mask = tf.TensorSpec(shape, tf.bool)
shape = property(lambda self: self.values.shape)
dtype = property(lambda self: self.values.dtype)
def with_shape(self):
return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),
tf.TensorSpec(shape, self.mask.dtype))
次に、ディスパッチデコレータを使用して、いくつかの TensorFlow API のデフォルトの動作をオーバーライドします。これらの API は標準の Keras レイヤー(Dense
レイヤーなど)で使用されるため、これらをオーバーライドすると、これらのレイヤーを MaskedTensor
で使用できるようになります。この例では、マスクされたテンソルの matmul
は、マスクされた値をゼロとして扱う(つまり、それらを積に含めない)ように定義されています。
@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_op_handler(op, x):
return MaskedTensor(op(x.values), x.mask)
@tf.experimental.dispatch_for_binary_elementwise_apis(
Union[MaskedTensor, tf.Tensor],
Union[MaskedTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
x = convert_to_masked_tensor(x)
y = convert_to_masked_tensor(y)
return MaskedTensor(op(x.values, y.values), x.mask & y.mask)
@tf.experimental.dispatch_for_api(tf.matmul)
def masked_matmul(a: MaskedTensor, b,
transpose_a=False, transpose_b=False,
adjoint_a=False, adjoint_b=False,
a_is_sparse=False, b_is_sparse=False,
output_type=None):
if isinstance(a, MaskedTensor):
a = a.with_default(0)
if isinstance(b, MaskedTensor):
b = b.with_default(0)
return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
adjoint_b, a_is_sparse, b_is_sparse, output_type)
次に、標準の Keras レイヤーを使用して、MaskedTensor
入力を受け入れる Kerasモデルを構築できます。
input_spec = MaskedTensor.Spec([None, 2], tf.float32)
masked_tensor_model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
tf.keras.layers.Dense(16, activation="relu"),
tf.keras.layers.Dense(1)])
masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
a = MaskedTensor([[1., 2], [3, 4], [5, 6]],
[[True, False], [False, True], [True, True]])
masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)
print(masked_tensor_model(a))
Epoch 1/3 1/1 [==============================] - 1s 1s/step - loss: 0.7110 Epoch 2/3 1/1 [==============================] - 0s 6ms/step - loss: 0.6215 Epoch 3/3 1/1 [==============================] - 0s 5ms/step - loss: 0.5670 tf.Tensor( [[ 0.20307031] [-0.32614586] [ 1.0157624 ]], shape=(3, 1), dtype=float32)
SavedModel
SavedModel は、シリアル化された TensorFlow プログラムで、重みと計算の両方が含まれます。Keras モデルまたはカスタムモデルから構築できます。いずれの場合でも、拡張型は SavedModel によって定義された関数とメソッドで透過的に使用できます。
SavedModel は、拡張型に __name__
フィールドがある限り、拡張型を処理するモデル、レイヤー、および関数を保存できます。この名前は拡張型を登録するために使用されるため、モデルを読み込む際に見つけることができます。
例: Keras モデルを保存する
拡張型を使用する Keras モデルは、SavedModel
を使用して保存できます。
masked_tensor_model_path = tempfile.mkdtemp()
tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)
imported_model = tf.saved_model.load(masked_tensor_model_path)
imported_model(a)
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel. INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3_ax5e_0/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp3_ax5e_0/assets <tf.Tensor: shape=(3, 1), dtype=float32, numpy= array([[ 0.20307031], [-0.32614586], [ 1.0157624 ]], dtype=float32)>
例: カスタムモデルを保存する
SavedModel は、拡張型を処理する関数を持つカスタム tf.Module
サブクラスを保存するためにも使用できます。
class CustomModule(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def grow(self, x: MaskedTensor):
"""Increase values in `x` by multiplying them by `self.v`."""
return MaskedTensor(x.values * self.v, x.mask)
module = CustomModule(100.0)
module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,
dtype=tf.float32))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp9apon9h2/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp9apon9h2/assets <MaskedTensor [_, 200.0, _]>
ExtensionType
が利用できない場合に SavedModel を読み込む
ExtensionType
を使用する SavedModel
を読み込んだけれども、その ExtensionType
が利用できない(つまり、インポートされていない)場合、警告が表示され、TensorFlow は「匿名拡張型」オブジェクトの使用にフォールバックします。このオブジェクトには元の型と同じフィールドがありますが、カスタムメソッドやプロパティなど、型に追加したカスタマイズはありません。
TensorFlow Serving で ExtensionType
を使用する
現在、TensorFlow Serving(および SavedModel の「シグネチャ」ディクショナリの他のコンシューマー)は、すべての入力と出力が生のテンソルである必要があります。拡張型を使用するモデルで TensorFlow Serving を使用する場合は、テンソルから拡張型の値を構成または分解するラッパーメソッドを追加できます。例えば、次のとおりです。
class CustomModuleWrapper(tf.Module):
def __init__(self, variable_value):
super().__init__()
self.v = tf.Variable(variable_value)
@tf.function
def var_weighted_mean(self, x: MaskedTensor):
"""Mean value of unmasked values in x, weighted by self.v."""
x = MaskedTensor(x.values * self.v, x.mask)
return (tf.reduce_sum(x.with_default(0)) /
tf.reduce_sum(tf.cast(x.mask, x.dtype)))
@tf.function()
def var_weighted_mean_wrapper(self, x_values, x_mask):
"""Raw tensor wrapper for var_weighted_mean."""
return self.var_weighted_mean(MaskedTensor(x_values, x_mask))
module = CustomModuleWrapper([3., 2., 8., 5.])
module.var_weighted_mean_wrapper.get_concrete_function(
tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))
custom_module_path = tempfile.mkdtemp()
tf.saved_model.save(module, custom_module_path)
imported_model = tf.saved_model.load(custom_module_path)
x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])
imported_model.var_weighted_mean_wrapper(x.values, x.mask)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpts_yhd88/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpts_yhd88/assets <tf.Tensor: shape=(), dtype=float32, numpy=12.0>
Dataset
tf.data は、単純で再利用可能なピースから複雑な入力パイプラインを構築できる API です。そのコアデータ構造は tf.data.Dataset
で、一連の要素を表し、その各要素には 1 つ以上のコンポーネントが含まれます。
拡張型を使用した Dataset
の構築
Dataset.from_tensors
、Dataset.from_tensor_slices
、または Dataset.from_generator
を使用して、拡張型の値からデータセットを構築できます。
ds = tf.data.Dataset.from_tensors(Pastry(5, 5))
iter(ds).next()
Pastry(sweetness=<tf.Tensor: shape=(), dtype=int32, numpy=5>, chewiness=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))
ds = tf.data.Dataset.from_tensor_slices(mt)
for value in ds:
print(value)
<MaskedTensor [0, 1, 2, 3]> <MaskedTensor [4, 5, 6, 7]> <MaskedTensor [8, 9, 10, 11]> <MaskedTensor [12, 13, 14, 15]> <MaskedTensor [16, 17, 18, 19]>
def value_gen():
for i in range(2, 7):
yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])
ds = tf.data.Dataset.from_generator(
value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))
for value in ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>
拡張型を使用した Dataset
のバッチ処理とバッチ処理解除
拡張型を持つデータセットは、Dataset.batch
および Dataset.unbatch
を使用してバッチおよびバッチ解除できます。
batched_ds = ds.batch(2)
for value in batched_ds:
print(value)
<MaskedTensor [[_, 1, _, 3, _, 5, _, 7, _, 9], [_, 1, 2, _, 4, 5, _, 7, 8, _]]> <MaskedTensor [[_, 1, 2, 3, _, 5, 6, 7, _, 9], [_, 1, 2, 3, 4, _, 6, 7, 8, 9]]> <MaskedTensor [[_, 1, 2, 3, 4, 5, _, 7, 8, 9]]>
unbatched_ds = batched_ds.unbatch()
for value in unbatched_ds:
print(value)
<MaskedTensor [_, 1, _, 3, _, 5, _, 7, _, 9]> <MaskedTensor [_, 1, 2, _, 4, 5, _, 7, 8, _]> <MaskedTensor [_, 1, 2, 3, _, 5, 6, 7, _, 9]> <MaskedTensor [_, 1, 2, 3, 4, _, 6, 7, 8, 9]> <MaskedTensor [_, 1, 2, 3, 4, 5, _, 7, 8, 9]>