Segmentazione delle immagini senza etichette con CLIPSeg

Segmentazione immagini senza etichette con CLIPSeg.

Questa guida mostra come puoi utilizzare CLIPSeg, un modello di segmentazione delle immagini a zero-shot, utilizzando 🤗 transformers. CLIPSeg crea maschere di segmentazione approssimative che possono essere utilizzate per la percezione dei robot, il riempimento delle immagini e molti altri compiti. Se hai bisogno di maschere di segmentazione più precise, ti mostreremo come puoi affinare i risultati di CLIPSeg su Segments.ai.

La segmentazione delle immagini è un compito ben noto nel campo della visione artificiale. Consente a un computer di non solo sapere cosa c’è in un’immagine (classificazione), dove si trovano gli oggetti nell’immagine (rilevamento), ma anche quali sono i contorni di quegli oggetti. Conoscere i contorni degli oggetti è essenziale in campi come la robotica e la guida autonoma. Ad esempio, un robot deve conoscere la forma di un oggetto per prenderlo correttamente. La segmentazione può anche essere combinata con l’inpainting delle immagini per consentire agli utenti di descrivere quale parte dell’immagine desiderano sostituire.

Una limitazione della maggior parte dei modelli di segmentazione delle immagini è che funzionano solo con un elenco fisso di categorie. Ad esempio, non puoi semplicemente utilizzare un modello di segmentazione addestrato su arance per segmentare mele. Per insegnare al modello di segmentazione una categoria aggiuntiva, devi etichettare i dati della nuova categoria e addestrare un nuovo modello, il che può essere costoso e richiedere molto tempo. Ma cosa succederebbe se ci fosse un modello che potesse già segmentare quasi ogni tipo di oggetto, senza alcun ulteriore addestramento? Ecco esattamente ciò che CLIPSeg, un modello di segmentazione a zero-shot, riesce a fare.

Attualmente, CLIPSeg ha ancora le sue limitazioni. Ad esempio, il modello utilizza immagini di 352 x 352 pixel, quindi l’output è di bassa risoluzione. Ciò significa che non possiamo aspettarci risultati perfetti al livello dei pixel quando lavoriamo con immagini provenienti da telecamere moderne. Se vogliamo segmentazioni più precise, possiamo raffinare un modello di segmentazione all’avanguardia, come mostrato nel nostro precedente post sul blog. In quel caso, possiamo comunque utilizzare CLIPSeg per generare alcune etichette approssimative, e poi perfezionarle in uno strumento di etichettatura come Segments.ai. Prima di descrivere come fare, diamo un’occhiata a come funziona CLIPSeg.

CLIP: il modello magico dietro CLIPSeg

CLIP, che sta per Contrastive Language- Image Pre-training, è un modello sviluppato da OpenAI nel 2021. Puoi fornire a CLIP un’immagine o un pezzo di testo e CLIP restituirà una rappresentazione astratta del tuo input. Questa rappresentazione astratta, chiamata anche embedding, è semplicemente un vettore (una lista di numeri). Puoi pensare a questo vettore come a un punto nello spazio ad alta dimensionalità. CLIP è addestrato in modo che le rappresentazioni di immagini e testi simili siano anche simili. Ciò significa che se inseriamo un’immagine e una descrizione testuale che si adatta a quell’immagine, le rappresentazioni dell’immagine e del testo saranno simili (ossia i punti ad alta dimensionalità saranno vicini tra loro).

All’inizio, potrebbe non sembrare molto utile, ma è in realtà molto potente. Ad esempio, diamo un’occhiata veloce a come CLIP può essere utilizzato per classificare le immagini senza aver mai ricevuto un addestramento su quel compito. Per classificare un’immagine, inseriamo l’immagine e le diverse categorie tra cui vogliamo scegliere a CLIP (ad esempio, inseriamo un’immagine e le parole “mela”, “arancia”, …). CLIP quindi ci restituisce un embedding dell’immagine e di ogni categoria. Ora, dobbiamo semplicemente verificare quale embedding di categoria è più vicino all’embedding dell’immagine, et voilà! Sembra magia, vero?

Esempio di classificazione delle immagini utilizzando CLIP (fonte).

Inoltre, CLIP non è utile solo per la classificazione, ma può anche essere utilizzato per la ricerca di immagini (puoi vedere come questo è simile alla classificazione?), modelli di testo-immagine (DALL-E 2 è alimentato da CLIP), rilevamento degli oggetti (OWL-ViT) e, soprattutto per noi, la segmentazione delle immagini. Ora capisci perché CLIP è stato veramente una svolta nell’apprendimento automatico.

Il motivo per cui CLIP funziona così bene è che il modello è stato addestrato su un enorme dataset di immagini con didascalie testuali. Il dataset conteneva ben 400 milioni di coppie immagine-testo prese da Internet. Queste immagini contengono una vasta varietà di oggetti e concetti, e CLIP è bravo a creare una rappresentazione per ognuno di essi.

CLIPSeg: segmentazione delle immagini con CLIP

CLIPSeg è un modello che utilizza le rappresentazioni CLIP per creare maschere di segmentazione delle immagini. È stato pubblicato da Timo Lüddecke e Alexander Ecker. Hanno ottenuto la segmentazione delle immagini senza addestramento, addestrando un decodificatore basato su Transformer in cima al modello CLIP, che viene mantenuto congelato. Il decodificatore prende in input la rappresentazione CLIP di un’immagine e la rappresentazione CLIP dell’oggetto che si desidera segmentare. Utilizzando questi due input, il decodificatore CLIPSeg crea una maschera di segmentazione binaria. Per essere più precisi, il decodificatore non utilizza solo la rappresentazione CLIP finale dell’immagine che vogliamo segmentare, ma utilizza anche le uscite di alcuni strati di CLIP.

Fonte

Il decodificatore è addestrato sul dataset PhraseCut, che contiene oltre 340.000 frasi con maschere di segmentazione delle immagini corrispondenti. Gli autori hanno anche sperimentato diverse augmentations per espandere la dimensione del dataset. L’obiettivo qui non è solo quello di essere in grado di segmentare le categorie presenti nel dataset, ma anche di segmentare categorie non viste. Gli esperimenti mostrano infatti che il decodificatore può generalizzare a categorie non viste.

Una caratteristica interessante di CLIPSeg è che sia la query (l’immagine che vogliamo segmentare) che il prompt (la cosa che vogliamo segmentare nell’immagine) vengono inseriti come embedding CLIP. L’embedding CLIP per il prompt può provenire da un pezzo di testo (il nome della categoria), o da un’altra immagine. Ciò significa che è possibile segmentare le arance in una foto fornendo a CLIPSeg un’immagine di esempio di un’arancia.

Questa tecnica, chiamata “visual prompting”, è molto utile quando la cosa che si desidera segmentare è difficile da descrivere. Ad esempio, se si desidera segmentare un logo in una foto di una maglietta, non è facile descrivere la forma del logo, ma CLIPSeg consente di utilizzare semplicemente l’immagine del logo come prompt.

Il paper di CLIPSeg contiene alcuni suggerimenti per migliorare l’efficacia del visual prompting. Si è scoperto che ritagliare l’immagine della query (in modo che contenga solo l’oggetto che si desidera segmentare) aiuta molto. Sfocare e scurire lo sfondo dell’immagine della query aiuta anche un po’. Nella prossima sezione, mostreremo come è possibile provare il visual prompting utilizzando 🤗 transformers.

Utilizzo di CLIPSeg con Hugging Face Transformers

Utilizzando Hugging Face Transformers, è possibile scaricare ed eseguire facilmente un modello CLIPSeg pre-addestrato sulle proprie immagini. Iniziamo installando transformers.

!pip install -q transformers

Per scaricare il modello, basta istanziarlo.

from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

Adesso possiamo caricare un’immagine per provare la segmentazione. Sceglieremo una foto di una deliziosa colazione scattata da Calum Lewis.

from PIL import Image
import requests

url = "https://unsplash.com/photos/8Nc_oQsc2qQ/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjcxMjAwNzI0&force=true&w=640"
image = Image.open(requests.get(url, stream=True).raw)
image

Text prompting

Iniziamo definendo alcune categorie di testo che vogliamo segmentare.

prompts = ["posate", "pancakes", "mirtilli", "succo d'arancia"]

Ora che abbiamo i nostri input, possiamo elaborarli e inserirli nel modello.

import torch

inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
# predict
with torch.no_grad():
  outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)

Infine, visualizziamo l’output.

import matplotlib.pyplot as plt

_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];

Suggerimenti visivi

Come già accennato, possiamo anche utilizzare immagini come suggerimenti di input (cioè al posto dei nomi delle categorie). Questo può essere particolarmente utile se non è facile descrivere ciò che si vuole segmentare. Per questo esempio, useremo una foto di una tazza di caffè scattata da Daniel Hooper.

url = "https://unsplash.com/photos/Ki7sAc8gOGE/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8MTJ8fGNvZmZlJTIwdG8lMjBnb3xlbnwwfHx8fDE2NzExOTgzNDQ&force=true&w=640"
prompt = Image.open(requests.get(url, stream=True).raw)
prompt

Ora possiamo elaborare l’immagine di input e l’immagine di prompt e inserirle nel modello.

encoded_image = processor(images=[image], return_tensors="pt")
encoded_prompt = processor(images=[prompt], return_tensors="pt")
# predici
with torch.no_grad():
  outputs = model(**encoded_image, conditional_pixel_values=encoded_prompt.pixel_values)
preds = outputs.logits.unsqueeze(1)
preds = torch.transpose(preds, 0, 1)

Successivamente, possiamo visualizzare i risultati come prima.

_, ax = plt.subplots(1, 2, figsize=(6, 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
ax[1].imshow(torch.sigmoid(preds[0]))

Proviamo un’ultima volta utilizzando i suggerimenti visivi descritti nel documento, cioè ritagliando l’immagine e scurendo lo sfondo.

url = "https://i.imgur.com/mRSORqz.jpg"
alternative_prompt = Image.open(requests.get(url, stream=True).raw)
alternative_prompt

encoded_alternative_prompt = processor(images=[alternative_prompt], return_tensors="pt")
# predici
with torch.no_grad():
  outputs = model(**encoded_image, conditional_pixel_values=encoded_alternative_prompt.pixel_values)
preds = outputs.logits.unsqueeze(1)
preds = torch.transpose(preds, 0, 1)

_, ax = plt.subplots(1, 2, figsize=(6, 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
ax[1].imshow(torch.sigmoid(preds[0]))

In questo caso, il risultato è praticamente lo stesso. Probabilmente perché la tazza di caffè era già ben separata dallo sfondo nell’immagine originale.

Utilizzo di CLIPSeg per etichettare preimmagini su Segments.ai

Come puoi vedere, i risultati di CLIPSeg sono un po’ sfocati e di bassa risoluzione. Se vogliamo ottenere risultati migliori, è possibile perfezionare un modello di segmentazione all’avanguardia, come spiegato nel nostro precedente post sul blog. Per perfezionare il modello, avremo bisogno di dati etichettati. In questa sezione, ti mostreremo come puoi utilizzare CLIPSeg per creare alcune maschere di segmentazione approssimative e poi perfezionarle su Segments.ai, una piattaforma di etichettatura con strumenti intelligenti per la segmentazione delle immagini.

Prima di tutto, crea un account su https://segments.ai/join e installa il Segments Python SDK. Successivamente, puoi inizializzare il client Python di Segments.ai utilizzando una chiave API. Questa chiave può essere trovata nella pagina dell’account.

!pip install -q segments-ai

from segments import SegmentsClient
from getpass import getpass

api_key = getpass('Inserisci la tua chiave API: ')
segments_client = SegmentsClient(api_key)

Successivamente, carichiamo un’immagine da un dataset utilizzando il client Segments. Utilizzeremo il dataset di guida autonoma a2d2. Puoi anche creare il tuo dataset seguendo queste istruzioni.

samples = segments_client.get_samples("admin-tobias/clipseg")

# Utilizza l'ultima immagine come esempio
sample = samples[1]
image = Image.open(requests.get(sample.attributes.image.url, stream=True).raw)
image

Dobbiamo anche ottenere i nomi delle categorie dagli attributi del dataset.

dataset = segments_client.get_dataset("admin-tobias/clipseg")
category_names = [category.name for category in dataset.task_attributes.categories]

Ora possiamo utilizzare CLIPSeg sull’immagine come prima. Questa volta, ingrandiremo anche le uscite in modo che corrispondano alle dimensioni dell’immagine di input.

from torch import nn

inputs = processor(text=category_names, images=[image] * len(category_names), padding="max_length", return_tensors="pt")

# predici
con torch.no_grad():
  outputs = model(**inputs)

# ridimensiona gli output
preds = nn.functional.interpolate(
    outputs.logits.unsqueeze(1),
    size=(image.size[1], image.size[0]),
    mode="bilinear"
)

E possiamo visualizzare nuovamente i risultati.

len_cats = len(category_names)
_, ax = plt.subplots(1, len_cats + 1, figsize=(3*(len_cats + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len_cats)];
[ax[i+1].text(0, -15, category_name) for i, category_name in enumerate(category_names)];

Ora dobbiamo combinare le predizioni in un’unica immagine segmentata. Faremo semplicemente prendendo la categoria con il valore di sigmoid maggiore per ogni patch. Ci assicureremo anche che tutti i valori al di sotto di una certa soglia non vengano conteggiati.

soglia = 0.1

flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))

# Inizializza una maschera fittizia "non etichettata" con la soglia
flat_preds_with_treshold = torch.full((preds.shape[0] + 1, flat_preds.shape[-1]), soglia)
flat_preds_with_treshold[1:preds.shape[0]+1,:] = flat_preds

# Ottieni l'indice della maschera superiore per ogni pixel
inds = torch.topk(flat_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))

Visualizziamo rapidamente il risultato.

plt.imshow(inds)

Infine, possiamo caricare la predizione su Segments.ai. Per farlo, convertiremo prima la bitmap in un file png, poi caricheremo questo file su Segments, e infine aggiungeremo l’etichetta al campione.

from segments.utils import bitmap2file
import numpy as np

inds_np = inds.numpy().astype(np.uint32)
unique_inds = np.unique(inds_np).tolist()
f = bitmap2file(inds_np, is_segmentation_bitmap=True)

asset = segments_client.upload_asset(f, "clipseg_prediction.png")

attributes = {
      'format_version': '0.1',
      'annotations': [{"id": i, "category_id": i} for i in unique_inds if i != 0],
      'segmentation_bitmap': { 'url': asset.url },
  }

segments_client.add_label(sample.uuid, 'ground-truth', attributes)

Se dai un’occhiata alla predizione caricata su Segments.ai, puoi vedere che non è perfetta. Tuttavia, puoi correggere manualmente gli errori più grandi, e poi puoi utilizzare il dataset corretto per addestrare un modello migliore rispetto a CLIPSeg.

Conclusioni

CLIPSeg è un modello di segmentazione zero-shot che funziona con prompt di testo e immagini. Il modello aggiunge un decoder a CLIP e può segmentare praticamente qualsiasi cosa. Tuttavia, le maschere di segmentazione prodotte sono ancora di bassa risoluzione per ora, quindi probabilmente vorrai comunque affinare un modello di segmentazione diverso se l’accuratezza è importante.

Nota che attualmente sono in corso ulteriori ricerche sulla segmentazione zero-shot, quindi è possibile che vengano aggiunti altri modelli in futuro. Un esempio è GroupViT, che è già disponibile in 🤗 Transformers. Per rimanere aggiornato sulle ultime novità nella ricerca sulla segmentazione, puoi seguirci su Twitter: @TobiasCornille , @NielsRogge e @huggingface .

Se sei interessato a imparare come affinare un modello di segmentazione all’avanguardia, dai un’occhiata al nostro precedente post sul blog: https://huggingface.co/blog/fine-tune-segformer .