View source on GitHub |
Builds a variational posterior by linearly transforming base distributions.
tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
*args, seed=None, **kwargs
)
This function builds a surrogate posterior by applying a trainable
transformation to a base distribution (typically a tfd.JointDistribution
) or
nested structure of base distributions, and constraining the samples with
bijector
. Note that the distributions must have event shapes corresponding
to the pretransformed surrogate posterior -- that is, if bijector
contains
a shape-changing bijector, then the corresponding base distribution event
shape is the inverse event shape of the bijector applied to the desired
surrogate posterior shape. The surrogate posterior is constucted as follows:
- Flatten the base distribution event shapes to vectors, and pack the base
distributions into a
tfd.JointDistribution
. - Apply a trainable blockwise LinearOperator bijector to the joint base distribution.
- Apply the constraining bijectors and return the resulting trainable
tfd.TransformedDistribution
instance.
Args | |
---|---|
base_distribution
|
tfd.Distribution instance (typically a
tfd.JointDistribution ), or a nested structure of tfd.Distribution
instances.
|
operators
|
Either a string or a list/tuple containing LinearOperator
subclasses, LinearOperator instances, or callables returning
LinearOperator instances. Supported string values are "diag" (to create
a mean-field surrogate posterior) and "tril" (to create a full-covariance
surrogate posterior). A list/tuple may be passed to induce other
posterior covariance structures. If the list is flat, a
tf.linalg.LinearOperatorBlockDiag instance will be created and applied
to the base distribution. Otherwise the list must be singly-nested and
have a first element of length 1, second element of length 2, etc.; the
elements of the outer list are interpreted as rows of a lower-triangular
block structure, and a tf.linalg.LinearOperatorBlockLowerTriangular
instance is created. For complete documentation and examples, see
tfp.experimental.vi.util.build_trainable_linear_operator_block , which
receives the operators arg if it is list-like.
Default value: "diag" .
|
bijector
|
tfb.Bijector instance, or nested structure of tfb.Bijector
instances, that maps (nested) values in R^n to the support of the
posterior. (This can be the experimental_default_event_space_bijector of
the distribution over the prior latent variables.)
Default value: None (i.e., the posterior is over R^n).
|
initial_unconstrained_loc_fn
|
Optional Python callable with signature
initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed) used to
sample real-valued initializations for the unconstrained location of
each variable.
Default value: functools.partial(tf.random.stateless_uniform,
minval=-2., maxval=2., dtype=tf.float32) .
|
validate_args
|
Python bool . Whether to validate input with asserts. This
imposes a runtime cost. If validate_args is False , and the inputs are
invalid, correct behavior is not guaranteed.
Default value: False .
|
name
|
Python str name prefixed to ops created by this function.
Default value: None (i.e.,
'build_affine_surrogate_posterior_from_base_distribution'). seed: PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns | |
---|---|
instance
|
instance parameterized by trainable tf.Variable s.
|
Examples
tfd = tfp.distributions
tfb = tfp.bijectors
# Fit a multivariate Normal surrogate posterior on the Eight Schools model
# [1].
treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
def model_fn():
avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
school_effects = yield tfd.Sample(
tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
sample_shape=[8],
name='school_effects')
treatment_effects = yield tfd.Independent(
tfd.Normal(loc=school_effects, scale=treatment_stddevs),
reinterpreted_batch_ndims=1,
name='treatment_effects')
model = tfd.JointDistributionCoroutineAutoBatched(model_fn)
# Pin the observed values in the model.
target_model = model.experimental_pin(treatment_effects=treatment_effects)
# Define a lower triangular structure of `LinearOperator` subclasses that
# models full covariance among latent variables except for the 8 dimensions
# of `school_effect`, which are modeled as independent (using
# `LinearOperatorDiag`).
operators = [
[tf.linalg.LinearOperatorLowerTriangular],
[tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
[tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,
tf.linalg.LinearOperatorDiag]]
# Constrain the posterior values to the support of the prior.
bijector = target_model.experimental_default_event_space_bijector()
# Build a full-covariance surrogate posterior.
surrogate_posterior = (
tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
base_distribution=base_distribution,
operators=operators,
bijector=bijector))
# Fit the model.
losses = tfp.vi.fit_surrogate_posterior(
target_model.unnormalized_log_prob,
surrogate_posterior,
num_steps=100,
optimizer=tf.optimizers.Adam(0.1),
sample_size=10)
References
[1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013.