Perfeziona Llama 2 con DPO

Make Llama 2 perfect with DPO

Introduzione

Il Reinforcement Learning from Human Feedback (RLHF) è diventato l’ultimo passo di addestramento de facto degli LLM come GPT-4 o Claude per garantire che gli output del modello di linguaggio siano allineati alle aspettative umane come la loquacità o le funzionalità di sicurezza. Tuttavia, porta con sé parte della complessità del RL nel NLP: dobbiamo costruire una buona funzione di ricompensa, addestrare il modello a stimare il valore di uno stato e contemporaneamente fare attenzione a non allontanarci troppo dal modello originale e produrre invece testo senza senso. Tale processo è piuttosto complesso e richiede una serie di componenti mobili complesse in cui non è sempre facile fare le cose nel modo giusto.

Il recente articolo “Direct Preference Optimization” di Rafailov, Sharma, Mitchell et al. propone di trasformare l’obiettivo basato su RL utilizzato dai metodi esistenti in un obiettivo che può essere direttamente ottimizzato tramite una semplice perdita di entropia incrociata binaria, semplificando così notevolmente questo processo di raffinamento degli LLM.

Questo post del blog introduce il metodo di ottimizzazione delle preferenze dirette (DPO), ora disponibile nella libreria TRL, e mostra come è possibile ottimizzare il recente modello Llama v2 con 7 miliardi di parametri sul dataset di preferenze di stack-exchange, che contiene risposte classificate alle domande sui diversi portali di stack-exchange.

DPO vs PPO

Nel modello tradizionale di ottimizzazione delle preferenze derivate dall’uomo tramite RL, il metodo più diffuso è stato utilizzare un modello di ricompensa ausiliario e affinare il modello di interesse in modo che massimizzi questa ricompensa data tramite il meccanismo del RL. In modo intuitivo, utilizziamo il modello di ricompensa per fornire un feedback al modello che stiamo ottimizzando in modo che generi campioni ad alta ricompensa più spesso e campioni a bassa ricompensa meno spesso. Allo stesso tempo, utilizziamo un modello di riferimento congelato per assicurarci che ciò che viene generato non si discosti troppo e continui a mantenere una certa diversità di generazione. Questo di solito viene fatto aggiungendo una penalità KL all’obiettivo di massimizzazione della ricompensa completa tramite un modello di riferimento, che serve a impedire al modello di imparare a imbrogliare o sfruttare il modello di ricompensa.

La formulazione DPO evita il passaggio della modellazione della ricompensa e ottimizza direttamente il modello di linguaggio sui dati delle preferenze tramite una chiave di lettura: ovvero una mappatura analitica dalla funzione di ricompensa alla politica RL ottimale che consente agli autori di trasformare la perdita del RL sui modelli di ricompensa e di riferimento in una perdita solo sul modello di riferimento! Questa mappatura misura in modo intuitivo quanto bene una data funzione di ricompensa si allinea con i dati delle preferenze forniti. DPO inizia quindi con la soluzione ottimale per la perdita RLHF e, mediante un cambio di variabili, deriva una perdita solo sul modello di riferimento!

Di conseguenza, questo obiettivo di probabilità diretta può essere ottimizzato senza la necessità di un modello di ricompensa o del bisogno di eseguire l’ottimizzazione basata su RL potenzialmente complicata.

Come addestrare con TRL

Come accennato, tipicamente il flusso di lavoro RLHF consiste in queste parti distinte:

  1. un passaggio di affinamento supervisionato (SFT)
  2. il processo di annotazione dei dati con etichette di preferenza
  3. l’addestramento di un modello di ricompensa sui dati di preferenza
  4. e il passaggio di ottimizzazione RL

La libreria TRL fornisce assistenti per tutte queste parti, tuttavia l’addestramento DPO elimina la necessità della modellazione della ricompensa e del RL (passaggi 3 e 4) e ottimizza direttamente l’oggetto DPO sui dati con annotazione di preferenza.

In questo senso, dovremmo comunque eseguire il passaggio 1, ma anziché i passaggi 3 e 4, dobbiamo fornire a DPOTrainer in TRL i dati di preferenza dal passaggio 2, che hanno un formato molto specifico, ovvero un dizionario con le seguenti tre chiavi:

  • prompt che consiste nel prompt del contesto che viene fornito a un modello al momento dell’infrazione per la generazione di testo
  • chosen contiene la risposta generata preferita rispetto al prompt corrispondente
  • rejected contiene la risposta che non è preferita o non dovrebbe essere la risposta campionata rispetto al prompt dato

Ad esempio, per il dataset di coppie di preferenze di stack-exchange, possiamo mappare le voci del dataset per restituire il dizionario desiderato tramite l’assistente seguente e rimuovere tutte le colonne originali:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Domanda: " + question + "\n\nRisposta: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"],   # valutato meglio di k
        "rejected": samples["response_k"], # valutato peggiore di j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)

Una volta che abbiamo ordinato il dataset, la perdita DPO è essenzialmente una perdita supervisionata che ottiene una ricompensa implicita tramite un modello di riferimento e quindi a un livello elevato il DPOTrainer richiede il modello di base che desideriamo ottimizzare e un modello di riferimento:

dpo_trainer = DPOTrainer(
    model,                 # modello di base dalla pipeline SFT
    model_ref,             # tipicamente una copia del modello di base SFT allenato
    beta=0.1,              # iperparametro di temperatura di DPO
    train_dataset=dataset, # dataset preparato in precedenza
    tokenizer=tokenizer,   # tokenizer
    args=training_args,    # argomenti di allenamento come dimensione del batch, lr, ecc.
)

dove l’iperparametro beta è il parametro di temperatura per la perdita DPO, tipicamente nell’intervallo 0.1 a 0.5. Questo controlla quanto prestiamo attenzione al modello di riferimento nel senso che man mano che beta diventa più piccolo, ignoriamo sempre di più il modello di riferimento. Una volta che il nostro trainer è inizializzato, possiamo allenarlo sul dataset con gli training_args specificati semplicemente chiamando:

dpo_trainer.train()

Sperimenta con Llama v2

Il vantaggio di implementare il trainer DPO in TRL è che si può sfruttare tutti i vantaggi aggiuntivi dell’allenamento di grandi LLMs che arrivano con TRL e le sue librerie dipendenti come Peft e Accelerate. Con queste librerie siamo persino in grado di allenare un modello Llama v2 utilizzando la tecnica QLoRA fornita dalla libreria bitsandbytes.

Supervised Fine Tuning

Il processo come introdotto sopra coinvolge la fase di fine-tuning supervisionato utilizzando QLoRA sul modello Llama v2 da 7B di dati tramite il SFTTrainer di TRL:

# carica il modello di base con quantizzazione a 4 bit
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,        # "meta-llama/Llama-2-7b-hf"
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True,
    use_auth_token=True,
)
base_model.config.use_cache = False

# aggiungi strati LoRA sopra il modello di base quantizzato
peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args,         # argomenti di HF Trainer
)
trainer.train()

Allenamento DPO

Una volta che SFT è terminato, possiamo salvare il modello risultante e passare all’allenamento DPO. Come di consueto, utilizzeremo il modello salvato dal passaggio precedente di SFT sia come modello di base che come modello di riferimento di DPO. Quindi possiamo utilizzarli per allenare il modello con l’obiettivo DPO sui dati di preferenza di stack-exchange mostrati sopra. Poiché i modelli sono stati allenati tramite adattatori LoRa, carichiamo i modelli tramite gli assistenti AutoPeftModelForCausalLM di Peft:

model = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # posizione del modello SFT salvato
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path,  # stesso modello del principale
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

Come si può vedere, carichiamo il modello nella configurazione a 4 bit e poi lo addestriamo tramite il metodo QLora utilizzando gli argomenti peft_config. Il trainer valuterà anche i progressi durante l’addestramento rispetto all’insieme di dati di valutazione e fornirà diverse metriche chiave come la ricompensa implicita, che può essere registrata e visualizzata tramite WandB, ad esempio. Possiamo quindi caricare il modello addestrato finale su HuggingFace Hub.

Conclusioni

Il codice sorgente completo degli script di addestramento per SFT e DPO è disponibile nella seguente directory degli esempi/stack_llama_2 e il modello addestrato con gli adattatori uniti può essere trovato su HF Hub qui.

I log di WandB per l’esecuzione dell’addestramento DPO possono essere trovati qui, dove durante l’addestramento e la valutazione il DPOTrainer registra le seguenti metriche di ricompensa:

  • rewards/chosen: la differenza media tra le log probabilità del modello di politica e del modello di riferimento per le risposte scelte, scalata da beta
  • rewards/rejected: la differenza media tra le log probabilità del modello di politica e del modello di riferimento per le risposte respinte, scalata da beta
  • rewards/accuracies: media di quanto spesso le ricompense scelte sono superiori alle ricompense respinte corrispondenti
  • rewards/margins: la differenza media tra le ricompense scelte e le ricompense respinte corrispondenti.

In modo intuitivo, durante l’addestramento vogliamo che i margini aumentino e le accuratezze si avvicinino a 1.0, o in altre parole, che la ricompensa scelta sia più alta della ricompensa respinta (o il margine sia maggiore di zero). Queste metriche possono quindi essere calcolate su un certo insieme di dati di valutazione.

Speriamo che con il rilascio del codice si abbassi la barriera di accesso per voi lettori per provare questo metodo di allineamento di grandi modelli di linguaggio sui vostri stessi insiemi di dati e non vediamo l’ora di vedere cosa costruirete! E se volete provare il modello voi stessi, potete farlo qui: trl-lib/stack-llama.