Wav2Vec2 com ajuste fino com uma cabeça LM

Ver no TensorFlow.org Executar no Google Colab Ver no GitHub Baixar caderno Veja o modelo TF Hub

Neste caderno, vamos carregar o modelo wav2vec2 pré-treinados desde TFHub e vai afinar-lo na LibriSpeech conjunto de dados anexando cabeça Idioma Modeling (LM) sobre a parte superior do nosso modelo pré-treinados. A tarefa subjacente é a construção de um modelo de Reconhecimento Automático de Fala ou seja dado algum discurso, o modelo deve ser capaz de transcrevê-lo em texto.

Configurando

Antes de executar este notebook, certifique-se de que você está no GPU tempo de execução ( Runtime > Change runtime type > GPU ). A célula seguinte irá instalar gsoc-wav2vec2 pacote & suas dependências.

pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
sudo apt-get install -y libsndfile1-dev
pip3 install -q SoundFile
The following packages were automatically installed and are no longer required:
  linux-gcp-5.4-headers-5.4.0-1040 linux-gcp-5.4-headers-5.4.0-1043
  linux-gcp-5.4-headers-5.4.0-1044 linux-gcp-5.4-headers-5.4.0-1049
  linux-headers-5.4.0-1049-gcp linux-image-5.4.0-1049-gcp
  linux-modules-5.4.0-1049-gcp linux-modules-extra-5.4.0-1049-gcp
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libflac-dev libogg-dev libvorbis-dev libvorbisfile3
The following NEW packages will be installed:
  libflac-dev libogg-dev libsndfile1-dev libvorbis-dev libvorbisfile3
0 upgraded, 5 newly installed, 0 to remove and 143 not upgraded.
Need to get 1040 kB of archives.
After this operation, 4481 kB of additional disk space will be used.
Get:1 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libogg-dev amd64 1.3.2-1 [156 kB]
Get:2 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libflac-dev amd64 1.3.2-1 [260 kB]
Get:3 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbisfile3 amd64 1.3.5-4.2 [16.0 kB]
Get:4 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbis-dev amd64 1.3.5-4.2 [321 kB]
Get:5 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic-updates/main amd64 libsndfile1-dev amd64 1.0.28-4ubuntu0.18.04.2 [287 kB]
Fetched 1040 kB in 1s (1041 kB/s)
Selecting previously unselected package libogg-dev:amd64.
(Reading database ... 282211 files and directories currently installed.)
Preparing to unpack .../libogg-dev_1.3.2-1_amd64.deb ...
Unpacking libogg-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libflac-dev:amd64.
Preparing to unpack .../libflac-dev_1.3.2-1_amd64.deb ...
Unpacking libflac-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libvorbisfile3:amd64.
Preparing to unpack .../libvorbisfile3_1.3.5-4.2_amd64.deb ...
Unpacking libvorbisfile3:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libvorbis-dev:amd64.
Preparing to unpack .../libvorbis-dev_1.3.5-4.2_amd64.deb ...
Unpacking libvorbis-dev:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libsndfile1-dev.
Preparing to unpack .../libsndfile1-dev_1.0.28-4ubuntu0.18.04.2_amd64.deb ...
Unpacking libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Setting up libvorbisfile3:amd64 (1.3.5-4.2) ...
Setting up libogg-dev:amd64 (1.3.2-1) ...
Setting up libvorbis-dev:amd64 (1.3.5-4.2) ...
Setting up libflac-dev:amd64 (1.3.2-1) ...
Setting up libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Processing triggers for libc-bin (2.27-3ubuntu1.2) ...

Configuração do modelo usando TFHub

Começaremos importando algumas bibliotecas / módulos.

import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)
TF version: 2.7.0

Primeiro, vamos baixar o nosso modelo de TFHub e vai envolver a nossa assinatura modelo com hub.KerasLayer de ser capaz de usar este modelo como qualquer outra camada Keras. Felizmente, hub.KerasLayer pode fazer as duas coisas em apenas 1 linha.

pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)

Você pode se referir a este roteiro no caso de você estiver interessado no script exportar modelo. Objeto pretrained_layer é a versão freezed de Wav2Vec2Model . Estes pesos pré-formados foram convertidos de HuggingFace PyTorch pesos pré-formados utilizando este certificado .

Originalmente, wav2vec2 foi pré-treinado com uma abordagem de modelagem de linguagem mascarada com o objetivo de identificar a verdadeira representação de fala latente quantizada para um intervalo de tempo mascarado. Você pode ler mais sobre o objetivo do treinamento no papel- wav2vec 2.0: Uma Estrutura para a aprendizagem auto-Supervisionado das Representações da fala .

Agora definiremos algumas constantes e hiperparâmetros que serão úteis nas próximas células. AUDIO_MAXLEN é intencionalmente definido para 246000 como a assinatura modelo só aceita comprimento seqüência estática de 246000 .

AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

No seguinte célula, vamos embrulhar pretrained_layer e uma densa camada (cabeça LM) com a API Funcional do Keras .

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

A camada densa (definido acima) está a ter uma dimensão de saída vocab_size como queremos prever probabilidades de cada símbolo no vocabulário em cada etapa de tempo.

Configurando o estado de treinamento

Em TensorFlow, pesos modelo só é construído quando model.call ou model.build é chamado pela primeira vez, de modo que a célula seguinte irá construir os pesos modelo para nós. Além disso, estaremos correndo model.summary() para verificar o número total de parâmetros treináveis.

model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 246000)]          0         
                                                                 
 keras_layer (KerasLayer)    (None, 768, 768)          94371712  
                                                                 
 dense (Dense)               (None, 768, 32)           24608     
                                                                 
=================================================================
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________

Agora, precisamos definir o loss_fn e otimizador para ser capaz de treinar o modelo. A célula a seguir fará isso por nós. Nós estaremos usando o Adam otimizador para a simplicidade. CTCLoss é um tipo comum que a perda é utilizada para tarefas (como ASR ) onde sub-partes de entrada não possa ser facilmente alinhados com saída sub-partes. Você pode ler mais sobre CTC-perda do incrível post .

CTCLoss (de gsoc-wav2vec2 pacote) aceita 3 argumentos: config , model_input_shape & division_factor . Se division_factor=1 , então a perda vai simplesmente se resumiu, então passar division_factor acordo para obter lote mais de média.

from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

Carregando e pré-processando dados

Vamos agora fazer o download do conjunto de dados LibriSpeech do site oficial e configurá-lo.

wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/
--2021-11-05 11:43:09--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’

dev-clean.tar.gz    100%[===================>] 322.27M  11.6MB/s    in 31s     

2021-11-05 11:43:42 (10.3 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]
ls ./data/train/
LibriSpeech/  dev-clean.tar.gz

Nosso conjunto de dados está no diretório LibriSpeech. Vamos explorar esses arquivos.

data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)
Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0015.flac', '2428-83705-0004.flac', '2428-83705-0006.flac', '2428-83705-0026.flac', '2428-83705-0023.flac', '2428-83705-0001.flac', '2428-83705-0005.flac', '2428-83705-0040.flac', '2428-83705-0038.flac', '2428-83705-0042.flac', '2428-83705-0008.flac', '2428-83705-0019.flac', '2428-83705-0021.flac', '2428-83705-0002.flac', '2428-83705-0039.flac', '2428-83705-0034.flac', '2428-83705-0028.flac', '2428-83705-0000.flac', '2428-83705-0029.flac', '2428-83705-0041.flac', '2428-83705-0035.flac', '2428-83705-0032.flac', '2428-83705-0020.flac', '2428-83705-0025.flac', '2428-83705-0010.flac', '2428-83705-0014.flac', '2428-83705-0003.flac', '2428-83705-0031.flac', '2428-83705-0017.flac', '2428-83705-0027.flac', '2428-83705-0012.flac', '2428-83705-0043.flac', '2428-83705-0030.flac', '2428-83705-0022.flac', '2428-83705-0016.flac', '2428-83705-0037.flac', '2428-83705-0011.flac', '2428-83705-0036.flac', '2428-83705-0009.flac', '2428-83705-0013.flac', '2428-83705-0007.flac', '2428-83705-0018.flac', '2428-83705-0024.flac', '2428-83705-0033.flac']

Tudo bem, então cada sub-diretório tem muitos .flac arquivos e um .txt arquivo. O .txt arquivo contém transcrições de texto para todas as amostras de fala (ou seja .flac arquivos) presente nessa sub-diretório.

Podemos carregar esses dados de texto da seguinte maneira:

def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

Da mesma forma, vamos definir uma função para carregar uma amostra de fala de um .flac arquivo.

REQUIRED_SAMPLE_RATE está definido para 16000 como wav2vec2 foi pré-treinado com 16K frequência e é recomendado para afinar-lo sem qualquer grande mudança na distribuição de dados devido à frequência.

import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

Agora, vamos escolher algumas amostras aleatórias e tentaremos visualizá-las.

from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)
Text Transcription: HE HAS GIVEN US FREE PASSES ALL THE WAY TO THE END OF OUR JOURNEY AND ALL THE WAY BACK AGAIN AND COUPONS FOR FREE BOARD AND LODGING AT THE HOTEL IT'S A WEDDING PRESENT 
Audio:

Agora, combinaremos todos os exemplos de fala e texto e definiremos a função (na próxima célula) para esse propósito.

def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

É hora de dar uma olhada em algumas amostras ...

samples = fetch_sound_text_mapping(data_dir)
samples[:5]
[(array([ 6.10351562e-05,  9.15527344e-05,  9.15527344e-05, ...,
         -3.05175781e-04, -5.79833984e-04, -8.23974609e-04]),
  'WHEN SHE HEARD OF MY ENGAGEMENT WITH MARY ANN SHE WROTE AND SUGGESTED THAT WE SHOULD SPEND OUR HONEYMOON IN HER COTTAGE OR PIGSTYE AND THAT I SHOULD PAY HER RENT FOR IT'),
 (array([-0.00112915, -0.00131226, -0.00158691, ...,  0.00067139,
          0.00091553,  0.00100708]),
  "IT MIGHT JUST AS WELL BE SOME ONE ELSE'S WEDDING SO UNIMPORTANT IS THE PART WHICH I AM SET TO PLAY IN IT"),
 (array([ 3.05175781e-05, -6.10351562e-05,  2.13623047e-04, ...,
         -5.18798828e-04, -2.13623047e-04, -2.74658203e-04]),
  'THE ACCIDENT IN QUESTION OCCURRED UPON THE SUNDAY EVENING'),
 (array([ 3.05175781e-04,  3.05175781e-05, -1.83105469e-04, ...,
          7.62939453e-04,  6.10351562e-04,  5.79833984e-04]),
  "OF COURSE THERE ARE SOME PEOPLE WITH WHOM YOU CAN'T BE PERFECTLY PLAIN BUT I SHALL BE AS PLAIN AS I CAN THERE'S A WAY AND A MANNER OF DOING THAT KIND OF THING"),
 (array([ 6.10351562e-05, -3.05175781e-05,  0.00000000e+00, ...,
         -3.66210938e-04, -7.93457031e-04, -1.19018555e-03]),
  'I KNOW WHAT MAMMA CAN AFFORD TO GIVE AND I WILL SEE SHE GIVES IT')]

Vamos pré-processar os dados agora !!!

Vamos primeiro definir o tokenizer & processador usando gsoc-wav2vec2 pacote. Então, faremos um pré-processamento muito simples. processor irá normalizar discurso cru wrto quadros eixo e tokenizer irá converter os nossos resultados do modelo para a string (usando o vocabulário definido) e vai cuidar da remoção de tokens especiais (dependendo da configuração tokenizer).

from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))
Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE

Agora, vamos definir o gerador python para chamar as funções de pré-processamento que definimos nas células acima.

def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

Configurando tf.data.Dataset

Seguinte configuração vontade célula tf.data.Dataset objeto usando sua .from_generator(...) método. Nós estaremos usando o generator objeto, definimos na célula acima.

Você pode se referir a esse script para obter mais detalhes sobre como converter dados LibriSpeech em tfrecords.

output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

Vamos passar o conjunto de dados em vários lotes, então vamos preparar os lotes na célula a seguir. Agora, todas as sequências em um lote devem ser preenchidas com um comprimento constante. Nós vamos usar o .padded_batch(...) método para o efeito.

dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

Aceleradores (como GPUs / TPUs) são muito rápidos e muitas vezes o carregamento de dados (e pré-processamento) torna-se o gargalo durante o treinamento, pois a parte de carregamento de dados acontece nas CPUs. Isso pode aumentar o tempo de treinamento significativamente, especialmente quando há muito pré-processamento online envolvido ou os dados são transmitidos online a partir de buckets do GCS. Para lidar com essas questões, tf.data.Dataset oferece a .prefetch(...) método. Este método ajuda a preparar os próximos lotes em paralelo (em CPUs) enquanto o modelo faz previsões (em GPUs / TPUs) no lote atual.

dataset = dataset.prefetch(tf.data.AUTOTUNE)

Uma vez que este notebook é feito para fins de demonstração, estaremos dando os primeiros num_train_batches e vai executar o treinamento mais só isso. Você é encorajado a treinar em todo o conjunto de dados. Da mesma forma, iremos avaliar apenas num_val_batches .

num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

Treinamento de modelo

Para treinar o nosso modelo, que será diretamente chamando .fit(...) método depois de compilar o nosso modelo com .compile(...) .

model.compile(optimizer, loss=loss_fn)

A célula acima configurará nosso estado de treinamento. Agora podemos iniciar o treinamento com a .fit(...) método.

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
10/10 [==============================] - 32s 2s/step - loss: 649.3215 - val_loss: 315.0721
Epoch 2/3
10/10 [==============================] - 17s 2s/step - loss: 242.1202 - val_loss: 336.5721
Epoch 3/3
10/10 [==============================] - 17s 2s/step - loss: 222.1239 - val_loss: 253.0467
{'loss': [649.321533203125, 242.1201629638672, 222.1239013671875],
 'val_loss': [315.0721435546875, 336.5721130371094, 253.0466766357422]}

Vamos salvar o nosso modelo com .save(...) método para ser capaz de realizar inferência mais tarde. Você também pode exportar este SavedModel para TFHub seguindo documentação TFHub .

save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)
2021-11-05 11:44:54.280793: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 855). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets

Avaliação

Agora vamos calcular a taxa de erro do Word sobre o conjunto de dados de validação

Taxa de erro Word (WER) é uma métrica comum para medir o desempenho de um sistema de reconhecimento automático de fala. O WER é derivado da distância de Levenshtein, trabalhando no nível da palavra. A taxa de erro de palavra pode então ser calculada como: WER = (S + D + I) / N = (S + D + I) / (S + D + C) onde S é o número de substituições, D é o número de deleções , I é o número de inserções, C é o número de palavras corretas, N é o número de palavras na referência (N = S + D + C). Este valor indica a porcentagem de palavras que foram previstas incorretamente.

Você pode se referir a este documento para saber mais sobre WER.

Usaremos load_metric(...) função de HuggingFace conjuntos de dados biblioteca. Vamos primeiro instalar o datasets da biblioteca usando pip e, em seguida, definir a metric objeto.

!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")
Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

É hora de executar a avaliação dos dados de validação agora.

from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)
0%|          | 0/4 [00:00<?, ?it/s]
2021-11-05 11:45:11.575128: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:57] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. model/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/wav2vec2/encoder/layers/0/stochastic_depth/random_uniform/RandomUniform

Estamos usando o tokenizer.decode(...) método para decodificar as nossas previsões e etiquetas de volta para o texto e adicioná-los à métrica para WER computação mais tarde.

Agora, vamos calcular o valor da métrica na seguinte célula:

metric.compute()
1.0

Inferência

Agora que estamos satisfeitos com o processo de formação e salvou o modelo em save_dir , vamos ver como este modelo pode ser usado para inferência.

Primeiro, vamos carregar o nosso modelo usando tf.keras.models.load_model(...) .

finetuned_model = tf.keras.models.load_model(save_dir)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

Vamos baixar alguns exemplos de fala para realizar inferências. Você também pode substituir o exemplo a seguir pelo seu exemplo de fala.

wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
--2021-11-05 11:45:28--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 13.114.40.48
Connecting to github.com (github.com)|13.114.40.48|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav [following]
--2021-11-05 11:45:28--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’

SA2.wav             100%[===================>]  92.04K  --.-KB/s    in 0.02s   

2021-11-05 11:45:29 (5.38 MB/s) - ‘SA2.wav’ saved [94252/94252]

Agora, vamos ler a amostra de fala usando soundfile.read(...) e uma almofada para AUDIO_MAXLEN para satisfazer a assinatura modelo. Então, vamos normalizar que amostras de fala usando o Wav2Vec2Processor instância e vai alimentá-lo para o modelo.

import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs
<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 5.5087714 , -1.0872856 , -1.0728477 , ..., -1.3125695 ,
         -0.7992846 , -0.94512135],
        [ 5.508977  , -1.0873723 , -1.0727195 , ..., -1.3125291 ,
         -0.79928476, -0.9449429 ],
        [ 5.5091047 , -1.0871643 , -1.0728203 , ..., -1.312533  ,
         -0.7992611 , -0.94483167],
        ...,
        [ 5.5094743 , -1.0874028 , -1.0729864 , ..., -1.3126655 ,
         -0.7994431 , -0.9449925 ],
        [ 5.509465  , -1.0873648 , -1.072943  , ..., -1.3126557 ,
         -0.79943836, -0.94500387],
        [ 5.509408  , -1.0872416 , -1.0728781 , ..., -1.3125473 ,
         -0.7993649 , -0.9449776 ]]], dtype=float32)>

Números de decodificação de Let cópias em seqüência de texto usando o Wav2Vec2tokenizer exemplo, nós definido acima.

predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions
['']

Esta previsão é bastante aleatória, pois o modelo nunca foi treinado em grandes dados neste notebook (já que este notebook não foi feito para fazer um treinamento completo). Você obterá boas previsões se treinar este modelo no conjunto de dados LibriSpeech completo.

Finalmente, chegamos ao fim deste caderno. Mas não é um fim de aprender TensorFlow para tarefas relacionadas com a fala, este repositório contém alguns tutoriais mais surpreendentes. No caso de você encontrou qualquer bug nesta notebook, por favor, criar um problema aqui .