View source on GitHub |
RandomVariable
supports random variable semantics for TFP distributions.
Inherits From: DeferredTensor
tfp.experimental.nn.util.RandomVariable(
distribution,
convert_to_tensor_fn=tfp.distributions.Distribution.sample
,
dtype=None,
shape=None,
name=None
)
The RandomVariable
class memoizes concretizations of TFP distribution-like
objects so that random draws can be re-triggered on-demand, i.e., by calling
reset
. For more details type help(tfp.util.DeferredTensor)
.
Examples
# In this example we see the memoization semantics in action.
tfd = tfp.distributions
tfn = tfp.experimental.nn
x = tfn.util.RandomVariable(tfd.Normal(0, 1))
x_ = tf.convert_to_tensor(x)
x _ + 1. == x + 1.
# ==> True; `x` always has the same value until reset.
x.reset()
tf.convert_to_tensor(x) == x_
# ==> False; `x` was reset which triggers a new sample.
# In this example we see how to concretize with different semantics.
tfd = tfp.distributions
tfn = tfp.experimental.nn
x = tfn.util.RandomVariable(
tfd.Bernoulli(probs=[[0.25], [0.5]]),
convert_to_tensor_fn=tfd.Distribution.mean,
dtype=tf.float32,
shape=[2, 1],
name='x')
tf.convert_to_tensor(x)
# ==> [[0.25], [0.5]]
x.shape
# ==> [2, 1]
x.dtype
# ==> tf.float32
x.name
# ==> 'x'
# In this example we see a common pitfall: accessing the memoized value from a
# different graph context.
tfd = tfp.distributions
tfn = tfp.experimental.nn
x = tfn.util.RandomVariable(tfd.Normal(0, 1))
@tf.function(autograph=False, jit_compile=True)
def run():
return tf.convert_to_tensor(x)
first = run()
second = tf.convert_to_tensor(x)
# raises ValueError:
# "You are attempting to access a memoized value from a different
# graph context. Please call `this.reset()` before accessing a
# memoized value from a different graph context."
x.reset()
third = tf.convert_to_tensor(x)
# ==> No exception.
first == third
# ==> False
Args | |
---|---|
distribution
|
TFP distribution-like object which is passed into the
convert_to_tensor_fn whenever this object is evaluated in
Tensor -like contexts.
|
convert_to_tensor_fn
|
Python callable which takes one argument, the
distribution and returns a Tensor of type dtype and shape shape .
Default value: tfp.distributions.Distribution.sample .
|
dtype
|
TF dtype equivalent to what would otherwise be
convert_to_tensor_fn(distribution).dtype .
Default value: None (i.e., distribution.dtype ).
|
shape
|
tf.TensorShape -like object compatible with what would otherwise
be convert_to_tensor_fn(distribution).shape .
Default value: 'None' (i.e., unspecified static shape).
|
name
|
Python str representing this object's name ; used only in graph
mode.
Default value: None (i.e., distribution.name )
|
Attributes | |
---|---|
also_track
|
Additional variables tracked by tf.Module in self.trainable_variables. |
convert_to_tensor_fn
|
|
distribution
|
|
dtype
|
Represents the type of the elements in a Tensor .
|
name
|
The string name of this object. |
name_scope
|
Returns a tf.name_scope instance for this class.
|
non_trainable_variables
|
Sequence of non-trainable variables owned by this module and its submodules. |
pretransformed_input
|
Input to transform_fn .
|
shape
|
Represents the shape of a Tensor .
|
submodules
|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
trainable_variables
|
Sequence of trainable variables owned by this module and its submodules. |
transform_fn
|
Function which characterizes the Tensor ization of this object.
|
variables
|
Sequence of variables owned by this module and its submodules. |
Methods
is_unset
is_unset()
Returns True
if there is no memoized value and False
otherwise.
numpy
numpy()
Returns (copy of) deferred values as a NumPy array or scalar.
reset
reset()
Removes memoized value which triggers re-eval on subsequent reads.
set_shape
set_shape(
shape
)
Updates the shape of this pretransformed_input.
This method can be called multiple times, and will merge the given shape
with the current shape of this object. It can be used to provide additional
information about the shape of this object that cannot be inferred from the
graph alone.
Args | |
---|---|
shape
|
A TensorShape representing the shape of this
pretransformed_input , a TensorShapeProto , a list, a tuple, or None.
|
Raises | |
---|---|
ValueError
|
If shape is not compatible with the current shape of this
pretransformed_input .
|
with_name_scope
@classmethod
with_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):
@tf.Module.with_name_scope
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
return tf.matmul(x, self.w)
Using the above module would produce tf.Variable
s and tf.Tensor
s whose
names included the module name:
mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>
Args | |
---|---|
method
|
The method to wrap. |
Returns | |
---|---|
The original method wrapped such that it enters the module's name scope. |
__abs__
__abs__(
*args, **kwargs
)
__add__
__add__(
*args, **kwargs
)
__and__
__and__(
*args, **kwargs
)
__array__
__array__(
dtype=None
)
__bool__
__bool__()
Dummy method to prevent a tensor from being used as a Python bool
.
This overload raises a TypeError
when the user inadvertently
treats a Tensor
as a boolean (most commonly in an if
or while
statement), in code that was not converted by AutoGraph. For example:
if tf.constant(True): # Will raise.
# ...
if tf.constant(5) < tf.constant(7): # Will raise.
# ...
Raises | |
---|---|
TypeError .
|
__div__
__div__(
*args, **kwargs
)
__floordiv__
__floordiv__(
*args, **kwargs
)
__ge__
__ge__(
*args, **kwargs
)
__getitem__
__getitem__(
*args, **kwargs
)
__gt__
__gt__(
*args, **kwargs
)
__invert__
__invert__(
*args, **kwargs
)
__iter__
__iter__(
*args, **kwargs
)
__le__
__le__(
*args, **kwargs
)
__lt__
__lt__(
*args, **kwargs
)
__matmul__
__matmul__(
*args, **kwargs
)
__mod__
__mod__(
*args, **kwargs
)
__mul__
__mul__(
*args, **kwargs
)
__neg__
__neg__(
*args, **kwargs
)
__nonzero__
__nonzero__()
Dummy method to prevent a tensor from being used as a Python bool
.
This is the Python 2.x counterpart to __bool__()
above.
Raises | |
---|---|
TypeError .
|
__or__
__or__(
*args, **kwargs
)
__pow__
__pow__(
*args, **kwargs
)
__radd__
__radd__(
*args, **kwargs
)
__rand__
__rand__(
*args, **kwargs
)
__rdiv__
__rdiv__(
*args, **kwargs
)
__rfloordiv__
__rfloordiv__(
*args, **kwargs
)
__rmatmul__
__rmatmul__(
*args, **kwargs
)
__rmod__
__rmod__(
*args, **kwargs
)
__rmul__
__rmul__(
*args, **kwargs
)
__ror__
__ror__(
*args, **kwargs
)
__rpow__
__rpow__(
*args, **kwargs
)
__rsub__
__rsub__(
*args, **kwargs
)
__rtruediv__
__rtruediv__(
*args, **kwargs
)
__rxor__
__rxor__(
*args, **kwargs
)
__sub__
__sub__(
*args, **kwargs
)
__truediv__
__truediv__(
*args, **kwargs
)
__xor__
__xor__(
*args, **kwargs
)