Sets the global dtype policy.
tf.keras.mixed_precision.set_global_policy(
policy
)
The global policy is the default tf.keras.mixed_precision.Policy
used for
layers, if no policy is passed to the layer constructor.
tf.keras.mixed_precision.set_global_policy('mixed_float16')
tf.keras.mixed_precision.global_policy()
<Policy "mixed_float16">
tf.keras.layers.Dense(10).dtype_policy
<Policy "mixed_float16">
# Global policy is not used if a policy
# is directly passed to constructor
tf.keras.layers.Dense(10, dtype='float64').dtype_policy
<Policy "float64">
tf.keras.mixed_precision.set_global_policy('float32')
If no global policy is set, layers will instead default to a Policy
constructed from tf.keras.backend.floatx()
.
To use mixed precision, the global policy should be set to 'mixed_float16'
or 'mixed_bfloat16'
, so that every layer uses a 16-bit compute dtype and
float32 variable dtype by default.
Only floating point policies can be set as the global policy, such as
'float32'
and 'mixed_float16'
. Non-floating point policies such as
'int32'
and 'complex64'
cannot be set as the global policy because most
layers do not support such policies.
See tf.keras.mixed_precision.Policy
for more information.
Args | |
---|---|
policy
|
A Policy, or a string that will be converted to a Policy. Can also
be None, in which case the global policy will be constructed from
tf.keras.backend.floatx()
|