View source on GitHub |
Use particle filtering to sample from the posterior over trajectories.
tfp.experimental.mcmc.infer_trajectories(
observations,
initial_state_prior,
transition_fn,
observation_fn,
num_particles,
initial_state_proposal=None,
proposal_fn=None,
resample_fn=tfp.experimental.mcmc.resample_systematic
,
resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold
,
unbiased_gradients=True,
rejuvenation_kernel_fn=None,
num_transitions_per_observation=1,
seed=None,
name=None
)
Each latent state is a Tensor
or nested structure of Tensor
s, as defined
by the initial_state_prior
.
The transition_fn
and proposal_fn
args, if specified, have signature
next_state_dist = fn(step, state)
, where step
is an int
Tensor
index
of the current time step (beginning at zero), and state
represents
the latent state at time step
. The return value is a tfd.Distribution
instance over the state at time step + 1
.
Similarly, the observation_fn
has signature
observation_dist = observation_fn(step, state)
, where the return value
is a distribution over the value(s) observed at time step
.
Args | |
---|---|
observations
|
a (structure of) Tensors, each of shape
concat([[num_observation_steps, b1, ..., bN], event_shape]) with
optional batch dimensions b1, ..., bN .
|
initial_state_prior
|
a (joint) distribution over the initial latent state,
with optional batch shape [b1, ..., bN] .
|
transition_fn
|
callable returning a (joint) distribution over the next latent state. |
observation_fn
|
callable returning a (joint) distribution over the current observation. |
num_particles
|
int Tensor number of particles.
|
initial_state_proposal
|
a (joint) distribution over the initial latent
state, with optional batch shape [b1, ..., bN] . If None , the initial
particles are proposed from the initial_state_prior .
Default value: None .
|
proposal_fn
|
callable returning a (joint) proposal distribution over the
next latent state. If None , the dynamics model is used (
proposal_fn == transition_fn ).
Default value: None .
|
resample_fn
|
Python callable to generate the indices of resampled
particles, given their weights. Generally, one of
tfp.experimental.mcmc.resample_independent or
tfp.experimental.mcmc.resample_systematic , or any function
with the same signature, resampled_indices = f(log_probs, event_size, '
'sample_shape, seed) .
Default: tfp.experimental.mcmc.resample_systematic .
|
resample_criterion_fn
|
optional Python callable with signature
do_resample = resample_criterion_fn(log_weights) ,
where log_weights is a float Tensor of shape
[b1, ..., bN, num_particles] containing log (unnormalized) weights for
all particles at the current step. The return value do_resample
determines whether particles are resampled at the current step. In the
case resample_criterion_fn==None , particles are resampled at every step.
The default behavior resamples particles when the current effective
sample size falls below half the total number of particles.
Default value: tfp.experimental.mcmc.ess_below_threshold .
|
unbiased_gradients
|
If True , use the stop-gradient
resampling trick of Scibior, Masrani, and Wood [2] to
correct for gradient bias introduced by the discrete resampling step. This
will generally increase the variance of stochastic gradients.
Default value: True .
|
rejuvenation_kernel_fn
|
optional Python callable with signature
transition_kernel = rejuvenation_kernel_fn(target_log_prob_fn)
where target_log_prob_fn is a provided callable evaluating
p(x[t] | y[t], x[t-1]) at each step t , and transition_kernel
should be an instance of tfp.mcmc.TransitionKernel .
Default value: None . |
num_transitions_per_observation
|
scalar Tensor positive int number of
state transitions between regular observation points. A value of 1
indicates that there is an observation at every timestep,
2 that every other step is observed, and so on. Values greater than 1
may be used with an appropriately-chosen transition function to
approximate continuous-time dynamics. The initial and final steps
(steps 0 and num_timesteps - 1 ) are always observed.
Default value: None .
|
seed
|
PRNG seed; see tfp.random.sanitize_seed for details.
|
name
|
Python str name for ops created by this method.
Default value: None (i.e., 'infer_trajectories' ).
|
Returns | |
---|---|
trajectories
|
a (structure of) Tensor(s) matching the latent state, each
of shape
concat([[num_timesteps, num_particles, b1, ..., bN], event_shape]) ,
representing unbiased samples from the posterior distribution
p(latent_states | observations) .
|
incremental_log_marginal_likelihoods
|
float Tensor of shape
[num_observation_steps, b1, ..., bN] ,
giving the natural logarithm of an unbiased estimate of
p(observations[t] | observations[:t]) at each timestep t . Note that
(by Jensen's inequality)
this is smaller in expectation than the true
log p(observations[t] | observations[:t]) .
|
Examples
Tracking unknown position and velocity: Let's consider tracking an object
moving in a one-dimensional space. We'll define a dynamical system
by specifying an initial_state_prior
, a transition_fn
,
and observation_fn
.
The structure of the latent state space is determined by the prior distribution. Here, we'll define a state space that includes the object's current position and velocity:
initial_state_prior = tfd.JointDistributionNamed({
'position': tfd.Normal(loc=0., scale=1.),
'velocity': tfd.Normal(loc=0., scale=0.1)})
The transition_fn
specifies the evolution of the system. It should
return a distribution over latent states of the same structure as the prior.
Here, we'll assume that the position evolves according to the velocity,
with a small random drift, and the velocity also changes slowly, following
a random drift:
def transition_fn(_, previous_state):
return tfd.JointDistributionNamed({
'position': tfd.Normal(
loc=previous_state['position'] + previous_state['velocity'],
scale=0.1),
'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)})
The observation_fn
specifies the process by which the system is observed
at each time step. Let's suppose we observe only a noisy version of the =
current position.
def observation_fn(_, state):
return tfd.Normal(loc=state['position'], scale=0.1)
Now let's track our object. Suppose we've been given observations
corresponding to an initial position of 0.4
and constant velocity of 0.01
:
# Generate simulated observations.
observed_positions = tfd.Normal(loc=tf.linspace(0.4, 0.8, 0.01),
scale=0.1).sample()
# Run particle filtering to sample plausible trajectories.
(trajectories, # {'position': [40, 1000], 'velocity': [40, 1000]}
lps) = tfp.experimental.mcmc.infer_trajectories(
observations=observed_positions,
initial_state_prior=initial_state_prior,
transition_fn=transition_fn,
observation_fn=observation_fn,
num_particles=1000)
For all i
, trajectories['position'][:, i]
is a sample from the
posterior over position sequences, given the observations:
p(state[0:T] | observations[0:T])
. Often, the sampled trajectories
will be highly redundant in their earlier timesteps, because most
of the initial particles have been discarded through resampling
(this problem is known as 'particle degeneracy'; see section 3.5 of
[Doucet and Johansen][1]).
In such cases it may be useful to also consider the series of filtering
distributions p(state[t] | observations[:t])
, in which each latent state
is inferred conditioned only on observations up to that point in time; these
may be computed using tfp.mcmc.experimental.particle_filter
.
References
[1] Arnaud Doucet and Adam M. Johansen. A tutorial on particle filtering and smoothing: Fifteen years later. Handbook of nonlinear filtering, 12(656-704), 2009. https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf [2] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle Filtering without Modifying the Forward Pass. arXiv preprint arXiv:2106.10314, 2021. https://arxiv.org/abs/2106.10314