Ottimizza ViT per la classificazione delle immagini con 🤗 Transformers

Ottimizza ViT per classificazione immagini con 🤗 Transformers.

Come i modelli basati su trasformatori hanno rivoluzionato l’NLP, stiamo assistendo ora a una vera e propria esplosione di articoli che li applicano a tutti i tipi di domini diversi. Uno dei più rivoluzionari di questi è stato il Vision Transformer (ViT), presentato in giugno 2021 da un team di ricercatori di Google Brain.

Questo articolo ha esplorato come è possibile tokenizzare le immagini, proprio come si farebbe con le frasi, in modo che possano essere passate ai modelli di trasformatori per l’addestramento. È un concetto abbastanza semplice, in realtà…

  1. Dividi un’immagine in una griglia di sotto-immagini
  2. Incorpora ogni sotto-immagine con una proiezione lineare
  3. Ogni sotto-immagine incorporata diventa un token e la sequenza risultante di sotto-immagini incorporate è la sequenza che si passa al modello.

Risulta che una volta che hai fatto quanto sopra, puoi pre-addestrare e affinare i trasformatori proprio come sei abituato a fare con i compiti di NLP. Piuttosto interessante 😎.


In questo post del blog, vedremo come sfruttare 🤗 datasets per scaricare e elaborare set di dati per la classificazione delle immagini e poi utilizzarli per affinare un ViT pre-addestrato con 🤗 transformers.

Per iniziare, installiamo entrambi i pacchetti.

pip install datasets transformers

Carica un dataset

Iniziamo caricando un piccolo dataset di classificazione delle immagini e dando un’occhiata alla sua struttura.

Utilizzeremo il dataset beans, che è una collezione di immagini di foglie di fagioli sane e malate. 🍃

from datasets import load_dataset

ds = load_dataset('beans')
ds

Diamo un’occhiata al 400° esempio dello split 'train' del dataset beans. Noterai che ogni esempio del dataset ha 3 caratteristiche:

  1. image: Un’immagine PIL
  2. image_file_path: Il percorso str del file immagine che è stato caricato come image
  3. labels: Una caratteristica datasets.ClassLabel, che è una rappresentazione intera dell’etichetta. (Successivamente vedrai come ottenere i nomi di classe come stringhe, non preoccuparti!)
ex = ds['train'][400]
ex

{
  'image': <PIL.JpegImagePlugin ...>,
  'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
  'labels': 1
}

Diamo un’occhiata all’immagine 👀

image = ex['image']
image

È sicuramente una foglia! Ma di che tipo? 😅

Dato che la caratteristica 'labels' di questo dataset è una datasets.features.ClassLabel, possiamo utilizzarla per cercare il nome corrispondente all’ID dell’etichetta di questo esempio.

Prima di tutto, accediamo alla definizione della caratteristica per 'labels'.

labels = ds['train'].features['labels']
labels

ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)

Ora, stampiamo l’etichetta di classe per il nostro esempio. Puoi farlo utilizzando la funzione int2str di ClassLabel, che, come suggerisce il nome, consente di passare la rappresentazione intera della classe per cercare l’etichetta come stringa.

labels.int2str(ex['labels'])

'bean_rust'

Risulta che la foglia mostrata sopra è infetta da Ruggine del fagiolo, una grave malattia delle piante di fagiolo. 😢

Scriviamo una funzione che visualizzerà una griglia di esempi per ogni classe per avere una migliore idea di ciò con cui stai lavorando.

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # Filtra il dataset per una singola etichetta, mescola e prendi alcuni esempi
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        # Disegna gli esempi di questa etichetta lungo una riga
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)

Una griglia di alcuni esempi da ogni classe nel dataset

Dalla mia osservazione:

  • Macchia fogliare angolare: ha macchie marroni irregolari
  • Ruggine del fagiolo: ha macchie marroni circolari circondate da un anello giallo chiaro
  • Sano: …sembra sano. 🤷‍♂️

Caricamento dell’estrattore di caratteristiche ViT

Ora sappiamo come sono fatte le nostre immagini e comprendiamo meglio il problema che stiamo cercando di risolvere. Vediamo come possiamo preparare queste immagini per il nostro modello!

Quando i modelli ViT vengono addestrati, vengono applicate specifiche trasformazioni alle immagini che vengono loro fornite. Utilizzare trasformazioni errate sull’immagine farà sì che il modello non comprenda ciò che sta vedendo! 🖼 ➡️ 🔢

Per assicurarci di applicare le corrette trasformazioni, utilizzeremo un ViTFeatureExtractor inizializzato con una configurazione che è stata salvata insieme al modello preaddestrato che intendiamo utilizzare. Nel nostro caso, utilizzeremo il modello google/vit-base-patch16-224-in21k, quindi carichiamo il suo estrattore di caratteristiche dall’Hugging Face Hub.

from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

Puoi visualizzare la configurazione dell’estrattore di caratteristiche stampandola.

ViTFeatureExtractor {
  "do_normalize": true,
  "do_resize": true,
  "feature_extractor_type": "ViTFeatureExtractor",
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}

Per elaborare un’immagine, basta passarla alla funzione di chiamata dell’estrattore di caratteristiche. Questo restituirà un dizionario contenente valori dei pixel, che rappresentano la rappresentazione numerica da passare al modello.

Di default, ottieni un array NumPy, ma se aggiungi l’argomento return_tensors='pt', otterrai invece tensori torch.

feature_extractor(image, return_tensors='pt')

Dovrebbe darti qualcosa del genere…

{
  'pixel_values': tensor([[[[ 0.2706,  0.3255,  0.3804,  ...]]]])
}

…dove la forma del tensore è (1, 3, 224, 224).

Elaborazione del dataset

Ora che sai come leggere le immagini e trasformarle in input, scriviamo una funzione che metterà insieme queste due cose per elaborare un singolo esempio del dataset.

def process_example(example):
    inputs = feature_extractor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs

process_example(ds['train'][0])

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ...]]]]),
  'labels': 0
}

Anche se potresti chiamare ds.map e applicarlo a ogni esempio in una volta sola, questo può essere molto lento, specialmente se utilizzi un dataset più grande. Invece, puoi applicare una trasformazione al dataset. Le trasformazioni vengono applicate solo agli esempi quando vengono indicizzati.

Tuttavia, dovrai aggiornare l’ultima funzione per accettare un batch di dati, poiché è ciò che si aspetta ds.with_transform.

ds = load_dataset('beans')

def transform(example_batch):
    # Prendi una lista di immagini PIL e trasformale in valori dei pixel
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Non dimenticare di includere le etichette!
    inputs['labels'] = example_batch['labels']
    return inputs

Puoi applicare direttamente questo al dataset usando ds.with_transform(transform).

prepared_ds = ds.with_transform(transform)

Ora, ogni volta che ottieni un esempio dal dataset, la trasformazione verrà applicata in tempo reale (sia su campioni che su fette, come mostrato di seguito)

prepared_ds['train'][0:2]

Questa volta, il tensore risultante pixel_values avrà una forma (2, 3, 224, 224).

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': [0, 0]
}

I dati sono stati elaborati e sei pronto per iniziare a configurare la pipeline di addestramento. Questo post utilizza il Trainer di 🤗, ma ciò richiede alcune operazioni preliminari:

  • Definire una funzione di raccolta dei dati (collate function).

  • Definire una metrica di valutazione. Durante l’addestramento, il modello dovrebbe essere valutato in base alla sua accuratezza nelle previsioni. È necessario definire una funzione compute_metrics di conseguenza.

  • Caricare un checkpoint preaddestrato. È necessario caricare un checkpoint preaddestrato e configurarlo correttamente per l’addestramento.

  • Definire la configurazione di addestramento.

Dopo il fine-tuning del modello, lo valuterai correttamente sui dati di valutazione e verificherai che abbia imparato a classificare correttamente le immagini.

Definire il nostro data collator

I batch arrivano come liste di dizionari, quindi è possibile scomporli e unirli in tensori di batch.

Poiché la funzione collate_fn restituirà un dizionario di batch, è possibile scomporre gli input al modello in seguito utilizzando **unpack. ✨

import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

Definire una metrica di valutazione

La metrica di accuratezza di datasets può essere facilmente utilizzata per confrontare le previsioni con le etichette. Di seguito, puoi vedere come utilizzarla all’interno di una funzione compute_metrics che sarà utilizzata dal Trainer.

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

Carichiamo il modello preaddestrato. Aggiungeremo num_labels durante l’inizializzazione in modo che il modello crei una testa di classificazione con il numero corretto di unità. Includeremo anche i mapping id2label e label2id per avere etichette leggibili dall’utente nel widget Hub (se si sceglie di push_to_hub).

from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Quasi pronto per l’addestramento! L’ultima cosa necessaria prima di ciò è configurare la configurazione di addestramento definendo TrainingArguments.

La maggior parte di queste sono abbastanza autoesplicative, ma una che è piuttosto importante qui è remove_unused_columns=False. Questa opzione eliminerà qualsiasi caratteristica non utilizzata dalla funzione di chiamata del modello. Per impostazione predefinita è True perché di solito è ideale eliminare le colonne di caratteristiche non utilizzate, semplificando lo scomporre gli input nella funzione di chiamata del modello. Ma, nel nostro caso, abbiamo bisogno delle caratteristiche non utilizzate (‘image’ in particolare) per creare ‘pixel_values’.

Quello che sto cercando di dire è che avrai problemi se dimentichi di impostare remove_unused_columns=False.

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

Ora, tutte le istanze possono essere passate al Trainer e siamo pronti per iniziare l’addestramento!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

Train 🚀

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Evaluate 📊

metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

Ecco i risultati della mia valutazione – Fagioli fantastici! Scusate, dovevo dirlo.

***** metriche di valutazione *****
  epoca                   =        4.0
  accuracy di valutazione =      0.985
  perdita di valutazione  =     0.0637
  tempo di valutazione    = 0:00:02.13
  campioni al secondo di valutazione =     62.356
  passi al secondo di valutazione   =       7.97

Infine, se vuoi, puoi caricare il tuo modello sul hub. Qui, lo caricheremo se hai specificato push_to_hub=True nella configurazione dell’addestramento. Nota che per caricare sul hub, dovrai avere git-lfs installato e essere loggato nel tuo account Hugging Face (cosa che puoi fare tramite huggingface-cli login).

kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'beans',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)

Il modello risultante è stato condiviso su nateraw/vit-base-beans. Presumo che non hai foto di foglie di fagiolo in giro, quindi ne ho aggiunte alcune esempi per darti la possibilità di provarlo! 🚀