Compiles a function into a callable TensorFlow graph. (deprecated arguments)
View aliases
Compat aliases for migration
See Migration guide for more details.
tf.function(
func=None, input_signature=None, autograph=True, jit_compile=None,
experimental_implements=None, experimental_autograph_options=None,
experimental_relax_shapes=False, experimental_compile=None,
experimental_follow_type_hints=None
)
tf.function
constructs a callable that executes a TensorFlow graph
(tf.Graph
) created by trace-compiling the TensorFlow operations in func
,
effectively executing func
as a TensorFlow graph.
Example usage:
@tf.function
def f(x, y):
return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor: ... numpy=array([7, 7], ...)>
Features
func
may use data-dependent control flow, including if
, for
, while
break
, continue
and return
statements:
@tf.function
def f(x):
if tf.reduce_sum(x) > 0:
return x * x
else:
return -x // 2
f(tf.constant(-2))
<tf.Tensor: ... numpy=1>
func
's closure may include tf.Tensor
and tf.Variable
objects:
@tf.function
def f():
return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor: ... numpy=array([7, 7], ...)>
func
may also use ops with side effects, such as tf.print
, tf.Variable
and others:
v = tf.Variable(1)
@tf.function
def f(x):
for i in tf.range(x):
v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>
l = []
@tf.function
def f(x):
for i in x:
l.append(i + 1) # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]
Instead, use TensorFlow collections like tf.TensorArray
:
@tf.function
def f(x):
ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
for i in range(len(x)):
ta = ta.write(i, x[i] + 1)
return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor: ..., numpy=array([2, 3, 4], ...)>
tf.function
is polymorphic
Internally, tf.function
can build more than one graph, to support arguments
with different data types or shapes, since TensorFlow can build more
efficient graphs that are specialized on shapes and dtypes. tf.function
also treats any pure Python value as opaque objects, and builds a separate
graph for each set of Python arguments that it encounters.
To obtain an individual graph, use the get_concrete_function
method of
the callable created by tf.function
. It can be called with the same
arguments as func
and returns a special tf.Graph
object:
@tf.function
def f(x):
return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
True
@tf.function
def f(x):
return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2) # Slow - builds new graph
f1 is f2
False
f1 = f.get_concrete_function(tf.constant(1))
f2 = f.get_concrete_function(tf.constant(2)) # Fast - reuses f1
f1 is f2
True
Python numerical arguments should only be used when they take few distinct values, such as hyperparameters like the number of layers in a neural network.
Input signatures
For Tensor arguments, tf.function
instantiates a separate graph for every
unique set of input shapes and datatypes. The example below creates two
separate graphs, each specialized to a different shape:
@tf.function
def f(x):
return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False
An "input signature" can be optionally provided to tf.function
to control
the graphs traced. The input signature specifies the shape and type of each
Tensor argument to the function using a tf.TensorSpec
object. More general
shapes can be used. This is useful to avoid creating multiple graphs when
Tensors have dynamic shapes. It also restricts the shape and datatype of
Tensors that can be used:
@tf.function(
input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def f(x):
return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
True
Variables may only be created once
tf.function
only allows creating new tf.Variable
objects when it is called
for the first time:
class MyModule(tf.Module):
def __init__(self):
self.v = None
@tf.function
def __call__(self, x):
if self.v is None:
self.v = tf.Variable(tf.ones_like(x))
return self.v * x
In general, it is recommended to create stateful objects like tf.Variable
outside of tf.function
and passing them as arguments.
Using type annotations to improve performance
'experimental_follow_type_hints` can be used along with type annotations to
improve performance by reducing the number of expensive graph retracings.
For example, an argument annotated with tf.Tensor
is converted to Tensor
even when the input is a non-Tensor value.
@tf.function(experimental_follow_type_hints=True)
def f_with_hints(x: tf.Tensor):
print('Tracing')
return x
@tf.function(experimental_follow_type_hints=False)
def f_no_hints(x: tf.Tensor):
print('Tracing')
return x
f_no_hints(1)
Tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>
f_no_hints(2)
Tracing
<tf.Tensor: shape=(), dtype=int32, numpy=2>
f_with_hints(1)
Tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>
f_with_hints(2)
<tf.Tensor: shape=(), dtype=int32, numpy=2>
Args | |
---|---|
func
|
the function to be compiled. If func is None, tf.function returns
a decorator that can be invoked with a single argument - func . In other
words, tf.function(input_signature=...)(func) is equivalent to
tf.function(func, input_signature=...) . The former can be used as
decorator.
|
input_signature
|
A possibly nested sequence of tf.TensorSpec objects
specifying the shapes and dtypes of the Tensors that will be supplied to
this function. If None , a separate function is instantiated for each
inferred input signature. If input_signature is specified, every input to
func must be a Tensor , and func cannot accept **kwargs .
|
autograph
|
Whether autograph should be applied on func before tracing a
graph. Data-dependent control flow requires autograph=True . For more
information, see the tf.function and AutoGraph guide.
|
jit_compile
|
If True , compiles the function using
XLA. XLA performs compiler optimizations,
such as fusion, and attempts to emit more efficient code. This may
drastically improve the performance. If set to True ,
the whole function needs to be compilable by XLA, or an
errors.InvalidArgumentError is thrown.
If None (default), compiles the function with XLA when running on TPU
and goes through the regular function execution path when running on
other devices.
If False , executes the function without XLA compilation. Set this value
to False when directly running a multi-device function on TPUs (e.g. two
TPU cores, one TPU core and its host CPU).
Not all functions are compilable, see a list of
sharp corners.
|
experimental_implements
|
If provided, contains a name of a "known" function
this implements. For example "mycompany.my_recurrent_cell".
This is stored as an attribute in inference function,
which can then be detected when processing serialized function.
See standardizing composite ops for details. For an example of utilizing this attribute see this example The code above automatically detects and substitutes function that implements "embedded_matmul" and allows TFLite to substitute its own implementations. For instance, a tensorflow user can use this attribute to mark that their function also implements embedded_matmul (perhaps more efficiently!)
by specifying it using this parameter:
@tf.function(experimental_implements="embedded_matmul")
This can either be specified as just the string name of the function or
a NameAttrList corresponding to a list of key-value attributes associated
with the function name. The name of the function will be in the 'name'
field of the NameAttrList. To define a formal TF op for this function
implements, try the experimental composite TF
project.
|
experimental_autograph_options
|
Optional tuple of
tf.autograph.experimental.Feature values.
|
experimental_relax_shapes
|
When True, tf.function may generate fewer,
graphs that are less specialized on input shapes.
|
experimental_compile
|
Deprecated alias to 'jit_compile'. |
experimental_follow_type_hints
|
When True, the function may use type
annotations from func to optimize the tracing performance. For example,
arguments annotated with tf.Tensor will automatically be converted
to a Tensor.
|
Returns | |
---|---|
If func is not None, returns a callable that will execute the compiled
function (and return zero or more tf.Tensor objects).
If func is None, returns a decorator that, when invoked with a single
func argument, returns a callable equivalent to the case above.
|
Raises | |
---|---|
ValueError when attempting to use jit_compile=True, but XLA support is not linked. |