View source on GitHub |
Decorator that allow a numpy function to be used in Eager and Graph modes.
tf_agents.utils.eager_utils.np_function(
func=None, output_dtypes=None
)
Similar to tf.py_func
and tf.py_function
but it doesn't require defining
the inputs or the dtypes of the outputs a priori.
In Eager mode it would convert the tf.Tensors to np.arrays before passing to
func
and then convert back the outputs from np.arrays to tf.Tensors.
In Graph mode it would create different tf.py_function for each combination of dtype of the inputs and cache them for reuse.
Instead of doing:
def sum(x):
return np.sum(x)
inputs = tf.constant([3, 4])
outputs = tf.py_function(sum, inputs, Tout=[tf.int64])
inputs = tf.constant([3., 4.])
outputs = tf.py_function(sum, inputs, Tout=[tf.float32])
#### Do:
@eager_utils.np_function
def sum(x):
return np.sum(x)
inputs = tf.constant([3, 4])
outputs = sum(inputs) # Infers that Tout is tf.int64
inputs = tf.constant([3., 4.])
outputs = sum(inputs) # Infers that Tout is tf.float32
# Output dtype is always float32 for valid input dtypes.
@eager_utils.np_function(output_dtypes=np.float32)
def mean(x):
return np.mean(x)
# Output dtype depends on the input dtype.
@eager_utils.np_function(output_dtypes=lambda x: (x, x))
def repeat(x):
return x, x
with context.graph_mode():
outputs = sum(tf.constant([3, 4]))
outputs2 = sum(tf.constant([3., 4.]))
sess.run(outputs) # np.array(7)
sess.run(outputs2) # np.array(7.)
with context.eager_mode():
inputs = tf.constant([3, 4])
outputs = sum(tf.constant([3, 4])) # tf.Tensor([7])
outputs = sum(tf.constant([3., 4.])) # tf.Tensor([7.])
Args:
func: A numpy function, that takes numpy arrays as inputs and return numpy
arrays as outputs.
output_dtypes: Optional list of dtypes or a function that maps input dtypes
to output dtypes. Examples: output_dtypes=[tf.float32],
output_dtypes=lambda x: x (outputs have the same dtype as inputs). If it
is not provided in Graph mode the func
would be called to infer the
output dtypes.
Returns | |
---|---|
A wrapped function that can be used with TF code. |