Constructs a function that broadcasts inputs over named axes.
tfp.experimental.distribute.make_pbroadcast_function(
fn, in_axes, out_axes, out_dtype
)
Given a function fn
, make_pbroadcast_function
returns a new one that
applies pbroadcast
to input terms according to axis names provided in
in_axes
and out_axes
. For each output axis in each term out the output of
fn
, inputs that do not have the output axes present are pbroadcasted before
that term is computed.
Args |
fn
|
a callable to be transformed to have proadcasts at its inputs.
|
in_axes
|
A structure of axis names that should match the structure of the
input to fn . If the set of input axes for an input value does not match
the output axes of a particular output value, the gradient of that output
value w.r.t. the input value will be psum-ed over the axes present in the
output but not the input.
|
out_axes
|
A structure of axis names that should match the structure of the
output of fn . The inputs to fn will be pbroadcast-ed before computing
output terms according to their output axes.
|
out_dtype
|
A structure of dtypes that matches the output of fn .
|
Returns |
A new function that applies pbroadcasts to the inputs of the original
function.
|