API giải mã

Xem trên TensorFlow.org Chạy trong Google Colab Xem trên GitHub Tải xuống sổ ghi chép

Tổng quat

Trong quá khứ gần đây, đã có rất nhiều nghiên cứu trong việc tạo ngôn ngữ với các mô hình tự động rút lui. Trong thế hệ ngôn ngữ tự động thoái lui, các phân bố xác suất của thẻ tại bước thời gian K phụ thuộc vào mô hình của token dự đoán cho đến khi bước K-1. Đối với các mô hình này, các chiến lược giải mã như tia tìm kiếm, tham lam, Top-p, và Top-k là những thành phần quan trọng của mô hình và phần lớn ảnh hưởng đến phong cách / tính chất của sản lượng tạo ra mã thông báo tại một thời gian nhất định bước K.

Ví dụ, tia tìm kiếm sẽ giảm nguy cơ bị mất tích ẩn thẻ xác suất cao bằng cách giữ num_beams khả dĩ nhất của giả thuyết tại mỗi bước thời gian và cuối cùng lựa chọn giả thuyết rằng có xác suất tổng thể cao nhất. Murray và cộng sự. (2018)Yang et al. (2018) cho thấy rằng tìm kiếm chùm hoạt động tốt trong nhiệm vụ Máy dịch. Cả hai tìm kiếm tia & chiến lược tham lam có khả năng tạo ra thẻ lặp đi lặp lại.

Fan et. al (2018) giới thiệu Top-K lấy mẫu, trong đó K thẻ rất có thể sẽ được lọc và khối lượng xác xuất được phân phối lại trong chỉ những thẻ K.

Ari Holtzman et. al (2019) giới thiệu Top-p lấy mẫu, mà chọn từ phần thiết lập nhỏ nhất có thể của thẻ với xác suất tích lũy mà thêm tối đa xác suất p. Khối lượng xác suất sau đó được phân phối lại giữa tập hợp này. Bằng cách này, kích thước của tập hợp mã thông báo có thể tự động tăng và giảm. Top-p, Top-k thường được sử dụng trong các nhiệm vụ như câu chuyện thế hệ.

API giải mã cung cấp một giao diện để thử nghiệm các chiến lược giải mã khác nhau trên các mô hình tự động hồi quy.

  1. Các chiến lược lấy mẫu sau được cung cấp trong sampling_module.py, chiến lược này kế thừa từ lớp Giải mã cơ sở:

  2. Tìm kiếm chùm được cung cấp trong beam_search.py. github

Thành lập

pip install -q -U tensorflow-text
pip install -q tf-models-nightly
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from official import nlp
from official.nlp.modeling.ops import sampling_module
from official.nlp.modeling.ops import beam_search
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release
  PkgResourcesDeprecationWarning,

Khởi tạo mô-đun lấy mẫu trong TF-NLP.

  • symbols_to_logits_fn: Sử dụng đóng cửa này để gọi các mô hình để dự đoán logits cho index+1 bước. Đầu vào và đầu ra cho quá trình đóng này như sau:
Args:
  1] ids : Current decoded sequences. int tensor with shape (batch_size, index + 1 or 1 if padded_decode is True)],
  2] index [scalar] : current decoded step,
  3] cache [nested dictionary of tensors] : Only used for faster decoding to store pre-computed attention hidden states for keys and values. More explanation in the cell below.
Returns:
  1] tensor for next-step logits [batch_size, vocab]
  2] the updated_cache [nested dictionary of tensors].

Bộ nhớ đệm được sử dụng để giải mã nhanh hơn. Đây là một tài liệu tham khảo thực hiện cho việc đóng cửa trên.

  • length_normalization_fn: Sử dụng đóng cửa này trả lại thông số chiều dài bình thường.
Args: 
  1] length : scalar for decoded step index.
  2] dtype : data-type of output tensor
Returns:
  1] value of length normalization factor.
  • vocab_size: Output kích thước từ vựng.

  • max_decode_length: Scalar cho tổng số bước giải mã.

  • eos_id: Giải mã sẽ dừng lại nếu tất cả các đầu ra được giải mã id trong đợt có eos_id này.

  • padded_decode: Thiết lập này là True nếu chạy trên TPU. Hàng chục được đệm thành max_decoding_length nếu điều này là Đúng.

  • top_k: top_k được kích hoạt nếu giá trị này là> 1.

  • top_p: top_p được kích hoạt nếu giá trị này là> 0 và <1.0

  • sampling_temperature: Đây được sử dụng để tái ước tính sản lượng softmax. Nhiệt độ làm lệch phân phối về phía các mã có xác suất cao và làm giảm khối lượng trong phân phối đuôi. Giá trị phải tích cực. Nhiệt độ thấp tương đương với tham lam và làm cho phân bố sắc nét hơn, trong khi nhiệt độ cao làm cho nó phẳng hơn.

  • enable_greedy: Theo mặc định, đây là True và giải mã tham lam được kích hoạt. Để thử nghiệm với các chiến lược khác, vui lòng đặt giá trị này thành Sai.

Khởi tạo siêu tham số Mô hình

params = {}
params['num_heads'] = 2
params['num_layers'] = 2
params['batch_size'] = 2
params['n_dims'] = 256
params['max_decode_length'] = 4

Trong kiến trúc tự động thụt lùi như Transformer dựa mã hóa-Decoder mô hình, bộ nhớ cache được sử dụng để giải mã tuần tự nhanh. Nó là một từ điển lồng nhau lưu trữ các trạng thái ẩn được tính toán trước (khóa và giá trị trong các khối tự chú ý và trong khối chú ý chéo) cho mọi lớp.

Khởi tạo bộ nhớ cache.

cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", cache['layer_1']['k'].shape)
cache key shape for layer 1 : (2, 4, 2, 128)

Xác định sự đóng cửa để chuẩn hóa độ dài nếu cần.

Điều này được sử dụng để chuẩn hóa điểm số cuối cùng của các chuỗi được tạo và là tùy chọn

def length_norm(length, dtype):
  """Return length normalization factor."""
  return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)

Tạo model_fn

Trong thực tế, điều này sẽ được thay thế bằng một thực hiện mô hình thực tế như ở đây

Args:
i : Step that is being decoded.
Returns:
  logit probabilities of size [batch_size, 1, vocab_size]
probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                            [[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])
def model_fn(i):
  return probabilities[:, i, :]

Khởi tạo Symbol_to_logits_fn

def _symbols_to_logits_fn():
  """Calculates logits of the next tokens."""
  def symbols_to_logits_fn(ids, i, temp_cache):
    del ids
    logits = tf.cast(tf.math.log(model_fn(i)), tf.float32)
    return logits, temp_cache
  return symbols_to_logits_fn

Tham

Giải mã tham lam chọn id mã thông báo với xác suất cao nhất như id tiếp theo của nó: \(id_t = argmax_{w}P(id | id_{1:t-1})\) tại mỗi bước thời gian \(t\). Bản phác thảo sau đây cho thấy sự giải mã tham lam.

greedy_obj = sampling_module.SamplingModule(
    length_normalization_fn=None,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False)
ids, _ = greedy_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("Greedy Decoded Ids:", ids)
Greedy Decoded Ids: tf.Tensor(
[[9 1 2 2 2]
 [1 1 1 2 2]], shape=(2, 5), dtype=int32)

lấy mẫu top_k

Trong mẫu Top-K, K hầu như id tới thẻ được lọc và khối lượng xác xuất được phân phối lại trong chỉ những id K.

top_k_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_k=tf.constant(3),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_k_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-k sampled Ids:", ids)
top-k sampled Ids: tf.Tensor(
[[9 1 0 2 2]
 [1 0 1 2 2]], shape=(2, 5), dtype=int32)

lấy mẫu top_p

Thay vì lấy mẫu chỉ từ khả năng nhất K thẻ id, trong Top-p lấy mẫu sẽ chọn từ tập thể nhỏ nhất của id mà xác suất tích lũy vượt quá p xác suất.

top_p_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_p=tf.constant(0.9),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_p_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-p sampled Ids:", ids)
top-p sampled Ids: tf.Tensor(
[[9 1 1 2 2]
 [1 1 1 0 2]], shape=(2, 5), dtype=int32)

Giải mã tìm kiếm chùm

Tìm kiếm theo chùm làm giảm nguy cơ thiếu id mã thông báo có xác suất cao ẩn bằng cách giữ lại số lượng giả thuyết có khả năng xảy ra cao nhất ở mỗi bước thời gian và cuối cùng chọn giả thuyết có xác suất tổng thể cao nhất.

beam_size = 2
params['batch_size'] = 1
beam_cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", beam_cache['layer_1']['k'].shape)
ids, _ = beam_search.sequence_beam_search(
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    initial_ids=tf.constant([9], tf.int32),
    initial_cache=beam_cache,
    vocab_size=3,
    beam_size=beam_size,
    alpha=0.6,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False,
    dtype=tf.float32)
print("Beam search ids:", ids)
cache key shape for layer 1 : (1, 4, 2, 256)
Beam search ids: tf.Tensor(
[[[9 0 1 2 2]
  [9 1 2 2 2]]], shape=(1, 2, 5), dtype=int32)