API penguraian kode

Lihat di TensorFlow.org Jalankan di Google Colab Lihat di GitHub Unduh buku catatan

Ringkasan

Di masa lalu, telah ada banyak penelitian dalam generasi bahasa dengan model auto-regresif. Dalam auto-regresif generasi bahasa, distribusi probabilitas token pada saat langkah K tergantung pada model token-prediksi sampai langkah K-1. Untuk model ini, decoding strategi seperti Beam pencarian, Greedy, Top-p, dan Top-k adalah komponen penting dari model dan sebagian besar mempengaruhi gaya / sifat output yang dihasilkan tanda pada langkah waktu K diberikan.

Misalnya, Beam pencarian mengurangi risiko hilang disembunyikan token probabilitas tinggi dengan menjaga num_beams paling mungkin dari hipotesis pada setiap langkah waktu dan akhirnya memilih hipotesis yang memiliki probabilitas keseluruhan tertinggi. Murray dkk. (2018) dan Yang et al. (2018) menunjukkan bahwa pencarian balok bekerja dengan baik dalam tugas-tugas Machine Translation. Kedua Beam pencarian & strategi Greedy memiliki kemungkinan menghasilkan mengulangi token.

kipas angin et. al (2018) memperkenalkan Top-K sampling, di mana K token kemungkinan besar akan disaring dan massa probabilitas didistribusikan di antara hanya mereka token K.

Ari Holtzman et. al (2019) memperkenalkan Top-p sampling, yang memilih dari set terkecil yang mungkin dari token dengan probabilitas kumulatif yang menambahkan upto probabilitas p. Massa probabilitas kemudian didistribusikan kembali di antara himpunan ini. Dengan cara ini, ukuran set token dapat meningkat dan menurun secara dinamis. Top-p, top-k umumnya digunakan dalam tugas-tugas seperti cerita-generasi.

Decoding API menyediakan antarmuka untuk bereksperimen dengan berbagai strategi decoding pada model auto-regressive.

  1. Strategi pengambilan sampel berikut disediakan di sampling_module.py, yang diturunkan dari kelas Decoding dasar:

  2. Pencarian balok disediakan di beam_search.py. github

Mempersiapkan

pip install -q -U tensorflow-text
pip install -q tf-models-nightly
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from official import nlp
from official.nlp.modeling.ops import sampling_module
from official.nlp.modeling.ops import beam_search
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release
  PkgResourcesDeprecationWarning,

Inisialisasi Modul Pengambilan Sampel di TF-NLP.

  • symbols_to_logits_fn: Gunakan penutupan ini untuk memanggil model untuk memprediksi logits untuk index+1 langkah. Masukan dan keluaran untuk penutupan ini adalah sebagai berikut:
Args:
  1] ids : Current decoded sequences. int tensor with shape (batch_size, index + 1 or 1 if padded_decode is True)],
  2] index [scalar] : current decoded step,
  3] cache [nested dictionary of tensors] : Only used for faster decoding to store pre-computed attention hidden states for keys and values. More explanation in the cell below.
Returns:
  1] tensor for next-step logits [batch_size, vocab]
  2] the updated_cache [nested dictionary of tensors].

Cache digunakan untuk decoding lebih cepat. Berikut ini adalah referensi implementasi untuk penutupan di atas.

  • length_normalization_fn: Gunakan penutupan ini untuk kembali parameter panjang normalisasi.
Args: 
  1] length : scalar for decoded step index.
  2] dtype : data-type of output tensor
Returns:
  1] value of length normalization factor.
  • vocab_size: ukuran kosakata Output.

  • max_decode_length: skalar untuk jumlah langkah decoding.

  • eos_id: Decoding akan berhenti jika semua keluaran diterjemahkan ids dalam batch memiliki eos_id ini.

  • padded_decode: Set ini ke True jika berjalan di TPU. Tensor diisi ke max_decoding_length jika ini Benar.

  • top_k: top_k diaktifkan jika nilai ini> 1.

  • top_p: top_p diaktifkan jika nilai ini> 0 dan <1.0

  • sampling_temperature: ini digunakan untuk kembali memperkirakan-output Softmax. Suhu mencondongkan distribusi ke token probabilitas tinggi dan menurunkan massa dalam distribusi ekor. Nilai harus positif. Suhu rendah sama dengan serakah dan membuat distribusi lebih tajam, sedangkan suhu tinggi membuatnya lebih rata.

  • enable_greedy: Secara default, ini adalah Benar dan decoding serakah diaktifkan. Untuk bereksperimen dengan strategi lain, harap setel ini ke False.

Inisialisasi model hyperparameters

params = {}
params['num_heads'] = 2
params['num_layers'] = 2
params['batch_size'] = 2
params['n_dims'] = 256
params['max_decode_length'] = 4

Dalam arsitektur auto-regresif seperti berbasis Transformer Encoder-Decoder model, Cache digunakan untuk decoding sekuensial cepat. Ini adalah kamus bersarang yang menyimpan keadaan tersembunyi yang telah dihitung sebelumnya (kunci dan nilai dalam blok perhatian-diri dan dalam blok perhatian-silang) untuk setiap lapisan.

Inisialisasi cache.

cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", cache['layer_1']['k'].shape)
cache key shape for layer 1 : (2, 4, 2, 128)

Tentukan penutupan untuk normalisasi panjang jika diperlukan.

Ini digunakan untuk menormalkan skor akhir dari urutan yang dihasilkan dan bersifat opsional

def length_norm(length, dtype):
  """Return length normalization factor."""
  return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)

Buat model_fn

Dalam prakteknya, ini akan digantikan oleh implementasi model yang sebenarnya seperti disini

Args:
i : Step that is being decoded.
Returns:
  logit probabilities of size [batch_size, 1, vocab_size]
probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                            [[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])
def model_fn(i):
  return probabilities[:, i, :]

Inisialisasi simbol_to_logits_fn

def _symbols_to_logits_fn():
  """Calculates logits of the next tokens."""
  def symbols_to_logits_fn(ids, i, temp_cache):
    del ids
    logits = tf.cast(tf.math.log(model_fn(i)), tf.float32)
    return logits, temp_cache
  return symbols_to_logits_fn

Tamak

Decoding serakah memilih id tanda dengan probabilitas tertinggi sebagai id berikutnya: \(id_t = argmax_{w}P(id | id_{1:t-1})\) di setiap timestep \(t\). Sketsa berikut menunjukkan decoding serakah.

greedy_obj = sampling_module.SamplingModule(
    length_normalization_fn=None,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False)
ids, _ = greedy_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("Greedy Decoded Ids:", ids)
Greedy Decoded Ids: tf.Tensor(
[[9 1 2 2 2]
 [1 1 1 2 2]], shape=(2, 5), dtype=int32)

pengambilan sampel top_k

Dalam Top-K sampling, K yang paling mungkin id Token berikutnya disaring dan massa probabilitas didistribusikan di antara hanya mereka ids K.

top_k_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_k=tf.constant(3),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_k_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-k sampled Ids:", ids)
top-k sampled Ids: tf.Tensor(
[[9 1 0 2 2]
 [1 0 1 2 2]], shape=(2, 5), dtype=int32)

pengambilan sampel top_p

Alih-alih sampel hanya dari yang paling mungkin K tanda ids, di Top-p sampel lagi memilih dari set kemungkinan terkecil id yang probabilitas kumulatif melebihi probabilitas p.

top_p_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_p=tf.constant(0.9),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_p_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-p sampled Ids:", ids)
top-p sampled Ids: tf.Tensor(
[[9 1 1 2 2]
 [1 1 1 0 2]], shape=(2, 5), dtype=int32)

Dekode pencarian balok

Pencarian balok mengurangi risiko kehilangan id token probabilitas tinggi yang tersembunyi dengan menjaga jumlah hipotesis yang paling mungkin pada setiap langkah waktu dan akhirnya memilih hipotesis yang memiliki probabilitas tertinggi secara keseluruhan.

beam_size = 2
params['batch_size'] = 1
beam_cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", beam_cache['layer_1']['k'].shape)
ids, _ = beam_search.sequence_beam_search(
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    initial_ids=tf.constant([9], tf.int32),
    initial_cache=beam_cache,
    vocab_size=3,
    beam_size=beam_size,
    alpha=0.6,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False,
    dtype=tf.float32)
print("Beam search ids:", ids)
cache key shape for layer 1 : (1, 4, 2, 256)
Beam search ids: tf.Tensor(
[[[9 0 1 2 2]
  [9 1 2 2 2]]], shape=(1, 2, 5), dtype=int32)