Camadas Probabilísticas TFP: Codificador Automático Variacional

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Neste exemplo, mostramos como ajustar um Autoencoder Variacional usando as "camadas probabilísticas" da TFP.

Dependências e pré-requisitos

Importar

Torne as coisas mais rápidas!

Antes de começarmos, vamos ter certeza de que estamos usando uma GPU para esta demonstração.

Para fazer isso, selecione "Runtime" -> "Alterar tipo de tempo de execução" -> "Acelerador de hardware" -> "GPU".

O snippet a seguir verificará se temos acesso a uma GPU.

if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))
SUCCESS: Found GPU: /device:GPU:0

Carregar conjunto de dados

datasets, datasets_info = tfds.load(name='mnist',
                                    with_info=True,
                                    as_supervised=False)

def _preprocess(sample):
  image = tf.cast(sample['image'], tf.float32) / 255.  # Scale to unit interval.
  image = image < tf.random.uniform(tf.shape(image))   # Randomly binarize.
  return image, image

train_dataset = (datasets['train']
                 .map(_preprocess)
                 .batch(256)
                 .prefetch(tf.data.AUTOTUNE)
                 .shuffle(int(10e3)))
eval_dataset = (datasets['test']
                .map(_preprocess)
                .batch(256)
                .prefetch(tf.data.AUTOTUNE))

Note-se que pré-processamento () retorna acima image, image , em vez de apenas image porque Keras está configurado para modelos discriminativos com um (exemplo, etiqueta) formato de entrada, ou seja, \(p\theta(y|x)\). Uma vez que o objectivo do VAE é recuperar os x entrada do próprio x (ou seja, \(p_\theta(x|x)\)), a par de dados é (exemplo, exemplo).

VAE Code Golf

Especifique o modelo.

input_shape = datasets_info.features['image'].shape
encoded_size = 16
base_depth = 32
prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1),
                        reinterpreted_batch_ndims=1)
encoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=input_shape),
    tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
    tfkl.Conv2D(base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(2 * base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(2 * base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(4 * encoded_size, 7, strides=1,
                padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Flatten(),
    tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),
               activation=None),
    tfpl.MultivariateNormalTriL(
        encoded_size,
        activity_regularizer=tfpl.KLDivergenceRegularizer(prior)),
])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.
Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py:158: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.
Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.
decoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=[encoded_size]),
    tfkl.Reshape([1, 1, encoded_size]),
    tfkl.Conv2DTranspose(2 * base_depth, 7, strides=1,
                         padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(2 * base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(2 * base_depth, 5, strides=2,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=2,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(filters=1, kernel_size=5, strides=1,
                padding='same', activation=None),
    tfkl.Flatten(),
    tfpl.IndependentBernoulli(input_shape, tfd.Bernoulli.logits),
])
vae = tfk.Model(inputs=encoder.inputs,
                outputs=decoder(encoder.outputs[0]))

Faça inferências.

negloglik = lambda x, rv_x: -rv_x.log_prob(x)

vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
            loss=negloglik)

_ = vae.fit(train_dataset,
            epochs=15,
            validation_data=eval_dataset)
Epoch 1/15
235/235 [==============================] - 14s 61ms/step - loss: 206.5541 - val_loss: 163.1924
Epoch 2/15
235/235 [==============================] - 14s 59ms/step - loss: 151.1891 - val_loss: 143.6748
Epoch 3/15
235/235 [==============================] - 14s 58ms/step - loss: 141.3275 - val_loss: 137.9188
Epoch 4/15
235/235 [==============================] - 14s 58ms/step - loss: 136.7453 - val_loss: 133.2726
Epoch 5/15
235/235 [==============================] - 14s 58ms/step - loss: 132.3803 - val_loss: 131.8343
Epoch 6/15
235/235 [==============================] - 14s 58ms/step - loss: 129.2451 - val_loss: 127.1935
Epoch 7/15
235/235 [==============================] - 14s 59ms/step - loss: 126.0975 - val_loss: 123.6789
Epoch 8/15
235/235 [==============================] - 14s 58ms/step - loss: 124.0565 - val_loss: 122.5058
Epoch 9/15
235/235 [==============================] - 14s 58ms/step - loss: 122.9974 - val_loss: 121.9544
Epoch 10/15
235/235 [==============================] - 14s 58ms/step - loss: 121.7349 - val_loss: 120.8735
Epoch 11/15
235/235 [==============================] - 14s 58ms/step - loss: 121.0856 - val_loss: 120.1340
Epoch 12/15
235/235 [==============================] - 14s 58ms/step - loss: 120.2232 - val_loss: 121.3554
Epoch 13/15
235/235 [==============================] - 14s 58ms/step - loss: 119.8123 - val_loss: 119.2351
Epoch 14/15
235/235 [==============================] - 14s 58ms/step - loss: 119.2685 - val_loss: 118.2133
Epoch 15/15
235/235 [==============================] - 14s 59ms/step - loss: 118.8895 - val_loss: 119.4771

Olha mãe não Mãos Tensores!

# We'll just examine ten random digits.
x = next(iter(eval_dataset))[0][:10]
xhat = vae(x)
assert isinstance(xhat, tfd.Distribution)

Utilitário de plotagem de imagem

print('Originals:')
display_imgs(x)

print('Decoded Random Samples:')
display_imgs(xhat.sample())

print('Decoded Modes:')
display_imgs(xhat.mode())

print('Decoded Means:')
display_imgs(xhat.mean())
Originals:

png

Decoded Random Samples:

png

Decoded Modes:

png

Decoded Means:

png

# Now, let's generate ten never-before-seen digits.
z = prior.sample(10)
xtilde = decoder(z)
assert isinstance(xtilde, tfd.Distribution)
print('Randomly Generated Samples:')
display_imgs(xtilde.sample())

print('Randomly Generated Modes:')
display_imgs(xtilde.mode())

print('Randomly Generated Means:')
display_imgs(xtilde.mean())
Randomly Generated Samples:

png

Randomly Generated Modes:

png

Randomly Generated Means:

png