![]() |
![]() |
![]() |
![]() |
Setup
First install packages used in this demo.
pip install -q dm-sonnet
Imports (tf, tfp with adjoint trick, etc)
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead. import pandas.util.testing as tm
Helper functions for visualization
FFJORD bijector
In this colab we demonstrate FFJORD bijector, originally proposed in the paper by Grathwohl, Will, et al. arxiv link.
In the nutshell the idea behind such approach is to establish a correspondence between a known base distribution and the data distribution.
To establish this connection, we need to
- Define a bijective map \(\mathcal{T}_{\theta}:\mathbf{x} \rightarrow \mathbf{y}\), \(\mathcal{T}_{\theta}^{1}:\mathbf{y} \rightarrow \mathbf{x}\) between the space \(\mathcal{Y}\) on which base distribution is defined and space \(\mathcal{X}\) of the data domain.
- Efficiently keep track of the deformations we perform to transfer the notion of probability onto \(\mathcal{X}\).
The second condition is formalized in the following expression for probability distribution defined on \(\mathcal{X}\):
\[ \log p_{\mathbf{x} }(\mathbf{x})=\log p_{\mathbf{y} }(\mathbf{y})-\log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| \]
FFJORD bijector accomplishes this by defining a transformation
\[ \mathcal{T_{\theta} }: \mathbf{x} = \mathbf{z}(t_{0}) \rightarrow \mathbf{y} = \mathbf{z}(t_{1}) \quad : \quad \frac{d \mathbf{z} }{dt} = \mathbf{f}(t, \mathbf{z}, \theta) \]
This transformation is invertible, as long as function \(\mathbf{f}\) describing the evolution of the state \(\mathbf{z}\) is well behaved and the log_det_jacobian
can be calculated by integrating the following expression.
\[ \log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| = -\int_{t_{0} }^{t_{1} } \operatorname{Tr}\left(\frac{\partial \mathbf{f}(t, \mathbf{z}, \theta)}{\partial \mathbf{z}(t)}\right) d t \]
In this demo we will train a FFJORD bijector to warp a gaussian distribution onto the distribution defined by moons
dataset. This will be done in 3 steps:
- Define base distribution
- Define FFJORD bijector
- Minimize exact log-likelihood of the dataset
First, we load the data
Dataset
Next, we instantiate a base distribution
base_loc = np.array([0.0, 0.0]).astype(np.float32)
base_sigma = np.array([0.8, 0.8]).astype(np.float32)
base_distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)
We use a multi-layer perceptron to model state_derivative_fn
.
While not necessary for this dataset, it is often benefitial to make state_derivative_fn
dependent on time. Here we achieve this by concatenating t
to inputs of our network.
class MLP_ODE(snt.Module):
"""Multi-layer NN ode_fn."""
def __init__(self, num_hidden, num_layers, num_output, name='mlp_ode'):
super(MLP_ODE, self).__init__(name=name)
self._num_hidden = num_hidden
self._num_output = num_output
self._num_layers = num_layers
self._modules = []
for _ in range(self._num_layers - 1):
self._modules.append(snt.Linear(self._num_hidden))
self._modules.append(tf.math.tanh)
self._modules.append(snt.Linear(self._num_output))
self._model = snt.Sequential(self._modules)
def __call__(self, t, inputs):
inputs = tf.concat([tf.broadcast_to(t, inputs.shape), inputs], -1)
return self._model(inputs)
Model and training parameters
Now we construct a stack of FFJORD bijectors. Each bijector is provided with ode_solve_fn
and trace_augmentation_fn
and it's own state_derivative_fn
model, so that they represent a sequence of different transformations.
Building bijector
Now we can use TransformedDistribution
which is the result of warping base_distribution
with stacked_ffjord
bijector.
transformed_distribution = tfd.TransformedDistribution(
distribution=base_distribution, bijector=stacked_ffjord)
Now we define our training procedure. We simply minimize negative log-likelihood of the data.
Training
Samples
Plot samples from base and transformed distributions.
evaluation_samples = []
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append((base_samples, transformed_samples, transformed_grid))
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers.
panel_id = 0
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
1, 4, figsize=(16, 6))
plot_panel(
grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray, False)
plt.tight_layout()
learning_rate = tf.Variable(LR, trainable=False)
optimizer = snt.optimizers.Adam(learning_rate)
for epoch in tqdm.trange(NUM_EPOCHS // 2):
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append(
(base_samples, transformed_samples, transformed_grid))
for batch in moons_ds:
_ = train_step(optimizer, batch)
0%| | 0/40 [00:00<?, ?it/s] WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/ode/base.py:350: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.while_loop(c, b, vars, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars)) 100%|██████████| 40/40 [07:00<00:00, 10.52s/it]
panel_id = -1
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
1, 4, figsize=(16, 6))
plot_panel(grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray)
plt.tight_layout()
Training it for longer with learning rate results in further improvements.
Not convered in this example, FFJORD bijector supports hutchinson's stochastic trace estimation. The particular estimator can be provided via trace_augmentation_fn
. Similarly alternative integrators can be used by defining custom ode_solve_fn
.