View source on GitHub |
An arbitrary already-batched computation, a 'primitive operation'.
tfp.experimental.auto_batching.instructions.PrimOp(
vars_in, vars_out, function, skip_push_mask
)
These are the items of work on which auto-batching is applied. The
function
must accept and produce Tensors with a batch dimension,
and is free to stage any (batched) computation it wants.
Restriction: the function
must use the same computation substrate
as the VM backend. That is, if the VM is staging to XLA, the
function
will see XLA Tensor handles; if the VM is staging to
graph-mode TensorFlow, the function
will see TensorFlow Tensors;
etc.
The current values of the vars_out
are saved on their respective
stacks, and the results written to the new top.
The exact contract for function
is as follows:
- It will be invoked with a list of positional (only) arguments,
parallel to
vars_in
. - Each argument will be a pattern of Tensors (meaning, either one
Tensor or a (potentially nested) list or tuple of Tensors),
corresponding to the
Type
of that variable. - Each Tensor in the argument will have the
dtype
andshape
given in the correspondingTensorType
, and an additional leading batch dimension. - Some indices in the batch dimension may contain junk data, if the corresponding threads are not executing this instruction [this is subject to change based on the batch execution strategy].
- The
function
must return a pattern of Tensors, or objects convertible to Tensors. - The returned pattern must be compatible with the
Type
s ofvars_out
. - The Tensors in the returned pattern must have
dtype
andshape
compatible with the correspondingTensorType
s ofvars_out
. - The returned Tensors will be broadcast into their respective positions if necessary. The broadcasting includes the batch dimension: Thus, a returned Tensor of insufficient rank (e.g., a constant) will be broadcast across batch members. In particular, a Tensor that carries the indended batch size but whose sub-batch shape is too low rank will broadcast incorrectly, and will result in an error.
- If the
function
raises an exception, it will propagate and abort the entire computation. - Even in the TensorFlow backend, the
function
will be staged several times: at least twice during type inference (to ascertain the shapes of the Tensors it likes to return, as a function of the shapes of the Tensors it is given), and exactly once during executable graph construction.
Methods
replace
replace(
vars_out=None
)
Return a copy of self
with vars_out
replaced.