tf.wrap_function

View source on GitHub

Wraps the TF 1.x function fn into a graph function.

The python function fn will be called once with symbolic arguments specified in the signature, traced, and turned into a graph function. Any variables created by fn will be owned by the object returned by wrap_function. The resulting graph function can be called with tensors which match the signature.

def f(x, do_add):
  v = tf.Variable(5.0)
  if do_add:
    op = v.assign_add(x)
  else:
    op = v.assign_sub(x)
  with tf.control_dependencies([op]):
    return v.read_value()

f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])

assert float(f_add(1.0)) == 6.0
assert float(f_add(1.0)) == 7.0

# Can call tf.compat.v1.wrap_function again to get a new trace, a new set
# of variables, and possibly different non-template arguments.
f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])

assert float(f_sub(1.0)) == 4.0
assert float(f_sub(1.0)) == 3.0

Both tf.compat.v1.wrap_function and tf.function create a callable TensorFlow graph. But while tf.function runs all stateful operations (e.g. tf.print) and sequences operations to provide the same semantics as eager execution, wrap_function is closer to the behavior of session.run in TensorFlow 1.x. It will not run any operations unless they are required to compute the function's outputs, either through a data dependency or a control dependency. Nor will it sequence operations.

Unlike tf.function, wrap_function will only trace the Python function once. As with placeholders in TF 1.x, shapes and dtypes must be provided to wrap_function's signature argument.

Since it is only traced once, variables and state may be created inside the function and owned by the function wrapper object.

fn python function to be wrapped
signature the placeholder and python arguments to be passed to the wrapped function
name Optional. The name of the function.

the wrapped graph function.