Un’immersione approfondita nell’algoritmo FlashAttention – parte 3

Esplorazione approfondita dell'algoritmo FlashAttention - Parte 3

fonte: articolo FlashAttention

Benvenuti alla terza parte della nostra serie su Flash Attention! In questo segmento, approfondiremo il funzionamento interno dell’algoritmo FlashAttention V1, analizzando i suoi concetti e principi fondamentali. Se sei nuovo sull’argomento o desideri saperne di più sulle GPU e su come FlashAttention funziona a livello avanzato, assicurati di consultare la Comprensione delle GPU e gli avanzamenti nell’accelerazione GPU in questa serie.

Per iniziare, chiariremo che le ottimizzazioni e le migliorie di velocità di FlashAttention si concentrano principalmente sulle GPU. Sebbene l’articolo faccia riferimento alla cache L1 e L2, queste ottimizzazioni sono fondamentalmente incentrate sulle prestazioni delle GPU e non sulla RAM o su altri componenti di memoria.

Un Riassunto Veloce

In una tipica architettura delle GPU, i dati vengono memorizzati sul disco rigido, ma per qualsiasi calcolo significativo occorre spostare i dati nella RAM. Da lì, i dati attraversano diverse gerarchie di memoria prima di raggiungere la GPU. L’algoritmo FlashAttention è sintonizzato per sfruttare le capacità dei tensor core nelle GPU moderne. Questo è particolarmente importante poiché, durante l’addestramento di modelli come GPT-3, si è scoperto che i tensor core rimanevano inattivi per circa il 50% del tempo.

FlashAttention è un algoritmo notevole per due ragioni principali: la divisione in blocchi (tiling) e la ricomputazione (recomputation).

La divisione in blocchi è una tecnica che suddivide le matrici Q, K e V in blocchi più piccoli. Questa divisione consente all’algoritmo di leggere e processare queste matrici blocco per blocco anziché caricare tutto in memoria della GPU in una volta sola.

La ricomputazione, invece, si occupa della retropropagazione, un aspetto essenziale dell’addestramento dei modelli. Invece di memorizzare i valori nella memoria ad alta larghezza di banda (HBM) e accedere ripetutamente a questa memoria, Flash Attention ricalcola i valori quando necessario. Sebbene la ricomputazione aumenti il numero di operazioni in virgola mobile, riduce significativamente il tempo dedicato all’accesso alla memoria.

Ora, entriamo nei dettagli tecnici dell’algoritmo menzionato nell’articolo:

Preferisci una spiegazione visuale? Dai un’occhiata al mio video sull’algoritmo FlashAttention V1

Algoritmo FlashAttention e Calcolo Normalizzatore Online per Softmax

Algoritmo FlashAttention

Fonte: articolo FlashAttention

Concetto di Divisione in Blocchi

Il primo concetto fondamentale di FlashAttention è la divisione in blocchi (tiling). Ogni token in un modello trasformatore ha matrici associate per Q, K e V. Il processo di divisione in blocchi suddivide queste matrici in blocchi gestibili per il processo. La dimensione del blocco è tipicamente impostata a 128, come indicato dagli autori.

Per iniziare, è necessario determinare le dimensioni dei blocchi per le matrici Q, K e V. Inoltre, iniziamo con la creazione di variabili intermedie come “l” e “m” per memorizzare i risultati. L’output finale è il prodotto di tutte le variabili intermedie. Questa divisione in blocchi e la memorizzazione dei risultati intermedi sono essenziali per combinare efficientemente questi risultati successivamente nel processo.

Softmax Sicura

Il cuore di FlashAttention risiede nella sua implementazione della funzione Softmax. La funzione Softmax standard spesso si scontra con sfide legate a overflow e underflow, in quanto i valori esponenziali possono diventare eccessivamente grandi o piccoli. FlashAttention utilizza una “Softmax sicura” per mitigare questi problemi.

La Softmax sicura opera individuando il valore massimo nell’array di input e sottraendo questo massimo da ogni elemento nell’array prima dell’elevazione a potenza. Questo aggiustamento evita potenziali problemi di overflow o underflow, rendendo i calcoli numericamente stabili.

La formula per la Softmax sicura:

Fonte: Immagine dell'autore

SafeSoftmax con Calcolo Normalizzatore Online

Il Safe Softmax di FlashAttention trae ispirazione da un concetto noto come “calcolo normalizzatore online”. Questo metodo, descritto in un articolo di NVIDIA, Calcolo normalizzatore online per Softmax, fornisce un modo per calcolare la Softmax senza eseguire accessi di memoria ridondanti. L’approccio è progettato per ridurre il numero di operazioni di memoria, rendendo l’algoritmo più efficiente dal punto di vista della memoria.

Fonte: Immagine dell'autore.

In questo processo, l’algoritmo mantiene una somma in esecuzione e modifica il calcolo Softmax utilizzando valori intermedi calcolati al volo. Lo fa annullando le modifiche apportate durante le iterazioni precedenti e applicando gli adeguamenti necessari durante l’elaborazione di nuovi elementi. Questo approccio permette di effettuare calcoli Softmax senza dover rivisitare tutti gli elementi dell’array di input, riducendo così significativamente l’accesso alla memoria.

Combinazione di Tiling e Safe Softmax

FlashAttention combina abilmente i concetti di tiling e Safe Softmax per massimizzare l’efficienza del suo meccanismo di attenzione. Suddividendo l’input in blocchi gestibili e applicando Safe Softmax a livello di blocco, FlashAttention minimizza l’accesso alla memoria mantenendo la stabilità numerica. Questo approccio garantisce che l’algoritmo operi in modo ottimale con i tensor core delle GPU, spesso sottoutilizzati durante l’addestramento di modelli di deep learning o di modelli linguistici.

Conclusione

In sintesi, FlashAttention V1 è un meccanismo di attenzione altamente efficiente e numericamente stabile che sfrutta la potenza delle GPU e dei tensor core. Il suo uso innovativo di tiling e Safe Softmax garantisce che i collo di bottiglia dell’accesso alla memoria siano ridotti al minimo, consentendo una formazione più efficace dei modelli di linguaggio come il GPT-3.

L’integrazione di queste tecniche mette in mostra la brillante sinergia tra concetti matematici, ingegno algoritmico e ottimizzazione hardware nel campo del deep learning.

Riferimenti

  1. Articolo su FlashAttention
  2. Articolo FlashAttention Annotato
  3. Calcolo normalizzatore online per Softmax
  4. Guida sulle Prestazioni delle GPU per Deep Learning di NVIDIA​