Ottimizza XLSR-Wav2Vec2 per l’ASR a bassa risorsa con i Transformers di 🤗

'Ottimizza XLSR-Wav2Vec2 per l'ASR con risorse limitate utilizzando i Transformers di 🤗.'

Nuovo (11/2021): Questo post del blog è stato aggiornato per includere il successore di XLSR, chiamato XLS-R.

Wav2Vec2 è un modello preaddestrato per il riconoscimento automatico del parlato (ASR) ed è stato rilasciato nel settembre 2020 da Alexei Baevski, Michael Auli e Alex Conneau. Poco dopo è stata dimostrata la performance superiore di Wav2Vec2 su uno dei dataset più popolari in inglese per l’ASR, chiamato LibriSpeech, Facebook AI ha presentato una versione multilingue di Wav2Vec2, chiamata XLSR. XLSR sta per rappresentazioni del parlato cross-linguistiche e si riferisce alla capacità del modello di apprendere rappresentazioni del parlato utili in più lingue.

Il successore di XLSR, semplicemente chiamato XLS-R (che si riferisce a “XLM-R for Speech”), è stato rilasciato nel novembre 2021 da Arun Babu, Changhan Wang, Andros Tjandra e altri. XLS-R ha utilizzato quasi mezzo milione di ore di dati audio in 128 lingue per l’auto-pre-addestramento e ha dimensioni che vanno da 300 milioni fino a due miliardi di parametri. Puoi trovare i checkpoint preaddestrati su 🤗 Hub:

  • Wav2Vec2-XLS-R-300M
  • Wav2Vec2-XLS-R-1B
  • Wav2Vec2-XLS-R-2B

Similmente all’obiettivo di modellazione del linguaggio mascherato di BERT, XLS-R apprende rappresentazioni del parlato contestualizzate mascherando casualmente i vettori di caratteristiche prima di passarli a una rete transformer durante l’auto-pre-addestramento (ovvero il diagramma a sinistra di seguito).

Per il fine-tuning, viene aggiunto uno strato lineare singolo sopra la rete preaddestrata per addestrare il modello su dati etichettati di attività audio come il riconoscimento del parlato, la traduzione del parlato e la classificazione audio (ovvero il diagramma a destra di seguito).

XLS-R mostra miglioramenti impressionanti rispetto ai risultati precedenti di state-of-the-art sia nel riconoscimento del parlato, nella traduzione del parlato che nell’identificazione del parlante/lingua, confronta con la Tabella 3-6, Tabella 7-10 e Tabella 11-12 rispettivamente del paper ufficiale.

Configurazione

In questo blog, forniremo una spiegazione dettagliata su come XLS-R – più specificamente il checkpoint preaddestrato Wav2Vec2-XLS-R-300M – può essere sottoposto a fine-tuning per l’ASR.

A scopo dimostrativo, sottoponiamo il modello a fine-tuning sul dataset ASR a bassa risorsa di Common Voice, che contiene solo circa 4 ore di dati di addestramento convalidati.

XLS-R viene sottoposto a fine-tuning utilizzando la Connectionist Temporal Classification (CTC), che è un algoritmo utilizzato per addestrare reti neurali per problemi di sequenza-su-sequenza, come l’ASR e il riconoscimento della scrittura a mano.

Raccomando vivamente la lettura del ben scritto post del blog Sequence Modeling with CTC (2017) di Awni Hannun.

Prima di iniziare, installiamo datasets e transformers. Inoltre, abbiamo bisogno di torchaudio per caricare file audio e jiwer per valutare il nostro modello sottoposto a fine-tuning utilizzando la metrica del tasso di errore delle parole (WER) 1 {}^1 1.

!pip install datasets==1.18.3
!pip install transformers==4.11.3
!pip install huggingface_hub==0.1
!pip install torchaudio
!pip install librosa
!pip install jiwer

Suggeriamo vivamente di caricare direttamente i tuoi checkpoint di addestramento su Hugging Face Hub durante l’addestramento. Hugging Face Hub ha integrato il controllo delle versioni in modo da poter essere sicuro che nessun checkpoint del modello venga perso durante l’addestramento.

Per farlo, devi memorizzare il tuo token di autenticazione dal sito web di Hugging Face (registrati qui se non l’hai ancora fatto!)

from huggingface_hub import notebook_login

notebook_login()

Output di stampa:

    Accesso riuscito
    Il tuo token è stato salvato in /root/.huggingface/token

Successivamente è necessario installare Git-LFS per caricare i checkpoint del modello:

apt install git-lfs

1 {}^1 1 Nel paper, il modello è stato valutato utilizzando il tasso di errore dei fonemi (PER), ma il metrico più comune nell’ASR è il tasso di errore delle parole (WER). Per mantenere questo notebook il più generale possibile, abbiamo deciso di valutare il modello utilizzando il WER.

Prepara i Dati, il Tokenizer, l’estrattore di Caratteristiche

I modelli ASR trascrivono il parlato in testo, il che significa che abbiamo bisogno sia di un estrattore di caratteristiche che elabori il segnale del parlato nel formato di input del modello, ad esempio un vettore di caratteristiche, sia di un tokenizer che elabori il formato di output del modello in testo.

In 🤗 Transformers, il modello XLS-R è accompagnato sia da un tokenizer, chiamato Wav2Vec2CTCTokenizer, sia da un estrattore di caratteristiche, chiamato Wav2Vec2FeatureExtractor.

Iniziamo creando il tokenizer per decodificare le classi di output previste nella trascrizione di output.

Crea Wav2Vec2CTCTokenizer

Un modello XLS-R pre-addestrato mappa il segnale del parlato in una sequenza di rappresentazioni di contesto come illustrato nella figura sopra. Tuttavia, per il riconoscimento del parlato il modello deve mappare questa sequenza di rappresentazioni di contesto alla sua corrispondente trascrizione, il che significa che deve essere aggiunto uno strato lineare sopra il blocco del trasformatore (mostrato in giallo nel diagramma sopra). Questo strato lineare viene utilizzato per classificare ogni rappresentazione di contesto in una classe di token, in modo analogo a come viene aggiunto uno strato lineare sopra gli embedding di BERT per ulteriori classificazioni dopo il pre-addestramento (cf. con la sezione ‘BERT’ del seguente post sul blog). Dopo il pre-addestramento viene aggiunto uno strato lineare sopra gli embedding di BERT per ulteriori classificazioni – cf. con la sezione ‘BERT’ di questo post sul blog.

La dimensione di output di questo strato corrisponde al numero di token nel vocabolario, che non dipende dal compito di pre-addestramento di XLS-R, ma solo dal dataset etichettato utilizzato per il fine-tuning. Quindi nel primo passaggio, daremo un’occhiata al dataset scelto di Common Voice e definiremo un vocabolario basato sulle trascrizioni.

Per prima cosa, andiamo sul sito web ufficiale di Common Voice e scegliamo una lingua per il fine-tuning di XLS-R. Per questo notebook, useremo il turco.

Per ciascun dataset specifico della lingua, è possibile trovare un codice lingua corrispondente alla lingua scelta. Su Common Voice, cercare il campo “Versione”. Il codice lingua corrisponde quindi al prefisso prima dell’underscore. Per il turco, ad esempio, il codice lingua è "tr".

Grande, ora possiamo utilizzare l’API semplice di 🤗 Datasets per scaricare i dati. Il nome del dataset è "common_voice", il nome della configurazione corrisponde al codice lingua, che è "tr" nel nostro caso.

Common Voice ha molti split diversi, incluso invalidated, che si riferisce ai dati che non sono stati valutati come “abbastanza puliti” per essere considerati utili. In questo notebook, utilizzeremo solo gli split "train", "validation" e "test".

Perché il dataset turco è così piccolo, uniremo i dati di validazione e di addestramento in un dataset di addestramento e utilizzeremo solo i dati di test per la validazione.

from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("common_voice", "tr", split="train+validation")
common_voice_test = load_dataset("common_voice", "tr", split="test")

Molti dataset ASR forniscono solo il testo di destinazione, 'sentence' per ogni array audio 'audio' e file 'path'. Common Voice fornisce effettivamente molte altre informazioni su ciascun file audio, come l’'accent', ecc. Per mantenere il notebook il più generale possibile, consideriamo solo il testo trascritto per il fine-tuning.

common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

Scriviamo una breve funzione per visualizzare alcuni campioni casuali del dataset e la eseguiamo un paio di volte per avere un’idea delle trascrizioni.

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def mostra_elementi_casuali(dataset, num_esempi=10):
    assert num_esempi <= len(dataset), "Non è possibile selezionare più elementi di quanti presenti nel dataset."
    selezioni = []
    for _ in range(num_esempi):
        selezione = random.randint(0, len(dataset)-1)
        while selezione in selezioni:
            selezione = random.randint(0, len(dataset)-1)
        selezioni.append(selezione)
    
    df = pd.DataFrame(dataset[selezioni])
    display(HTML(df.to_html()))

Output di stampa:

Ok! Le trascrizioni sembrano abbastanza pulite. Avendo tradotto le frasi trascritte, sembra che la lingua corrisponda più al testo scritto che al dialogo rumoroso. Ha senso considerando che Common Voice è un corpus di lettura orale collaborativo.

Possiamo vedere che le trascrizioni contengono alcuni caratteri speciali, come ,.?!;: . Senza un modello linguistico, è molto più difficile classificare frammenti di discorso in tali caratteri speciali perché non corrispondono realmente a un’unità sonora caratteristica. Ad esempio, la lettera "s" ha un suono più o meno chiaro, mentre il carattere speciale "." non lo ha. Inoltre, per comprendere il significato di un segnale vocale, di solito non è necessario includere caratteri speciali nella trascrizione.

Semplicemente rimuoviamo tutti i caratteri che non contribuiscono al significato di una parola e che non possono essere rappresentati realmente da un suono acustico e normalizziamo il testo.

import re
regex_caratteri_speciali = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'

def rimuovi_caratteri_speciali(batch):
    batch["sentence"] = re.sub(regex_caratteri_speciali, '', batch["sentence"]).lower()
    return batch

common_voice_train = common_voice_train.map(rimuovi_caratteri_speciali)
common_voice_test = common_voice_test.map(rimuovi_caratteri_speciali)

Rivediamo ancora una volta le etichette di testo elaborate.

mostra_elementi_casuali(common_voice_train.remove_columns(["path","audio"]))

Output di stampa:

Bene! Questo sembra migliore. Abbiamo rimosso la maggior parte dei caratteri speciali dalle trascrizioni e le abbiamo normalizzate in minuscolo.

Prima di finalizzare la pre-elaborazione, è sempre vantaggioso consultare un madrelingua della lingua di destinazione per verificare se il testo può essere ulteriormente semplificato. Per questo post del blog, Merve è stata così gentile da dare un’occhiata veloce e ha notato che i caratteri “cappellati” – come â – non vengono più utilizzati in turco e possono essere sostituiti dal loro equivalente “non cappellato”, ad esempio a.

Ciò significa che dovremmo sostituire una frase come "yargı sistemi hâlâ sağlıksız" con "yargı sistemi hala sağlıksız".

Scriviamo un’altra breve funzione di mappatura per semplificare ulteriormente le etichette di testo. Ricordiamo che più semplici sono le etichette di testo, più facile è per il modello imparare a prevedere tali etichette.

def sostituisci_caratteri_cappellati(batch):
    batch["sentence"] = re.sub('[â]', 'a', batch["sentence"])
    batch["sentence"] = re.sub('[î]', 'i', batch["sentence"])
    batch["sentence"] = re.sub('[ô]', 'o', batch["sentence"])
    batch["sentence"] = re.sub('[û]', 'u', batch["sentence"])
    return batch

common_voice_train = common_voice_train.map(sostituisci_caratteri_cappellati)
common_voice_test = common_voice_test.map(sostituisci_caratteri_cappellati)

Nel CTC, è comune classificare i frammenti di discorso in lettere, quindi faremo lo stesso qui. Estrarremo tutte le lettere distinte dai dati di addestramento e di test e costruiremo il nostro vocabolario da questo insieme di lettere.

Scriviamo una funzione di mappatura che concatena tutte le trascrizioni in una sola trascrizione lunga e quindi trasforma la stringa in un insieme di caratteri. È importante passare l’argomento batched=True alla funzione map(...) in modo che la funzione di mappatura abbia accesso a tutte le trascrizioni contemporaneamente.

def estrai_tutti_i_caratteri(batch):
  tutto_il_testo = " ".join(batch["sentence"])
  vocabolario = list(set(tutto_il_testo))
  return {"vocabolario": [vocabolario], "tutto_il_testo": [tutto_il_testo]}

vocabolario_addestramento = common_voice_train.map(estrai_tutti_i_caratteri, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocabolario_test = common_voice_test.map(estrai_tutti_i_caratteri, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

Ora, creiamo l’unione di tutte le lettere distinte nel dataset di addestramento e nel dataset di test e convertiamo la lista risultante in un dizionario enumerato.

vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict

Stampa l’output:

{
 ' ': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 'ç': 27,
 'ë': 28,
 'ö': 29,
 'ü': 30,
 'ğ': 31,
 'ı': 32,
 'ş': 33,
 '̇': 34
}

Interessante, vediamo che tutte le lettere dell’alfabeto compaiono nel dataset (il che non è davvero sorprendente) e abbiamo anche estratto i caratteri speciali "" e ' . Nota che non abbiamo escluso quei caratteri speciali perché:

Il modello deve imparare a predire quando una parola è finita, altrimenti la previsione del modello sarebbe sempre una sequenza di caratteri che renderebbe impossibile separare le parole l’una dall’altra.

Bisogna sempre tenere presente che la pre-elaborazione è un passaggio molto importante prima di addestrare il proprio modello. Ad esempio, non vogliamo che il nostro modello differenzi tra a e A solo perché abbiamo dimenticato di normalizzare i dati. La differenza tra a e A non dipende affatto dal “suono” della lettera, ma più dalle regole grammaticali – ad esempio, usare una lettera maiuscola all’inizio della frase. Pertanto, ha senso rimuovere la differenza tra lettere maiuscole e minuscole in modo che il modello abbia più facilità nell’apprendere la trascrizione del discorso.

Per rendere più chiaro che " " ha la propria classe di token, gli assegnamo un carattere più visibile | . Inoltre, aggiungiamo anche un token “sconosciuto” in modo che il modello possa gestire successivamente i caratteri non incontrati nel set di addestramento di Common Voice.

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

Infine, aggiungiamo anche un token di riempimento che corrisponde al “token vuoto” di CTC. Il “token vuoto” è un componente fondamentale dell’algoritmo CTC. Per ulteriori informazioni, si prega di dare un’occhiata alla sezione “Allineamento” qui .

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

Bene, ora il nostro vocabolario è completo e consiste di 39 token, il che significa che lo strato lineare che aggiungeremo sopra il checkpoint XLS-R preaddestrato avrà una dimensione di output di 39.

Salviamo ora il vocabolario come file json.

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In un ultimo passaggio, utilizziamo il file json per caricare il vocabolario in un’istanza della classe Wav2Vec2CTCTokenizer.

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

Se si desidera riutilizzare il tokenizer appena creato con il modello sintonizzato di questo notebook, è vivamente consigliato caricare il tokenizer sul Hugging Face Hub . Chiamiamo il repository a cui caricheremo i file "wav2vec2-large-xlsr-turkish-demo-colab" :

repo_name = "wav2vec2-large-xls-r-300m-tr-colab"

e carica il tokenizer su 🤗 Hub .

tokenizer.push_to_hub(repo_name)

Grande, puoi vedere il repository appena creato su https://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-tr-colab

Crea Wav2Vec2FeatureExtractor

Il parlato è un segnale continuo e, per essere trattato dai computer, deve prima essere discretizzato, cosa che di solito viene chiamata campionamento. Il tasso di campionamento svolge un ruolo importante poiché definisce quanti punti di dati del segnale vocale vengono misurati al secondo. Pertanto, il campionamento con un tasso di campionamento più elevato produce una migliore approssimazione del vero segnale vocale ma richiede anche più valori al secondo.

Un checkpoint preaddestrato si aspetta che i dati in input siano stati campionati più o meno dalla stessa distribuzione dei dati su cui è stato addestrato. Segnali vocali campionati con due diversi tassi di campionamento hanno una distribuzione molto diversa. Ad esempio, raddoppiando il tasso di campionamento si ottengono punti dati due volte più lunghi. Pertanto, prima di effettuare il fine-tuning di un checkpoint preaddestrato di un modello ASR, è fondamentale verificare che il tasso di campionamento dei dati utilizzati per il preaddestramento del modello corrisponda al tasso di campionamento del dataset utilizzato per il fine-tuning del modello.

XLS-R è stato preaddestrato su dati audio di Babel, Multilingual LibriSpeech (MLS), Common Voice, VoxPopuli e VoxLingua107 con un tasso di campionamento di 16kHz. Common Voice, nella sua forma originale, ha un tasso di campionamento di 48kHz, quindi dovremo ridimensionare i dati di fine-tuning a 16kHz.

Un oggetto Wav2Vec2FeatureExtractor richiede i seguenti parametri per essere istanziato:

  • feature_size: I modelli vocali prendono una sequenza di vettori di caratteristiche come input. Anche se la lunghezza di questa sequenza varia ovviamente, la dimensione delle caratteristiche non dovrebbe farlo. Nel caso di Wav2Vec2, la dimensione delle caratteristiche è 1 perché il modello è stato addestrato sul segnale vocale grezzo 2 {}^2 2 .
  • sampling_rate: Il tasso di campionamento su cui è stato addestrato il modello.
  • padding_value: Per l’inferenza in batch, gli input più corti devono essere riempiti con un valore specifico.
  • do_normalize: Se l’input deve essere normalizzato a zero media e varianza unitaria o meno. Di solito, i modelli vocali funzionano meglio quando normalizzano l’input.
  • return_attention_mask: Se il modello deve utilizzare una attention_mask per l’inferenza in batch. In generale, i checkpoint dei modelli XLS-R dovrebbero sempre utilizzare la attention_mask.
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

Ottimo, la pipeline di estrazione delle caratteristiche di XLS-R è così completamente definita!

Per una maggiore facilità d’uso, l’estrattore di caratteristiche e il tokenizer sono racchiusi in una singola classe Wav2Vec2Processor in modo che sia necessario solo un oggetto model e processor.

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

Successivamente, possiamo preparare il dataset.

Preprocessa i dati

Fino ad ora, non abbiamo guardato ai valori effettivi del segnale vocale ma solo alla trascrizione. Oltre alla frase, i nostri dataset includono altri due nomi di colonne, path e audio. path indica il percorso assoluto del file audio. Diamo un’occhiata.

common_voice_train[0]["path"]

XLS-R si aspetta l’input nel formato di un array monodimensionale a 16 kHz. Ciò significa che il file audio deve essere caricato e ridimensionato.

Fortunatamente, datasets lo fa automaticamente chiamando l’altra colonna audio. Proviamolo.

common_voice_train[0]["audio"]

    {'array': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
            -8.8930130e-05, -3.8027763e-05, -2.9146671e-05], dtype=float32),
     'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
     'sampling_rate': 48000}

Ottimo, possiamo vedere che il file audio è stato caricato automaticamente. Questo è grazie alla nuova funzionalità “Audio” introdotta in datasets == 1.18.3, che carica e campiona i file audio al volo quando viene chiamata.

Nell’esempio sopra possiamo vedere che i dati audio vengono caricati con un tasso di campionamento di 48kHz, mentre il modello si aspetta 16kHz. Possiamo impostare la funzione audio al tasso di campionamento corretto utilizzando “cast_column” :

common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

Diamo un’occhiata nuovamente a “audio”.

common_voice_train[0]["audio"]

    {'array': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
            -7.4556941e-05, -1.4621433e-05, -5.7861507e-05], dtype=float32),
     'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
     'sampling_rate': 16000}

Sembra che abbia funzionato! Ascoltiamo alcuni file audio per capire meglio il dataset e verificare che l’audio sia stato caricato correttamente.

import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)

Output della stampa:

    sunulan bütün teklifler i̇ngilizce idi

Sembra che i dati siano ora caricati e campionati correttamente.

Può essere ascoltato che i parlanti cambiano insieme alla loro velocità di parlato, all’accento e all’ambiente di sfondo, ecc. Nel complesso, le registrazioni suonano abbastanza chiare, come ci si può aspettare da un corpus di lettura di massa basato sulla folla.

Facciamo un ultimo controllo che i dati siano preparati correttamente, stampando la forma dell’input vocale, la sua trascrizione e il tasso di campionamento corrispondente.

rand_int = random.randint(0, len(common_voice_train)-1)

print("Testo target:", common_voice_train[rand_int]["sentence"])
print("Forma dell'array di input:", common_voice_train[rand_int]["audio"]["array"].shape)
print("Tasso di campionamento:", common_voice_train[rand_int]["audio"]["sampling_rate"])

Output della stampa:

    Testo target: makedonya bu yıl otuz adet tyetmiş iki tankı aldı
    Forma dell'array di input: (71040,)
    Tasso di campionamento: 16000

Perfetto! Tutto sembra in ordine: i dati sono un array unidimensionale, il tasso di campionamento corrisponde sempre a 16kHz e il testo target è normalizzato.

Infine, possiamo utilizzare Wav2Vec2Processor per elaborare i dati nel formato richiesto da Wav2Vec2ForCTC per l’addestramento. Per farlo, utilizziamo la funzione “map(…)” di Dataset.

Per prima cosa, carichiamo e ridimensioniamo i dati audio, semplicemente chiamando batch["audio"]. In secondo luogo, estraiamo i input_values dal file audio caricato. Nel nostro caso, il Wav2Vec2Processor normalizza solo i dati. Per altri modelli di speech, tuttavia, questo passaggio può includere una estrazione delle caratteristiche più complessa, come l’estrazione delle caratteristiche Log-Mel. In terzo luogo, codifichiamo le trascrizioni in identificatori di etichetta.

Nota: Questa funzione di mappatura è un buon esempio di come dovrebbe essere utilizzata la classe Wav2Vec2Processor. In un contesto “normale”, chiamare processor(...) viene reindirizzato al metodo di chiamata di Wav2Vec2FeatureExtractor. Quando invece si avvolge il processore nel contesto di as_target_processor, lo stesso metodo viene reindirizzato al metodo di chiamata di Wav2Vec2CTCTokenizer. Per ulteriori informazioni, consulta la documentazione.

def prepare_dataset(batch):
    audio = batch["audio"]

    # L'output in batch viene "scomposto"
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

Applichiamo la funzione di preparazione dei dati a tutti gli esempi.

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

Nota: Attualmente, i datasets utilizzano torchaudio e librosa per il caricamento e il ridimensionamento dell’audio. Se desideri implementare il tuo caricamento/ridimensionamento dei dati personalizzato, puoi semplicemente utilizzare la colonna "path" e ignorare la colonna "audio".

Le sequenze di input lunghe richiedono molta memoria. XLS-R si basa sull’auto-attenzione. Il requisito di memoria aumenta quadraticamente con la lunghezza di input per sequenze di input lunghe (cf. con questo post di Reddit). Nel caso in cui questa demo si blocchi con un errore di “Memoria esaurita”, potresti voler rimuovere il commento dalle seguenti righe per filtrare tutte le sequenze più lunghe di 5 secondi per l’addestramento.

#max_input_length_in_sec = 5.0
#common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

Fantastico, ora siamo pronti per iniziare l’addestramento!

Addestramento

I dati sono elaborati in modo che siamo pronti per configurare la pipeline di addestramento. Faremo uso del Trainer di 🤗 per il quale dobbiamo essenzialmente fare quanto segue:

  • Definire un data collator. A differenza della maggior parte dei modelli NLP, XLS-R ha una lunghezza di input molto maggiore rispetto alla lunghezza di output. Ad esempio, un campione di lunghezza di input 50000 ha una lunghezza di output di al massimo 100. Date le grandi dimensioni di input, è molto più efficiente riempire dinamicamente i batch di addestramento, il che significa che tutti i campioni di addestramento dovrebbero essere riempiti solo fino al campione più lungo nel loro batch e non al campione più lungo in assoluto. Pertanto, il fine-tuning di XLS-R richiede un data collator di riempimento speciale che definiremo di seguito

  • Metrica di valutazione. Durante l’addestramento, il modello dovrebbe essere valutato sul tasso di errore delle parole. Dovremmo definire una funzione compute_metrics di conseguenza

  • Caricare un checkpoint preaddestrato. Dobbiamo caricare un checkpoint preaddestrato e configurarlo correttamente per l’addestramento.

  • Definire la configurazione dell’addestramento.

Dopo aver effettuato il fine-tuning del modello, lo valuteremo correttamente sui dati di test e verificheremo che abbia effettivamente imparato a trascrivere correttamente il discorso.

Configurazione del Trainer

Iniziamo definendo il data collator. Il codice del data collator è stato copiato da questo esempio.

Senza entrare troppo nei dettagli, a differenza dei data collator comuni, questo data collator tratta gli input_values e le labels in modo diverso e quindi applica a ognuno di essi funzioni di riempimento separate (utilizzando ancora una volta il gestore di contesto del processore XLS-R). Questo è necessario perché nell’input e nell’output del discorso sono presenti diverse modalità, il che significa che non devono essere trattati dalla stessa funzione di riempimento. Analogamente ai data collator comuni, i token di riempimento nelle etichette sono impostati su -100 in modo che quei token non siano considerati nel calcolo della perdita.

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator che effettuerà un padding dinamico agli input ricevuti.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            Il processore utilizzato per elaborare i dati.
        padding (:obj:`bool`, :obj:`str` o :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `opzionale`, predefinito a :obj:`True`):
            Seleziona una strategia per effettuare il padding delle sequenze restituite (in base al lato di padding del modello e all'indice di padding)
            tra:
            * :obj:`True` o :obj:`'longest'`: Effettua il padding alla sequenza più lunga nel batch (o nessun padding se viene fornita solo una sequenza).
            * :obj:`'max_length'`: Effettua il padding a una lunghezza massima specificata dall'argomento :obj:`max_length` o alla lunghezza massima di input accettabile per il modello se quell'argomento non viene fornito.
            * :obj:`False` o :obj:`'do_not_pad'` (predefinito): Nessun padding (ossia può generare un batch con sequenze di lunghezze diverse).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # divide input e label poiché devono avere lunghezze diverse e necessitano
        # di metodi di padding diversi
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # sostituisci il padding con -100 per ignorare correttamente la loss
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In seguito, viene definita la metrica di valutazione. Come accennato in precedenza, la metrica predominante in ASR è il tasso di errore delle parole (WER), quindi la utilizzeremo anche in questo notebook.

wer_metric = load_metric("wer")

Il modello restituirà una sequenza di vettori di logit: y 1 , … , y m \mathbf{y}_1, \ldots, \mathbf{y}_m y 1 ​ , … , y m ​ con y 1 = f θ ( x 1 , … , x n ) [ 0 ] \mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0] y 1 ​ = f θ ​ ( x 1 ​ , … , x n ​ ) [ 0 ] e n > > m n >> m n > > m .

Un vettore di logit y 1 \mathbf{y}_1 y 1 ​ contiene le log-odds per ogni parola nel vocabolario che abbiamo definito in precedenza, quindi len ( y i ) = \text{len}(\mathbf{y}_i) = len ( y i ​ ) = config.vocab_size . Siamo interessati alla previsione più probabile del modello e quindi prendiamo l’argomento massimo (argmax(...)) dei logit. Inoltre, trasformiamo le etichette codificate nella stringa originale sostituendo -100 con pad_token_id e decodifichiamo gli id assicurandoci che i token consecutivi non siano raggruppati nello stesso token nello stile CTC 1 {}^1 1 .

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # non vogliamo raggruppare i token durante il calcolo delle metriche
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Ora possiamo caricare il checkpoint preaddestrato di Wav2Vec2-XLS-R-300M . L’pad_token_id del tokenizer deve essere definito come l’pad_token_id del modello o nel caso di Wav2Vec2ForCTC anche come il token vuoto di CTC 2 {}^2 2 . Per risparmiare memoria GPU, abilitiamo il checkpointing dei gradienti di PyTorch e impostiamo anche la riduzione della loss su “mean”.

Poiché il dataset è piuttosto piccolo (~6h di dati di addestramento) e poiché Common Voice è piuttosto rumoroso, sembra che la messa a punto del checkpoint wav2vec2-xls-r-300m di Facebook richieda un po’ di messa a punto degli iperparametri. Pertanto, ho dovuto sperimentare con diversi valori di dropout, tasso di dropout di mascheramento di SpecAugment, dropout di livello e learning rate fino a quando l’addestramento sembrava abbastanza stabile.

Nota: Quando si utilizza questo notebook per addestrare XLS-R su un’altra lingua di Common Voice, tali impostazioni degli iperparametri potrebbero non funzionare molto bene. Sentiti libero di adattarli a seconda del tuo caso d’uso.

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m", 
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.0,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

Il primo componente di XLS-R è costituito da uno stack di livelli CNN che vengono utilizzati per estrarre caratteristiche acusticamente significative – ma contestualmente indipendenti – dal segnale audio grezzo. Questa parte del modello è già stata addestrata a sufficienza durante il preaddestramento e, come indicato nel paper, non è necessario eseguire ulteriori messa a punto. Pertanto, possiamo impostare requires_grad su False per tutti i parametri della parte di estrazione delle caratteristiche.

model.freeze_feature_extractor()

In un ultimo passaggio, definiamo tutti i parametri relativi all’addestramento. Per fornire ulteriori spiegazioni su alcuni dei parametri:

  • group_by_length rende l’addestramento più efficiente raggruppando campioni di addestramento di lunghezza simile in un unico batch. Ciò può velocizzare significativamente il tempo di addestramento riducendo notevolmente il numero complessivo di token di padding inutili che vengono passati attraverso il modello
  • learning_rate e weight_decay sono stati sintonizzati euristicamente fino a quando la messa a punto è diventata stabile. Nota che questi parametri dipendono fortemente dal dataset di Common Voice e potrebbero non essere ottimali per altri dataset di speech.

Per ulteriori spiegazioni su altri parametri, è possibile consultare la documentazione.

Durante l’addestramento, un checkpoint verrà caricato in modo asincrono nell’Hub ogni 400 passi di addestramento. Ciò consente anche di sperimentare con il widget demo anche mentre il modello è ancora in fase di addestramento.

Nota: Se non si desidera caricare i checkpoint del modello nell’Hub, è sufficiente impostare push_to_hub=False.

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=30,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=400,
  eval_steps=400,
  logging_steps=400,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=True,
)

Ora, tutte le istanze possono essere passate a Trainer e siamo pronti per iniziare l’addestramento!

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

1 {}^1 1 Per consentire ai modelli di essere indipendenti dalla velocità del parlante, in CTC, i token consecutivi che sono identici vengono semplicemente raggruppati come un singolo token. Tuttavia, le etichette codificate non dovrebbero essere raggruppate durante la decodifica poiché non corrispondono ai token previsti dal modello, motivo per cui il parametro group_tokens=False deve essere passato. Se non passassimo questo parametro, una parola come "hello" sarebbe erroneamente codificata e decodificata come "helo". 2 {}^2 2 Il token blank consente al modello di prevedere una parola, come "hello", costringendolo a inserire il token blank tra le due l. Una previsione CTC-conforme di "hello" del nostro modello sarebbe [PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD].

Allenamento

L’allenamento richiederà diverse ore a seconda della GPU allocata per questo notebook. Sebbene il modello addestrato produca risultati soddisfacenti sui dati di test di Common Voice in lingua turca, non è assolutamente un modello ottimizzato. Lo scopo di questo notebook è solo quello di dimostrare come ottimizzare XLS-R XLSR-Wav2Vec2 su un dataset ASR.

A seconda della GPU allocata al tuo Google Colab, potrebbe essere possibile che tu stia vedendo un errore “out-of-memory” qui. In tal caso, probabilmente è meglio ridurre per_device_train_batch_size a 8 o anche meno e aumentare gradient_accumulation.

trainer.train()

Stampa output:

La perdita di allenamento e la WER di convalida diminuiscono in modo soddisfacente.

Ora puoi caricare il risultato dell’allenamento nell’Hub, esegui semplicemente questa istruzione:

trainer.push_to_hub()

Ora puoi condividere questo modello con tutti i tuoi amici, familiari, animali domestici preferiti: tutti possono caricarlo con l’identificatore “tuo-nome/il-nome-che-hai-scelto”, ad esempio:

from transformers import AutoModelForCTC, Wav2Vec2Processor

model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")

Per ulteriori esempi su come ottimizzare XLS-R, consulta gli esempi ufficiali di 🤗 Transformers.

Valutazione

Come ultimo controllo, carichiamo il modello e verifichiamo che abbia effettivamente imparato a trascrivere il discorso in turco.

Carichiamo prima il checkpoint pre-addestrato.

model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)

Ora, prenderemo solo il primo esempio dell’insieme di test, lo passeremo attraverso il modello e prenderemo l’argmax(...) dei logit per recuperare gli ID dei token predetti.

input_dict = processor(common_voice_test[0]["input_values"], return_tensors="pt", padding=True)

logits = model(input_dict.input_values.to("cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

È vivamente consigliato passare l’argomento sampling_rate a questa funzione. Non farlo può causare errori silenziosi che potrebbero essere difficili da individuare.

Abbiamo adattato abbastanza l’insieme di test common_voice_test in modo che l’istanza del dataset non contenga più l’etichetta della frase originale. Pertanto, riutilizziamo il dataset originale per ottenere l’etichetta del primo esempio.

common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")

Infine, possiamo decodificare l’esempio.

print("Predizione:")
print(processor.decode(pred_ids))

print("\nRiferimento:")
print(common_voice_test_transcription[0]["sentence"].lower())

Stampa output:

Ok! La trascrizione può sicuramente essere riconosciuta dalla nostra predizione, ma non è ancora perfetta. Allenare il modello un po’ più a lungo, dedicare più tempo alla pre-elaborazione dei dati e, in particolare, utilizzare un modello di linguaggio per la decodifica, migliorerebbe certamente le prestazioni complessive del modello.

Per un modello dimostrativo su una lingua a bassa risorsa, i risultati sono comunque accettabili 🤗.