デコードAPI

TensorFlow.orgで表示GoogleColabで実行GitHubで表示 ノートブックをダウンロード

概要

最近では、自己回帰モデルを使用した言語生成に関する多くの研究が行われています。自己回帰言語生成では、時間ステップKでのトークンの確率分布は、ステップK-1までのモデルのトークン予測に依存しています。これらのモデルでは、そのようなビームサーチ、貪欲、トップ-P、およびトップ-Kなどのデコードの戦略は、モデルの重要なコンポーネントであり、主に与えられた時間ステップKで生成された出力トークンのスタイル/自然に影響を与えます。

例えば、ビームサーチは、各時間ステップでの仮説の最も可能性の高いnum_beamsを維持し、最終的に全体の確率が最も高い仮説を選択することにより、欠落している隠された高い確率トークンのリスクを低減します。マレーら。 (2018)およびYangら。 (2018)は、ビームサーチは、機械翻訳の作業にうまく機能することを示しています。ビームサーチ貪欲な戦略の両方が繰り返しのトークンを生成する可能性があります。

ファンらら(2018) K可能性が最も高いトークンを濾過し、確率質量をのみKトークン間で再分配されるトップ-Kサンプリングを導入しました。

アリホルツマン他ら(2019)確率P点で最大加算累積確率でトークンの最小の可能なセットから選択トップPサンプリングを導入しました。次に、確率質量がこのセット間で再配分されます。このようにして、トークンのセットのサイズを動的に増減できます。トップ-P、トップ-kは、一般的に、このような話世代などのタスクに使用されています。

Decoding APIは、自己回帰モデルでさまざまなデコード戦略を試すためのインターフェイスを提供します。

  1. 次のサンプリング戦略は、基本のDecodingクラスから継承するsampling_module.pyで提供されます。

  2. ビーム検索はbeam_search.py​​で提供されます。 github

設定

pip install -q -U tensorflow-text
pip install -q tf-models-nightly
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from official import nlp
from official.nlp.modeling.ops import sampling_module
from official.nlp.modeling.ops import beam_search
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release
  PkgResourcesDeprecationWarning,

TF-NLPでサンプリングモジュールを初期化します。

  • symbols_to_logits_fn:使用この閉鎖はのためlogitsを予測するモデルを呼び出すためのindex+1ステップ。このクロージャの入力と出力は次のとおりです。
Args:
  1] ids : Current decoded sequences. int tensor with shape (batch_size, index + 1 or 1 if padded_decode is True)],
  2] index [scalar] : current decoded step,
  3] cache [nested dictionary of tensors] : Only used for faster decoding to store pre-computed attention hidden states for keys and values. More explanation in the cell below.
Returns:
  1] tensor for next-step logits [batch_size, vocab]
  2] the updated_cache [nested dictionary of tensors].

キャッシュは、デコードを高速化するために使用されます。ここで、参照上の閉鎖のための実装。

  • length_normalization_fn:長正規化パラメータを返すために、このクロージャを使用してください。
Args: 
  1] length : scalar for decoded step index.
  2] dtype : data-type of output tensor
Returns:
  1] value of length normalization factor.
  • vocab_size:出力語彙サイズ。

  • max_decode_length:復号ステップの合計数のスカラー。

  • eos_id:バッチ内のすべての出力デコードされたIDがこのeos_idを持っている場合、復号が停止します。

  • padded_decode:TPU上で実行されている場合は、Trueに設定します。これがTrueの場合、テンソルはmax_decoding_lengthに埋め込まれます。

  • top_k:この値は> 1であればtop_kが有効になっています。

  • top_p:この値は> 0かつ<1.0であればtop_pが有効になっています

  • sampling_temperature:これは再見積もりソフトマックス出力するために使用されます。温度は、分布を高確率トークンに偏らせ、裾の分布の質量を下げます。値は正でなければなりません。低温は欲張りと同等であり、分布がより鮮明になり、高温はより平坦になります。

  • enable_greedy:デフォルトでは、これが真であると貪欲デコードが有効になっています。他の戦略を試すには、これをFalseに設定してください。

モデルのハイパーパラメータを初期化します

params = {}
params['num_heads'] = 2
params['num_layers'] = 2
params['batch_size'] = 2
params['n_dims'] = 256
params['max_decode_length'] = 4

トランスベースのような自己回帰アーキテクチャではエンコーダ・デコーダのモデルで、キャッシュは、高速逐次復号のために使用されています。これは、すべてのレイヤーの事前に計算された非表示状態(自己アテンションブロックとクロスアテンションブロックのキーと値)を格納するネストされた辞書です。

キャッシュを初期化します。

cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", cache['layer_1']['k'].shape)
cache key shape for layer 1 : (2, 4, 2, 128)

必要に応じて、長さの正規化のクロージャを定義します。

これは、生成されたシーケンスの最終スコアを正規化するために使用され、オプションです

def length_norm(length, dtype):
  """Return length normalization factor."""
  return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)

model_fnを作成します

実際には、これは以下のような実際のモデルの実装に置き換えられます、ここで

Args:
i : Step that is being decoded.
Returns:
  logit probabilities of size [batch_size, 1, vocab_size]
probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                            [[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])
def model_fn(i):
  return probabilities[:, i, :]

symbol_to_logits_fnを初期化します

def _symbols_to_logits_fn():
  """Calculates logits of the next tokens."""
  def symbols_to_logits_fn(ids, i, temp_cache):
    del ids
    logits = tf.cast(tf.math.log(model_fn(i)), tf.float32)
    return logits, temp_cache
  return symbols_to_logits_fn

貪欲

:貪欲復号は次のIDとして最も高い確率でトークンIDを選択 \(id_t = argmax_{w}P(id | id_{1:t-1})\) 各タイムステップにおける \(t\)。次のスケッチは、貪欲なデコードを示しています。

greedy_obj = sampling_module.SamplingModule(
    length_normalization_fn=None,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False)
ids, _ = greedy_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("Greedy Decoded Ids:", ids)
Greedy Decoded Ids: tf.Tensor(
[[9 1 2 2 2]
 [1 1 1 2 2]], shape=(2, 5), dtype=int32)

top_kサンプリング

トップ-Kのサンプリングでは、Kは、最も可能性の高い次のトークンのIDを濾過され、確率質量はのみK ID中に再分配されます。

top_k_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_k=tf.constant(3),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_k_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-k sampled Ids:", ids)
top-k sampled Ids: tf.Tensor(
[[9 1 0 2 2]
 [1 0 1 2 2]], shape=(2, 5), dtype=int32)

top_pサンプリング

代わりに、サンプリングのみから最も可能性の高いKトークンのID、累積確率確率pを超えるIDの最小の可能なセットから選択をサンプリングトップ-Pです。

top_p_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_p=tf.constant(0.9),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_p_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-p sampled Ids:", ids)
top-p sampled Ids: tf.Tensor(
[[9 1 1 2 2]
 [1 1 1 0 2]], shape=(2, 5), dtype=int32)

ビームサーチデコーディング

ビーム検索は、各タイムステップで最も可能性の高い仮説のnum_beamsを保持し、最終的に全体的に最も高い確率を持つ仮説を選択することにより、隠れた高確率トークンIDを見逃すリスクを軽減します。

beam_size = 2
params['batch_size'] = 1
beam_cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", beam_cache['layer_1']['k'].shape)
ids, _ = beam_search.sequence_beam_search(
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    initial_ids=tf.constant([9], tf.int32),
    initial_cache=beam_cache,
    vocab_size=3,
    beam_size=beam_size,
    alpha=0.6,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False,
    dtype=tf.float32)
print("Beam search ids:", ids)
cache key shape for layer 1 : (1, 4, 2, 256)
Beam search ids: tf.Tensor(
[[[9 0 1 2 2]
  [9 1 2 2 2]]], shape=(1, 2, 5), dtype=int32)