Checkpointer et PolicySaver

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

introduction

tf_agents.utils.common.Checkpointer est un utilitaire pour sauvegarder / charger l'état de la formation, l' état de la politique et de l' état de replay_buffer vers / depuis un stockage local.

tf_agents.policies.policy_saver.PolicySaver est un outil pour enregistrer / charger uniquement la politique, et est plus léger que Checkpointer . Vous pouvez utiliser PolicySaver pour déployer le modèle et sans aucune connaissance du code qui a créé la politique.

Dans ce tutoriel, nous utiliserons DQN pour former un modèle, puis utilisez Checkpointer et PolicySaver pour montrer comment nous pouvons stocker et charger de manière interactive les états et le modèle. Notez que nous allons utiliser de nouveaux outils de saved_model de TF2.0 et le format pour PolicySaver .

Installer

Si vous n'avez pas installé les dépendances suivantes, exécutez :

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython

try:
  from google.colab import files
except ImportError:
  files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()

Agent DQN

Nous allons mettre en place l'agent DQN, comme dans le précédent colab. Les détails sont masqués par défaut car ils ne font pas partie intégrante de cette collaboration, mais vous pouvez cliquer sur « AFFICHER LE CODE » pour voir les détails.

Hyperparamètres

env_name = "CartPole-v1"

collect_steps_per_iteration = 100
replay_buffer_capacity = 100000

fc_layer_params = (100,)

batch_size = 64
learning_rate = 1e-3
log_interval = 5

num_eval_episodes = 10
eval_interval = 1000

Environnement

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Agent

Collecte de données

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

Former l'agent

Génération vidéo

Générer une vidéo

Vérifiez les performances de la politique en générant une vidéo.

print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step:
<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>

gif

Configurer Checkpointer et PolicySaver

Nous sommes maintenant prêts à utiliser Checkpointer et PolicySaver.

Point de contrôle

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

Économiseur de politique

policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

Former une itération

print('Training one iteration....')
train_one_iteration()
Training one iteration....
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
iteration: 1 loss: 1.0214563608169556

Enregistrer au point de contrôle

train_checkpointer.save(global_step)

Restaurer le point de contrôle

Pour que cela fonctionne, l'ensemble des objets doit être recréé de la même manière que lors de la création du point de contrôle.

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

Enregistrez également la politique et exportez vers un emplacement

tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets

La politique peut être chargée sans avoir aucune connaissance de l'agent ou du réseau utilisé pour la créer. Cela rend le déploiement de la politique beaucoup plus facile.

Chargez la politique enregistrée et vérifiez ses performances

saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

Exporter et importer

Le reste de la collaboration vous aidera à exporter/importer des répertoires de point de contrôle et de stratégie de sorte que vous puissiez continuer la formation à un stade ultérieur et déployer le modèle sans avoir à vous entraîner à nouveau.

Vous pouvez maintenant revenir à « Entraîner une itération » et vous entraîner plusieurs fois de manière à comprendre la différence plus tard. Une fois que vous commencez à voir des résultats légèrement meilleurs, continuez ci-dessous.

Créez un fichier zip et téléchargez le fichier zip (double-cliquez pour voir le code)

Créez un fichier compressé à partir du répertoire de point de contrôle.

train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))

Téléchargez le fichier zip.

if files is not None:
  files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Après une formation pendant un certain temps (10 à 15 fois), téléchargez le fichier zip du point de contrôle et accédez à « Exécution > Redémarrer et exécuter tout » pour réinitialiser la formation et revenir à cette cellule. Vous pouvez maintenant télécharger le fichier zip téléchargé et continuer la formation.

upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

Une fois que vous avez téléchargé le répertoire de point de contrôle, revenez à « Train one iteration » pour continuer la formation ou revenez à « Générer une vidéo » pour vérifier les performances de la politique chargée.

Vous pouvez également enregistrer la stratégie (modèle) et la restaurer. Contrairement à checkpointer, vous ne pouvez pas continuer la formation, mais vous pouvez toujours déployer le modèle. Notez que le fichier téléchargé est beaucoup plus petit que celui du point de contrôle.

tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
  files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Téléchargez le répertoire de stratégie téléchargé (exported_policy.zip) et vérifiez les performances de la stratégie enregistrée.

upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

EnregistréModèlePyTFEagerPolitique

Si vous ne souhaitez pas utiliser la politique de TF, vous pouvez également utiliser le saved_model directement avec env Python par l'utilisation de py_tf_eager_policy.SavedModelPyTFEagerPolicy .

Notez que cela ne fonctionne que lorsque le mode impatient est activé.

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())

# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)

gif

Convertir la politique en TFLite

Voir convertisseur tensorflow Lite pour plus de détails.

converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
  f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Exécuter l'inférence sur le modèle TFLite

Voir tensorflow Lite Inference pour plus de détails.

import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))

policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
    '0/discount':tf.constant(0.0),
    '0/observation':tf.zeros([1,4]),
    '0/reward':tf.constant(0.0),
    '0/step_type':tf.constant(0)})
{'action': array([0])}