TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
概要
TensorFlow の型昇格には 4 つのオプションがあります。
- デフォルトでは、混合型の演算に対し、TensorFlow は型を昇格する代わりにエラーを発します。
tf.numpy.experimental_enable_numpy_behavior()
を実行すると、TensorFlow が NumPy 型の昇格ルールを使用するように切り替えられます。- このドキュメントでは、TensorFlow 2.15 で提供予定の新しい 2 つのオプションについて説明します(現在は、
tf-nightly
で提供されています)。
pip install -q tf_nightly
注意: experimental_enable_numpy_behavior
は、TensorFlow のすべての動作を変更します。
セットアップ
import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
print("Using TensorFlow version %s" % tf.__version__)
Using TensorFlow version 2.16.0-dev20240110
新しい型昇格の有効化
JAX のような型昇格を TF-Numpy で使用するには、TensorFlow で NumPy の動作を有効にする際に、dtype 変換モードとして 'all'
または 'safe'
のいずれかを指定します。
この新しい組(dtype_conversion_mode="all"
を使用)は結合的で可換であり、最終的にどのような幅の浮動小数になるかを制御するのが簡単になります(自動的により幅の広い float に変換しません)。ただし、オーバーフローと精度損失のリスクがいくらか導入されますが、dtype_conversion_mode="safe"
によってそれらのケースの明示的な処理が強制されます。2 つのモードについては、次のセクションで詳しく説明されています。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
2 つのモード: ALL モードと SAFE モード
新しい型昇格システムでは、ALL
モードと SAFE
モードの 2 つのモードが導入されています。SAFE
モードは精度損失またはビット拡張となる「リスクのある」昇格の懸念を緩和するために使用されます。
Dtype
簡潔さの目的で、以下の略語を使用します。
b
はtf.bool
ですu8
はtf.uint8
ですi16
はtf.int16
ですi32
はtf.int32
ですbf16
はtf.bfloat16
ですf32
はtf.float32
ですf64
はtf.float64
ですi32*
は Python のint
または弱く型付けされたi32
ですf32*
は Python のfloat
または弱く型付けされたf32
ですc128*
は Python のcomplex
または弱く型付けされたc128
です
アスタリスク(*)は、対応する方が「弱い」ことを示します。そのような dtype は一時的にシステムによって推論されるため、他の dtype に従う可能性があります。この概念は、こちらでより詳しく説明されています。
精度を損失する演算の例
次の例では、i32
+ f32
は ALL
モードでは可能ですが、精度損失のリスクにより、SAFE
モードでは行えません。
# i32 + f32 returns a f32 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
a + b # <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
try:
a + b
except TypeError as e:
print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int32'>, weak=False) and (<dtype: 'float32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).
ビット拡張の演算の例
次の例において、i8
+ u32
は ALL
モードでは可能ですが、入力のビットの数よりも多いビットを使用するビット拡張により、SAFE
モードでは行えません。新しい型昇格セマンティクスでは必要なビット拡張のみが許可されることに注意してください。
# i8 + u32 returns an i64 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
a + b
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=int64, numpy=15>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
try:
a + b
except TypeError as e:
print(f'{type(e)}: {e}') # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int8'>, weak=False) and (<dtype: 'uint32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).
格子に基づくシステム
型昇格格子
新しい型昇格の動作は、次の型昇格の格子を通じて決定されます。
より具体的には、2 つの型の間の昇格は、2 つのノード(ノード事態を含む)の最初の共通の子を見つけて決定されます。
たとえば、上のダイアグラムの場合、i8
と i32
の最初の共通の子は i32
です。この 2 つのノードは矢印の方向に進む際に最初に i32
で交差するためです。
もう 1 つの例とも同様に、u64
と f16
の間の昇格の結果の方は f16
となります。
型昇格テーブル
格子に従うと、以下のバイナリ昇格テーブルが生成されます。
注意: SAFE
モードでは、ハイライトされたセルは許可されません。ALL
モードではすべてのケースが許可されます。
新しい型昇格のメリット
新しい型昇格には、以下のメリットのある JAX のような格子ベースのシステムを採用します。
格子ベースのシステムのメリット
まず、格子ベースのシステムを使用することで、3 つの非常に重要な特性が確保されます。
- 存在: あらゆる型の組み合わせに固有の結果昇格型があります。
- 可換性:
a + b = b + a
- 結合性:
a + (b + c) = (a + b) + c
これらの 3 つの特性は、一貫性と予測可能性を備えた型昇格セマンティクスを構築する上で重要な特性です。
JAX のような格子系のメリット
JAX のような格子系のもう 1 つの重大なメリットは、符号なしの int の外側では、必要以上に広範なプロモーションをすべて回避することです。つまり、64 ビットの入力なしに 64 ビットの結果を取得することはできません。これは以前の型昇格で頻繁であった不要な 64 ビット値を回避するため、特にアクセラレータで処理する際に大きなメリットがあります。
ただし、これにはトレードオフがあります。float/integer が混合する昇格には精度損失の非常に強い傾向があることです。たとえば、下の例では i64
+ f16
は i64
を f16
に昇格してしまいます。
# The first input is promoted to f16 in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
tf.constant(1, tf.int64) + tf.constant(3.2, tf.float16) # <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
このような懸念を緩和するために、こういった「リスクのある」昇格を許可しない SAFE
モードを導入しました。
注意: 格子系の構築における設計上の考慮点については、JAX の型昇格セマンティクスの設計をご覧ください。
WeakTensor
概要
弱いテンソルとは、JAX における概念に似た「弱く型付けされた」テンソルです。
WeakTensor
の dtype は一時的にシステムによって推論され、他の dtype に従う可能性があります。この概念は、TF 値と、Python のスカラーリテラルのように明示的にユーザーが指定した型がない値の間で行われるバイナリ演算内で不要な型昇格が行われないようにするために、新しい型昇格に導入されています。
たとえば下の例では、tf.constant(1.2)
には特定の dtype がないため、「弱い」と見なされます。したがって、tf.constant(1.2)
は tf.constant(3.1, tf.float16)
の型に従い、f16
の出力結果が得られます。
tf.constant(1.2) + tf.constant(3.1, tf.float16) # <tf.Tensor: shape=(), dtype=float16, numpy=4.3>
<tf.Tensor: shape=(), dtype=float16, numpy=4.3>
WeakTensor の構造
WeakTensor は、dtype を指定せずにテンソルを作成した場合に作成され、その結果として WeakTensor となります。テンソルが「弱い」かどうかは、テンソルの文字列表現の最後にある weak 属性をチェックすることでわかります。
最初のケース: tf.constant
が、ユーザー指定の dtype のない入力で呼び出された場合。
tf.constant(5) # <tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
tf.constant([5.0, 10.0, 3]) # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10., 3.], dtype=float32), weak=True>
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10., 3.], dtype=float32), weak=True>
# A normal Tensor is created when dtype arg is specified.
tf.constant(5, tf.int32) # <tf.Tensor: shape=(), dtype=int32, numpy=5>
<tf.Tensor: shape=(), dtype=int32, numpy=5>
2 つ目のケース: ユーザー指定の dtype のない入力が WeakTensor をサポートする API に渡された場合。
tf.math.abs([100.0, 4.0]) # <tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
新しい型昇格をオンにした効果
以下は、新しい型昇格をオンにしたことによる変更の部分リストです。
- より一貫性のある予測可能な昇格結果。
- ビット拡張のリスクの軽減。
tf.Tensor
の数学的ダンダーメソッドでは、新しい型の昇格が使用されます。tf.constant
はWeakTensor
を戻せます。tf.constant
は、dtype
引数とは異なる dtype を持つテンソル入力が渡された場合に、暗黙的な変換を行えます。tf.Variable
インプレース演算(assign
、assign-add
、assign-sub
)で暗黙の変換が可能です。tnp.array(1)
とtnp.array(1.0)
は 32 ビット WeakTensor を返します。WeakTensor
が作成され、WeakTensor がサポートする単項およびバイナリ API に使用されます。
より一貫性のある予測可能な昇格結果
格子ベースのシステムを使用することで、新しい型昇格により、一貫性のある予測可能な型昇格結果を生成することができます。
以前の型昇格
以前の型昇格を使用すると、演算の順序の変更によって結果にばらつきが生じます。
# Setup
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
# (a + b) + c throws an InvalidArgumentError.
try:
tf.add(tf.add(a, b), c)
except tf.errors.InvalidArgumentError as e:
print(f'{type(e)}: {e}') # InvalidArgumentError
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute AddV2 as input #1(zero-based) was expected to be a int8 tensor but is a int32 tensor [Op:AddV2] name:
# (b + a) + c returns an i32 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=int32, numpy=3>
<tf.Tensor: shape=(), dtype=int32, numpy=3>
新しい型昇格
新しい型昇格では、順序に関係なく一貫した結果を得られます。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
# (a + b) + c returns a f16 result.
tf.add(tf.add(a, b), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
# (b + a) + c also returns a f16 result.
tf.add(tf.add(b, a), c) # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
ビット拡張のリスクの軽減
以前の型昇格
以前の型昇格では、64 ビットの結果が生成されることがありました。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
<tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
新しい型昇格
新しい型昇格では、必要最小限のビット数で結果が返されます。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50) # <tf.Tensor: shape=(), dtype=float16, numpy=54.2>
<tf.Tensor: shape=(), dtype=float16, numpy=54.2>
tf.Tensor の数学的ダンダーメソッド
すべての tf.Tensor
数学的ダンダーメソッドは、新しい型昇格に従います。
-tf.constant(5) # <tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
tf.constant(5, tf.int16) - tf.constant(1, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>
tf.Variable インプレース演算
tf.Variable
インプレース演算では、暗黙的変換が可能です。
注意: 変数の元の dtype とは異なる dtype を生成する昇格は許可されません。これは、tf.Variable
がその dtype を変更できないためです。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.Variable(10, tf.int32)
a.assign_add(tf.constant(5, tf.int16)) # <tf.Variable shape=() dtype=int32, numpy=15>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Variable 'UnreadVariable' shape=() dtype=int32, numpy=15>
tf.constant の暗黙的変換
以前の型昇格の場合、tf.constant
では、入力テンソルに dtype 引数と同じ dtype が使用されている必要がありましたが、新しい型昇格では、テンソルが指定された dtype に暗黙的に変換されます。
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, tf.int16)
tf.constant(a, tf.float32) # <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet. <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
TF-NumPy の配列
新しい型昇格では、Python の tnp.array
はデフォルトで i32*
と f32*
になります。
tnp.array(1) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
tnp.array(1.0) # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0, weak=True>
入力の型推論
新しい型昇格では、異なる入力の型は以下のようにして推論されます。
tf.Tensor
:tf.Tensor
には dtype プロパティがあるため、それ以上の推論は行われません。- NumPy 型: これには
np.array(1)
、np.int16(1)
、np.float
などの型が含まれます。NumPy 型入力にも dtype プロパティが含まれているため、その dtype プロパティが結果の推論型として使用されます。NumPy はデフォルトでi64
とf64
になることに注意してください。 - Python スカラー/ネスト型: これには
1
、[1, 2, 3]
、(1.0, 2.0)
などの型が含まれます。- Python
int
はi32*
として推論されます。 - Python
float
はf32*
として推論されます。 - Python
complex
はc128*
として推論されます。
- Python
- 入力が上記のいずれのカテゴリにも当てはまらない場合でも dtype プロパティがある場合には、その dtype プロパティが結果の推論型として使用されます。
その他の資料
新しい型昇格は JAC-NumPy の型昇格に非常によく似ています。新しい型昇格とその設計上の選択についての詳細は、以下のリソースをご覧ください。
参考資料
WeakTensor をサポートしている API
以下は、WeakTensor
をサポートしている API のリストです。
単項演算については、ユーザー指定の型がない入力が渡されると、WeakTensor
を返します。
バイナリ演算については、こちらの昇格テーブルに従います。2 つの入力の昇格結果に応じて、WeakTensor
が返される場合とそうでない場合があります。
注意: すべての数学的演算(+
、-
、*
など)がサポートされています。
tf.bitwise.invert
tf.clip_by_value
tf.debugging.check_numerics
tf.expand_dims
tf.identity
tf.image.adjust_brightness
tf.image.adjust_gamma
tf.image.extract_patches
tf.image.random_brightness
tf.image.stateless_random_brightness
tf.linalg.diag
tf.linalg.diag_part
tf.linalg.matmul
tf.linalg.matrix_transpose
tf.linalg.tensor_diag_part
tf.linalg.trace
tf.math.abs
tf.math.acos
tf.math.acosh
tf.math.add
tf.math.angle
tf.math.asin
tf.math.asinh
tf.math.atan
tf.math.atanh
tf.math.ceil
tf.math.conj
tf.math.cos
tf.math.cosh
tf.math.digamma
tf.math.divide_no_nan
tf.math.divide
tf.math.erf
tf.math.erfc
tf.math.erfcinv
tf.math.erfinv
tf.math.exp
tf.math.expm1
tf.math.floor
tf.math.floordiv
tf.math.floormod
tf.math.imag
tf.math.lgamma
tf.math.log1p
tf.math.log_sigmoid
tf.math.log
tf.math.multiply_no_nan
tf.math.multiply
tf.math.ndtri
tf.math.negative
tf.math.pow
tf.math.real
tf.math.real
tf.math.reciprocal_no_nan
tf.math.reciprocal
tf.math.reduce_euclidean_norm
tf.math.reduce_logsumexp
tf.math.reduce_max
tf.math.reduce_mean
tf.math.reduce_min
tf.math.reduce_prod
tf.math.reduce_std
tf.math.reduce_sum
tf.math.reduce_variance
tf.math.rint
tf.math.round
tf.math.rsqrt
tf.math.scalar_mul
tf.math.sigmoid
tf.math.sign
tf.math.sin
tf.math.sinh
tf.math.softplus
tf.math.special.bessel_i0
tf.math.special.bessel_i0e
tf.math.special.bessel_i1
tf.math.special.bessel_i1e
tf.math.special.bessel_j0
tf.math.special.bessel_j1
tf.math.special.bessel_k0
tf.math.special.bessel_k0e
tf.math.special.bessel_k1
tf.math.special.bessel_k1e
tf.math.special.bessel_y0
tf.math.special.bessel_y1
tf.math.special.dawsn
tf.math.special.expint
tf.math.special.fresnel_cos
tf.math.special.fresnel_sin
tf.math.special.spence
tf.math.sqrt
tf.math.square
tf.math.subtract
tf.math.tan
tf.math.tanh
tf.nn.depth_to_space
tf.nn.elu
tf.nn.gelu
tf.nn.leaky_relu
tf.nn.log_softmax
tf.nn.relu6
tf.nn.relu
tf.nn.selu
tf.nn.softsign
tf.nn.space_to_depth
tf.nn.swish
tf.ones_like
tf.realdiv
tf.reshape
tf.squeeze
tf.stop_gradient
tf.transpose
tf.truncatediv
tf.truncatemod
tf.zeros_like
tf.experimental.numpy.abs
tf.experimental.numpy.absolute
tf.experimental.numpy.amax
tf.experimental.numpy.amin
tf.experimental.numpy.angle
tf.experimental.numpy.arange
tf.experimental.numpy.arccos
tf.experimental.numpy.arccosh
tf.experimental.numpy.arcsin
tf.experimental.numpy.arcsinh
tf.experimental.numpy.arctan
tf.experimental.numpy.arctanh
tf.experimental.numpy.around
tf.experimental.numpy.array
tf.experimental.numpy.asanyarray
tf.experimental.numpy.asarray
tf.experimental.numpy.ascontiguousarray
tf.experimental.numpy.average
tf.experimental.numpy.bitwise_not
tf.experimental.numpy.cbrt
tf.experimental.numpy.ceil
tf.experimental.numpy.conj
tf.experimental.numpy.conjugate
tf.experimental.numpy.copy
tf.experimental.numpy.cos
tf.experimental.numpy.cosh
tf.experimental.numpy.cumprod
tf.experimental.numpy.cumsum
tf.experimental.numpy.deg2rad
tf.experimental.numpy.diag
tf.experimental.numpy.diagflat
tf.experimental.numpy.diagonal
tf.experimental.numpy.diff
tf.experimental.numpy.empty_like
tf.experimental.numpy.exp2
tf.experimental.numpy.exp
tf.experimental.numpy.expand_dims
tf.experimental.numpy.expm1
tf.experimental.numpy.fabs
tf.experimental.numpy.fix
tf.experimental.numpy.flatten
tf.experimental.numpy.flip
tf.experimental.numpy.fliplr
tf.experimental.numpy.flipud
tf.experimental.numpy.floor
tf.experimental.numpy.full_like
tf.experimental.numpy.imag
tf.experimental.numpy.log10
tf.experimental.numpy.log1p
tf.experimental.numpy.log2
tf.experimental.numpy.log
tf.experimental.numpy.max
tf.experimental.numpy.mean
tf.experimental.numpy.min
tf.experimental.numpy.moveaxis
tf.experimental.numpy.nanmean
tf.experimental.numpy.negative
tf.experimental.numpy.ones_like
tf.experimental.numpy.positive
tf.experimental.numpy.prod
tf.experimental.numpy.rad2deg
tf.experimental.numpy.ravel
tf.experimental.numpy.real
tf.experimental.numpy.reciprocal
tf.experimental.numpy.repeat
tf.experimental.numpy.reshape
tf.experimental.numpy.rot90
tf.experimental.numpy.round
tf.experimental.numpy.signbit
tf.experimental.numpy.sin
tf.experimental.numpy.sinc
tf.experimental.numpy.sinh
tf.experimental.numpy.sort
tf.experimental.numpy.sqrt
tf.experimental.numpy.square
tf.experimental.numpy.squeeze
tf.experimental.numpy.std
tf.experimental.numpy.sum
tf.experimental.numpy.swapaxes
tf.experimental.numpy.tan
tf.experimental.numpy.tanh
tf.experimental.numpy.trace
tf.experimental.numpy.transpose
tf.experimental.numpy.triu
tf.experimental.numpy.vander
tf.experimental.numpy.var
tf.experimental.numpy.zeros_like