Evaluates a lattice using hypercube interpolation.
tfl.lattice_lib.evaluate_with_hypercube_interpolation(
inputs, kernel, units, lattice_sizes, clip_inputs
)
Lattice function is multi-linearly interpolated between the 2^d vertices of a
hypercube. This interpolation method is typically slower than simplex
interpolation, since each value is interpolated from 2^d hypercube corners,
rather than d+1 simplex corners. For details, see e.g. "Dissection of the
hypercube into simplices", D.G. Mead, Proceedings of the AMS, 76:2, Sep. 1979.
Args |
inputs
|
Tensor representing points to apply lattice interpolation to. If
units = 1, tensor should be of shape: (batch_size, ...,
len(lattice_sizes)) or list of len(lattice_sizes) tensors of same
shape (batch_size, ..., 1) .
If units > 1, tensor should be of shape: (batch_size, ..., units,
len(lattice_sizes)) or list of len(lattice_sizes) tensors of same
shape (batch_size, ..., units, 1) . A typical shape is (batch_size,
len(lattice_sizes)) .
|
kernel
|
Lattice kernel of shape (num_params_per_lattice, units).
|
units
|
Output dimension of the lattice.
|
lattice_sizes
|
List or tuple of integers which represents lattice sizes of
layer for which interpolation is being computed.
|
clip_inputs
|
Whether inputs should be clipped to the input range of the
lattice.
|
Returns |
Tensor of shape: (batch_size, ..., units) .
|