Regolazione fine di Wav2Vec2 con una testina LM

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza su GitHub Scarica taccuino Vedi il modello del mozzo TF

In questo notebook, caricheremo il modello wav2vec2 pre-addestrati da TFHub e perfezionerà su LibriSpeech dataset aggiungendo testa Lingua Modeling (LM) sopra la parte superiore del nostro modello di pre-addestrato. Il compito di fondo è quello di costruire un modello per il riconoscimento vocale automatico cioè dato un po 'il discorso, il modello dovrebbe essere in grado di trascrivere in testo.

Impostare

Prima di eseguire questo notebook, assicurarsi che siete sulla GPU runtime ( Runtime > Change runtime type > GPU ). Il seguente cellule installerà gsoc-wav2vec2 pacchetto e le sue dipendenze.

pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
sudo apt-get install -y libsndfile1-dev
pip3 install -q SoundFile
The following packages were automatically installed and are no longer required:
  linux-gcp-5.4-headers-5.4.0-1040 linux-gcp-5.4-headers-5.4.0-1043
  linux-gcp-5.4-headers-5.4.0-1044 linux-gcp-5.4-headers-5.4.0-1049
  linux-headers-5.4.0-1049-gcp linux-image-5.4.0-1049-gcp
  linux-modules-5.4.0-1049-gcp linux-modules-extra-5.4.0-1049-gcp
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libflac-dev libogg-dev libvorbis-dev libvorbisfile3
The following NEW packages will be installed:
  libflac-dev libogg-dev libsndfile1-dev libvorbis-dev libvorbisfile3
0 upgraded, 5 newly installed, 0 to remove and 143 not upgraded.
Need to get 1040 kB of archives.
After this operation, 4481 kB of additional disk space will be used.
Get:1 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libogg-dev amd64 1.3.2-1 [156 kB]
Get:2 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libflac-dev amd64 1.3.2-1 [260 kB]
Get:3 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbisfile3 amd64 1.3.5-4.2 [16.0 kB]
Get:4 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbis-dev amd64 1.3.5-4.2 [321 kB]
Get:5 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic-updates/main amd64 libsndfile1-dev amd64 1.0.28-4ubuntu0.18.04.2 [287 kB]
Fetched 1040 kB in 1s (1041 kB/s)
Selecting previously unselected package libogg-dev:amd64.
(Reading database ... 282211 files and directories currently installed.)
Preparing to unpack .../libogg-dev_1.3.2-1_amd64.deb ...
Unpacking libogg-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libflac-dev:amd64.
Preparing to unpack .../libflac-dev_1.3.2-1_amd64.deb ...
Unpacking libflac-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libvorbisfile3:amd64.
Preparing to unpack .../libvorbisfile3_1.3.5-4.2_amd64.deb ...
Unpacking libvorbisfile3:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libvorbis-dev:amd64.
Preparing to unpack .../libvorbis-dev_1.3.5-4.2_amd64.deb ...
Unpacking libvorbis-dev:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libsndfile1-dev.
Preparing to unpack .../libsndfile1-dev_1.0.28-4ubuntu0.18.04.2_amd64.deb ...
Unpacking libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Setting up libvorbisfile3:amd64 (1.3.5-4.2) ...
Setting up libogg-dev:amd64 (1.3.2-1) ...
Setting up libvorbis-dev:amd64 (1.3.5-4.2) ...
Setting up libflac-dev:amd64 (1.3.2-1) ...
Setting up libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Processing triggers for libc-bin (2.27-3ubuntu1.2) ...

Messa a punto del modello utilizzando TFHub

Inizieremo importando alcune librerie/moduli.

import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)
TF version: 2.7.0

In primo luogo, abbiamo scaricherà il nostro modello da TFHub e si concluderà la nostra firma modello con hub.KerasLayer per essere in grado di utilizzare questo modello come qualsiasi altro livello Keras. Fortunatamente, hub.KerasLayer può fare entrambe le cose in appena 1 riga.

pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)

È possibile fare riferimento a questo scritto nel caso siate interessati nello script esportazione del modello. Oggetto pretrained_layer è la versione freezed di Wav2Vec2Model . Questi pesi pre-addestrati sono stati convertiti da HuggingFace PyTorch pesi pre-addestrato utilizzando questo script .

In origine, wav2vec2 è stato pre-addestrato con un approccio di modellazione del linguaggio mascherato con l'obiettivo di identificare la vera rappresentazione del parlato latente quantizzata per un passaggio temporale mascherato. Si può leggere di più l'obiettivo di formazione nel carta- wav2vec 2.0: un quadro per auto-apprendimento supervisionato di parola Rappresentazioni .

Ora definiremo alcune costanti e iperparametri che saranno utili nelle prossime celle. AUDIO_MAXLEN è intenzionalmente impostato 246000 come firma modello accetta solo lunghezza della sequenza statica di 246000 .

AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

Nella cella seguente, ci sarà avvolgere pretrained_layer e un fitto strato (testa LM) con l' API funzionale del Keras .

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

Lo strato denso (definito sopra) sta avendo una dimensione di uscita vocab_size come vogliamo predire probabilità di ogni token nel vocabolario ad ogni passo.

Impostazione dello stato di allenamento

In tensorflow, pesi modello sono costruite solo quando model.call o model.build è chiamato per la prima volta, in modo da cella seguente costruirà i pesi modello per noi. Ulteriormente, saremo esecuzione model.summary() per controllare il numero totale di parametri addestrabili.

model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 246000)]          0         
                                                                 
 keras_layer (KerasLayer)    (None, 768, 768)          94371712  
                                                                 
 dense (Dense)               (None, 768, 32)           24608     
                                                                 
=================================================================
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________

Ora, abbiamo bisogno di definire il loss_fn e ottimizzatore di essere in grado di formare il modello. La cella seguente lo farà per noi. Saremo utilizzando la Adam ottimizzatore per semplicità. CTCLoss è un tipo comune perdita che viene utilizzato per compiti (come ASR ) dove ingresso sotto-parti non possono essere facilmente allineati con uscita sotto-parti. Si può leggere di più su CTC-perdita da questo incredibile post sul blog .

CTCLoss (da gsoc-wav2vec2 pacchetto) accetta 3 argomenti config , model_input_shape & division_factor . Se division_factor=1 , allora la perdita sarà semplicemente ottenere sommati, quindi passare division_factor di conseguenza per ottenere lotti medio su.

from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

Caricamento e pre-elaborazione dei dati

Diamo ora scaricare il dataset LibriSpeech dal sito ufficiale e configurarlo.

wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/
--2021-11-05 11:43:09--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’

dev-clean.tar.gz    100%[===================>] 322.27M  11.6MB/s    in 31s     

2021-11-05 11:43:42 (10.3 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]
ls ./data/train/
LibriSpeech/  dev-clean.tar.gz

Il nostro set di dati si trova nella directory LibriSpeech. Esploriamo questi file.

data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)
Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0015.flac', '2428-83705-0004.flac', '2428-83705-0006.flac', '2428-83705-0026.flac', '2428-83705-0023.flac', '2428-83705-0001.flac', '2428-83705-0005.flac', '2428-83705-0040.flac', '2428-83705-0038.flac', '2428-83705-0042.flac', '2428-83705-0008.flac', '2428-83705-0019.flac', '2428-83705-0021.flac', '2428-83705-0002.flac', '2428-83705-0039.flac', '2428-83705-0034.flac', '2428-83705-0028.flac', '2428-83705-0000.flac', '2428-83705-0029.flac', '2428-83705-0041.flac', '2428-83705-0035.flac', '2428-83705-0032.flac', '2428-83705-0020.flac', '2428-83705-0025.flac', '2428-83705-0010.flac', '2428-83705-0014.flac', '2428-83705-0003.flac', '2428-83705-0031.flac', '2428-83705-0017.flac', '2428-83705-0027.flac', '2428-83705-0012.flac', '2428-83705-0043.flac', '2428-83705-0030.flac', '2428-83705-0022.flac', '2428-83705-0016.flac', '2428-83705-0037.flac', '2428-83705-0011.flac', '2428-83705-0036.flac', '2428-83705-0009.flac', '2428-83705-0013.flac', '2428-83705-0007.flac', '2428-83705-0018.flac', '2428-83705-0024.flac', '2428-83705-0033.flac']

Va bene, in modo che ogni sotto-directory ha molti .flac file e un .txt file. Il .txt file contiene le trascrizioni di testo per tutti i campioni vocali (cioè .flac file) presente in quella sub-directory.

Possiamo caricare questi dati di testo come segue:

def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

Allo stesso modo, ci sarà definire una funzione per il caricamento di un campione di discorso da un .flac file.

REQUIRED_SAMPLE_RATE è impostato su 16000 come wav2vec2 è stato pre-addestrato con 16K frequenza e si consiglia di mettere a punto senza alcun cambiamento importante nella distribuzione dei dati a causa di frequenza.

import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

Ora, selezioneremo alcuni campioni casuali e proveremo a visualizzarli.

from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)
Text Transcription: HE HAS GIVEN US FREE PASSES ALL THE WAY TO THE END OF OUR JOURNEY AND ALL THE WAY BACK AGAIN AND COUPONS FOR FREE BOARD AND LODGING AT THE HOTEL IT'S A WEDDING PRESENT 
Audio:

Ora combineremo tutti i campioni vocali e di testo e definiremo la funzione (nella cella successiva) a tale scopo.

def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

È il momento di dare un'occhiata ad alcuni campioni...

samples = fetch_sound_text_mapping(data_dir)
samples[:5]
[(array([ 6.10351562e-05,  9.15527344e-05,  9.15527344e-05, ...,
         -3.05175781e-04, -5.79833984e-04, -8.23974609e-04]),
  'WHEN SHE HEARD OF MY ENGAGEMENT WITH MARY ANN SHE WROTE AND SUGGESTED THAT WE SHOULD SPEND OUR HONEYMOON IN HER COTTAGE OR PIGSTYE AND THAT I SHOULD PAY HER RENT FOR IT'),
 (array([-0.00112915, -0.00131226, -0.00158691, ...,  0.00067139,
          0.00091553,  0.00100708]),
  "IT MIGHT JUST AS WELL BE SOME ONE ELSE'S WEDDING SO UNIMPORTANT IS THE PART WHICH I AM SET TO PLAY IN IT"),
 (array([ 3.05175781e-05, -6.10351562e-05,  2.13623047e-04, ...,
         -5.18798828e-04, -2.13623047e-04, -2.74658203e-04]),
  'THE ACCIDENT IN QUESTION OCCURRED UPON THE SUNDAY EVENING'),
 (array([ 3.05175781e-04,  3.05175781e-05, -1.83105469e-04, ...,
          7.62939453e-04,  6.10351562e-04,  5.79833984e-04]),
  "OF COURSE THERE ARE SOME PEOPLE WITH WHOM YOU CAN'T BE PERFECTLY PLAIN BUT I SHALL BE AS PLAIN AS I CAN THERE'S A WAY AND A MANNER OF DOING THAT KIND OF THING"),
 (array([ 6.10351562e-05, -3.05175781e-05,  0.00000000e+00, ...,
         -3.66210938e-04, -7.93457031e-04, -1.19018555e-03]),
  'I KNOW WHAT MAMMA CAN AFFORD TO GIVE AND I WILL SEE SHE GIVES IT')]

Pre-processiamo i dati ora!!!

Per prima definire il tokenizzatore & processore utilizzando gsoc-wav2vec2 pacchetto. Quindi, eseguiremo una pre-elaborazione molto semplice. processor normalizzerà discorso crudo wrto cornici assi e tokenizer convertiranno le nostre uscite modello nella stringa (usando il vocabolario definito) e si occuperanno della rimozione dei gettoni speciali (a seconda della configurazione tokenizer).

from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))
Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE

Ora, definiremo il generatore Python per chiamare le funzioni di pre-elaborazione che abbiamo definito nelle celle sopra.

def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

Impostazione tf.data.Dataset

Dopo l'installazione verrà cella tf.data.Dataset oggetto utilizzando la sua .from_generator(...) metodo. Useremo il generator oggetto, abbiamo definito nella cella sopra.

È possibile fare riferimento a questo script per maggiori dettagli su come convertire i dati in LibriSpeech tfrecords.

output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

Passeremo il set di dati in più batch, quindi prepariamo i batch nella cella seguente. Ora, tutte le sequenze in un batch dovrebbero essere riempite a una lunghezza costante. Useremo il .padded_batch(...) il metodo a tal fine.

dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

Gli acceleratori (come GPU/TPU) sono molto veloci e spesso il caricamento dei dati (e la pre-elaborazione) diventa il collo di bottiglia durante l'addestramento poiché la parte del caricamento dei dati avviene sulle CPU. Ciò può aumentare significativamente il tempo di addestramento, specialmente quando sono coinvolte molte pre-elaborazione online o i dati vengono trasmessi in streaming online dai bucket GCS. Per gestire questi problemi, tf.data.Dataset offre la .prefetch(...) metodo. Questo metodo aiuta a preparare i successivi batch in parallelo (su CPU) mentre il modello sta effettuando previsioni (su GPU/TPU) sul batch corrente.

dataset = dataset.prefetch(tf.data.AUTOTUNE)

Dal momento che questo notebook è fatto per scopi dimostrativi, ci porteremo primi num_train_batches e si esibiranno nel corso di formazione solo. Tuttavia, sei incoraggiato ad allenarti sull'intero set di dati. Allo stesso modo, valuteremo solo num_val_batches .

num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

Formazione modello

Per la formazione del nostro modello, saremo direttamente chiamando .fit(...) metodo dopo aver compilato il nostro modello con .compile(...) .

model.compile(optimizer, loss=loss_fn)

La cella sopra imposterà il nostro stato di addestramento. Ora siamo in grado di avviare la formazione con il .fit(...) metodo.

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
10/10 [==============================] - 32s 2s/step - loss: 649.3215 - val_loss: 315.0721
Epoch 2/3
10/10 [==============================] - 17s 2s/step - loss: 242.1202 - val_loss: 336.5721
Epoch 3/3
10/10 [==============================] - 17s 2s/step - loss: 222.1239 - val_loss: 253.0467
{'loss': [649.321533203125, 242.1201629638672, 222.1239013671875],
 'val_loss': [315.0721435546875, 336.5721130371094, 253.0466766357422]}

Salviamo il nostro modello con .save(...) il metodo per essere in grado di eseguire l'inferenza tardi. È inoltre possibile esportare questo SavedModel a TFHub seguendo la documentazione TFHub .

save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)
2021-11-05 11:44:54.280793: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 855). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets

Valutazione

Ora calcoleremo Word Error Rate sul set di dati di convalida

Parola tasso di errore (WER) è una metrica comune per misurare le prestazioni di un sistema automatico di riconoscimento vocale. Il WER deriva dalla distanza di Levenshtein, lavorando a livello di parola. Il tasso di errore di parola può quindi essere calcolato come: WER = (S + D + I) / N = (S + D + I) / (S + D + C) dove S è il numero di sostituzioni, D è il numero di eliminazioni , I è il numero di inserimenti, C è il numero di parole corrette, N è il numero di parole nel riferimento (N=S+D+C). Questo valore indica la percentuale di parole che sono state previste in modo errato.

È possibile fare riferimento a questo documento per ulteriori informazioni su WER.

Useremo load_metric(...) la funzione da HuggingFace dataset biblioteca. Diamo prima installare il datasets libreria utilizzando pip e quindi definiscono l' metric oggetto.

!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")
Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

È ora di eseguire la valutazione sui dati di convalida ora.

from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)
0%|          | 0/4 [00:00<?, ?it/s]
2021-11-05 11:45:11.575128: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:57] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. model/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/wav2vec2/encoder/layers/0/stochastic_depth/random_uniform/RandomUniform

Stiamo usando il tokenizer.decode(...) il metodo per decodificare le nostre previsioni e le etichette di nuovo nel testo e li aggiungerà alla metrica per WER calcolo più tardi.

Ora, calcoliamo il valore della metrica nella seguente cella:

metric.compute()
1.0

Inferenza

Ora che siamo soddisfatti con il processo di formazione e abbiamo salvato il modello in save_dir , vedremo come questo modello può essere utilizzato per l'inferenza.

In primo luogo, verrà caricato il nostro modello utilizzando tf.keras.models.load_model(...) .

finetuned_model = tf.keras.models.load_model(save_dir)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

Scarichiamo alcuni esempi di parlato per eseguire l'inferenza. Puoi anche sostituire il seguente esempio con il tuo esempio vocale.

wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
--2021-11-05 11:45:28--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 13.114.40.48
Connecting to github.com (github.com)|13.114.40.48|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav [following]
--2021-11-05 11:45:28--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’

SA2.wav             100%[===================>]  92.04K  --.-KB/s    in 0.02s   

2021-11-05 11:45:29 (5.38 MB/s) - ‘SA2.wav’ saved [94252/94252]

Ora, leggeremo il campione vocale tramite soundfile.read(...) e pad a AUDIO_MAXLEN per soddisfare la firma del modello. Poi ci sarà normalizzare il campione vocale tramite il Wav2Vec2Processor istanza & Manderemo nel modello.

import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs
<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 5.5087714 , -1.0872856 , -1.0728477 , ..., -1.3125695 ,
         -0.7992846 , -0.94512135],
        [ 5.508977  , -1.0873723 , -1.0727195 , ..., -1.3125291 ,
         -0.79928476, -0.9449429 ],
        [ 5.5091047 , -1.0871643 , -1.0728203 , ..., -1.312533  ,
         -0.7992611 , -0.94483167],
        ...,
        [ 5.5094743 , -1.0874028 , -1.0729864 , ..., -1.3126655 ,
         -0.7994431 , -0.9449925 ],
        [ 5.509465  , -1.0873648 , -1.072943  , ..., -1.3126557 ,
         -0.79943836, -0.94500387],
        [ 5.509408  , -1.0872416 , -1.0728781 , ..., -1.3125473 ,
         -0.7993649 , -0.9449776 ]]], dtype=float32)>

Numeri decodificare andiamo indietro nella sequenza di testo utilizzando la Wav2Vec2tokenizer esempio, abbiamo definito in precedenza.

predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions
['']

Questa previsione è abbastanza casuale in quanto il modello non è mai stato addestrato su dati di grandi dimensioni in questo notebook (poiché questo notebook non è pensato per eseguire un training completo). Otterrai buone previsioni se esegui il training di questo modello su un set di dati completo di LibriSpeech.

Finalmente siamo giunti alla fine di questo notebook. Ma non è la fine di imparare tensorflow per compiti vocali legati, questo repository contiene alcuni tutorial più sorprendente. Nel caso in cui hai incontrato un errore in questa notebook, crei un problema qui .