Dynamically Rewired Delayed Message Passing GNNs

DRDMP GNNs.

Delayed Message Passing

Le reti neurali a grafo basate sul message-passing (MPNNs) tendono a soffrire del fenomeno di “over-squashing”, causando una riduzione delle prestazioni per i compiti che dipendono da interazioni a lungo raggio. Questo può essere in gran parte attribuito al fatto che il message-passing avviene solo localmente, sui vicini immediati di un nodo. Le tecniche tradizionali di riconfigurazione statica del grafo cercano tipicamente di contrastare questo effetto permettendo ai nodi distanti di comunicare istantaneamente (e nel caso estremo dei Transformers, rendendo tutti i nodi accessibili ad ogni livello). Tuttavia, ciò comporta un costo computazionale e viene a spese della rottura del bias induttivo fornito dalla struttura del grafo di input. In questo post, descriviamo due nuovi meccanismi per superare l’over-squashing mentre si compensano gli effetti collaterali delle tecniche di riconfigurazione statica: la riconfigurazione dinamica e il message-passing ritardato. Queste tecniche possono essere incorporate in qualsiasi MPNN e portare a prestazioni migliori rispetto ai Transformers sui compiti a lungo raggio.

Immagine: basata su Shutterstock.

Questo post è stato scritto in collaborazione con Francesco Di Giovanni e Ben Gutteridge e si basa sul paper di B. Gutteridge et al., DRew: Dynamically rewired message passing with delay (2023), ICML.

Le reti neurali a grafo basate sul message-passing classico (MPNNs) operano aggregando informazioni dai vicini a 1 hop di ogni nodo. Di conseguenza, i compiti di apprendimento che richiedono interazioni a lungo raggio (cioè esiste un nodo v la cui rappresentazione deve tenere conto delle informazioni contenute in un nodo u a distanza di cammino più breve (geodesica) d(u, v) = r>1) richiedono MPNN profonde con più strati di message-passing. Se la struttura del grafo è tale che il campo recettivo si espande esponenzialmente veloce con la distanza hop [1], potrebbe essere necessario “spremere” troppi messaggi in un vettore di feature del nodo fisso – un fenomeno noto come over-squashing [2].

Over-squashing

Nelle nostre precedenti opere [3-4], abbiamo formalizzato l’over-squashing come la mancanza di sensibilità dell’output MPNN in un nodo u all’input in un nodo a distanza r. Ciò può essere quantificato da un limite sulla derivata parziale (Jacobiano) della forma

|∂xᵤ⁽ʳ⁾/∂xᵥ⁽⁰⁾|<c(Aʳ)ᵤᵥ.

Qui c è una costante dipendente dall’architettura MPNN (ad esempio, regolarità Lipschitz della funzione di attivazione, profondità, ecc.) e A è la matrice di adiacenza normalizzata del grafo. L’over-squashing si verifica quando le voci di Aʳ decadono esponenzialmente veloci con la distanza r. Infatti, è ora noto che l’over-squashing è più generalmente un fenomeno che può essere correlato alla struttura locale del grafo (come la curvatura negativa [3]), o alla sua struttura globale al di là della distanza di cammino più breve (ad esempio, il tempo di percorrenza o la resistenza effettiva [4, 5]).

I poteri A ʳ nella precedente espressione riflettono il fatto che la comunicazione tra i nodi u e v a distanza r in un MPNN è una sequenza di interazioni tra nodi adiacenti che comprendono diversi percorsi che collegano u e v. Di conseguenza, i nodi u e v scambiano informazioni solo dalla r-esima layer in poi, e con una latenza pari alla loro distanza r. L’over-squashing è causato dal fatto che queste informazioni vengono “diluite” attraverso il passaggio ripetuto di messaggi su nodi intermedi lungo questi percorsi.

Ricollegamento del grafo

La questione dell’over-squashing può essere affrontata decouplando parzialmente la struttura del grafico di input da quella utilizzata come supporto per il calcolo dei messaggi, una procedura nota come ricollegamento del grafo [6]. Tipicamente, il ricollegamento viene eseguito come passaggio di pre-elaborazione in cui il grafo di input G viene sostituito con un altro grafo G’ che è “più amichevole” per il passaggio di messaggi, secondo una qualche misura di connettività spaziale o spettrale.

Il modo più semplice per raggiungere questo scopo consiste nel collegare tutti i nodi entro una certa distanza, consentendo loro di scambiare informazioni direttamente. Questa è l’idea alla base del sistema di passaggio di messaggi multi-hop [7]. I Graph Transformer [8] portano questo concetto all’estremo, collegando tutte le coppie di nodi attraverso un bordo ponderato da attenzione.

In questo modo, l’informazione non viene più “mescolata” con quella degli altri nodi lungo il percorso e l’over-squashing può essere evitato. Tuttavia, un tale ricollegamento rende il grafo molto più denso dalla prima layer, aumentando l’impronta computazionale e compromettendo in parte il bias induttivo garantito dal grafo di input, poiché i nodi locali e globali interagiscono in modo identico e istantaneo ad ogni layer.

Ricollegamento dinamico del grafo

Osservando il nostro precedente esempio di due nodi u e v a distanza r > 1, in un classico MPNN, bisogna aspettare r layer prima che u e v possano interagire, e questa interazione non è mai diretta. Sosteniamo invece che una volta raggiunta la layer r, i due nodi hanno ora aspettato “abbastanza a lungo” e possono quindi essere autorizzati a interagire direttamente (attraverso un bordo extra inserito, senza passare attraverso i vicini intermedi).

Di conseguenza, alla prima layer propaghiamo messaggi solo sui bordi del grafo di input (come nei classici MPNN), ma ad ogni layer successiva il campo recettivo del nodo u si espande di un hop [9]. Ciò consente ai nodi distanti di scambiare informazioni senza passaggi intermedi, preservando il bias induttivo garantito dalla topologia del grafo di input: il grafo viene gradualmente densificato nelle layer più profonde in base alla distanza.

Chiamiamo questo meccanismo ricollegamento dinamico del grafo, o DRew per brevità [10]. Le DRew-MPNN possono essere considerate il “punto di mezzo” tra i classici MPNN che agiscono localmente sul grafo di input e i Graph Transformer che considerano tutte le interazioni a coppie contemporaneamente.

Passaggio di messaggi ritardato

Nel classico MPNN, due nodi u e v a distanza r interagiscono sempre con un ritardo costante di r layer, il tempo minimo necessario per raggiungere un nodo dall’altro. Di conseguenza, il nodo v “vede” lo stato del nodo u (mescolato con le caratteristiche degli altri nodi) da r layer fa. Nelle DRew-MPNN invece, quando due nodi interagiscono, lo fanno istantaneamente, attraverso un bordo inserito, utilizzando il loro stato attuale.

Il passaggio di messaggi ritardato è un compromesso tra questi due casi estremi: aggiungiamo un ritardo globale (un iperparametro 𝝼) per i messaggi inviati tra i nodi.

Per semplicità, consideriamo qui due casi semplici: o nessun ritardo (come in DRew), o il caso del ritardo massimo, in cui due nodi u e v a distanza r interagiscono direttamente dalla layer r in poi, ma con un ritardo costante di r (come nei classici MPNN): alla layer r, il nodo u può scambiare informazioni con lo stato del nodo v come se fosse r layer prima [11].

Il ritardo controlla la velocità con cui le informazioni si propagano sul grafo. Nessun ritardo significa che i messaggi viaggiano più velocemente, con i nodi lontani che interagiscono istantaneamente una volta che viene aggiunto un bordo; viceversa, maggiore è il ritardo, più lenta è la propagazione delle informazioni, con i nodi lontani che accedono agli stati passati quando viene aggiunto un bordo.

A comparison of DRew and its delayed variant 𝝼DRew. On the left, nodes at distance r exchange information through an additional edge from layer r onwards, instantaneously. On the right, we show the case of maximal delay (in our paper corresponding to the case 𝝼 = 1), where the delay between two nodes coincides with their distance; the newly added edge between nodes at distance (layer) r looks “in the past” to access the state of a node as it was r layers ago.

Il framework 𝝼DRew

Chiamiamo architettura che combina il ricollegamento dinamico con il passaggio di messaggi ritardato 𝝼DRew (pronunciato “Andrew”).

Un modo per vedere 𝝼DRew è come un’architettura con skip-connection sparse, che consente ai messaggi di viaggiare non solo “orizontalmente” (tra i nodi del grafo all’interno della stessa layer, come nei classici MPNN) ma anche “verticalmente” (attraverso diverse layer). L’idea di affidarsi ai bordi verticali nei GNN non è nuova, e infatti si può pensare alle connessioni residue come a collegamenti verticali che connettono ciascun nodo allo stesso nodo alla layer precedente.

Il meccanismo di ritardo estende questo approccio creando bordi verticali che collegano un nodo u e un nodo v diverso in qualche layer precedente a seconda della distanza grafica tra u e v. In questo modo, possiamo sfruttare i vantaggi intrinseci delle skip-connection per le reti neurali profonde, condizionandole alle informazioni geometriche extra che abbiamo a disposizione sotto forma di distanza grafica.

𝝼DRew allevia l’over-squashing poiché i nodi lontani hanno ora accesso a percorsi multipli (più brevi) per scambiare informazioni, evitando la “diluizione delle informazioni” del passaggio di messaggi locali ripetuti. Diversamente dal ricollegamento statico, 𝝼DRew raggiunge questo effetto rallentando la densificazione del grafo e rendendolo dipendente dalla layer, riducendo così l’impronta di memoria.

𝝼DRew è adatto per esplorare il grafo a diverse velocità, gestire interazioni a lungo raggio e migliorare in generale la potenza di GNN molto profonde. Poiché 𝝼DRew determina dove e quando vengono scambiati i messaggi, ma non come, può essere visto come una meta-architettura che può aumentare i MPNN esistenti.

Risultati sperimentali

Nel nostro articolo [10], forniamo un’ampia comparazione di 𝝼DRew con i baselines dei classici MPNN, il ricollegamento statico e le architetture di tipo Transformer, utilizzando un budget di parametro fisso. Nel recente benchmark a lungo raggio (LRGB) introdotto da Vijay Dwivedi e coautori [11], 𝝼DRew batte nella maggior parte dei casi tutti i modelli sopracitati.

Comparison of various classical MPNNs (GCN, GINE, etc.), static graph rewiring (MixHop-GCN, DIGL), and graph Transformer-type architectures (Transformer, SAN, GraphGPS, including positional Laplacian encoding) with 𝝼DRew-MPNN variants on four Long-Range Graph Benchmark (LRGB) tasks. Green, orange, and purple represent first-, second-, and third-best models.

Uno studio di ablation su uno dei task LRGB rivela un’altra importante contribuzione del nostro framework: la capacità di regolare 𝝼 per adattarlo al task. Notiamo che maggiore è il ritardo utilizzato (valore più basso di 𝝼), migliori sono le performance per un grande numero di livelli L, mentre utilizzando meno ritardo (𝝼 alto) si assicura un riempimento più veloce del grafo computazionale e una maggiore densità di connessioni dopo meno livelli. Di conseguenza, nelle architetture superficiali (L piccolo), rimuovere completamente il ritardo (𝝼 = ∞) funziona meglio. Al contrario, nelle architetture profonde (L grande), più ritardo (𝝼 piccolo) “rallenta” la densificazione del grafo di passaggio dei messaggi, portando a migliori prestazioni.

Performance di 𝝼DRew-MPNNs con diverso numero di livelli L e diverso parametro di ritardo 𝝼. Mentre la riconnessione dinamica aiuta per i task a lungo raggio in tutti i regimi, il ritardo migliora significativamente le prestazioni sui modelli più profondi. Il nostro framework può anche essere controllato per il budget di calcolo/memoria a seconda dell'applicazione, ad esempio in situazioni in cui i Transformers sono computazionalmente intrattabili.

Le architetture di tipo MPNN tradizionale differiscono nel modo in cui vengono scambiati i messaggi [12]. Le tecniche di riconnessione del grafo aggiungono un livello di controllo su dove vengono inviati nel grafo. Il nostro nuovo approccio di riconnessione dinamica del grafo con passaggio di messaggi ritardato consente di controllare ulteriormente quando vengono scambiati i messaggi.

Questo approccio sembra essere molto potente e il nostro lavoro è solo un primo tentativo di sfruttare l’idea di accedere agli stati passati in una rete neurale grafica a seconda delle “proprietà geometriche” dei dati sottostanti. Speriamo che questo nuovo paradigma possa portare a framework più teoricamente fondati e sfidare l’idea che gli MPNN non siano in grado di risolvere i task a lungo raggio a meno che non siano integrati con strati di attenzione quadratici.

[1] Questo è tipico ad esempio nei grafi “a mondo piccolo” come le reti sociali.

[2] U. Alon and E. Yahav, Sull’aspetto bottleneck delle reti neurali grafiche e le sue implicazioni pratiche (2021), ICLR.

[3] J. Topping et al., Comprensione dell’oversquashing e dei bottleneck sui grafi attraverso la curvatura (2022), ICLR.

[4] Il tempo di percorrenza è il tempo atteso per un random walk per andare dal nodo v al nodo u e tornare. Vedere F. Di Giovanni et al., Sull’oversquashing nelle reti neurali di passaggio dei messaggi: l’impatto di larghezza, profondità e topologia (2023), ICML.

[5] Vedere Teorema 4.3 in F. Di Giovanni et al., Come l’oversquashing influenza la potenza delle GNN? (2023), arXiv:2306.03589.

[6] La riconnessione del grafo è una tecnica in qualche modo controversa nella comunità delle GNN poiché alcuni credono che il grafo di input sia sacro e non debba essere toccato. De facto, la maggior parte delle GNN moderne utilizza una forma di riconnessione del grafo, sia in modo esplicito (come passaggio di pre-elaborazione) o implicito (ad esempio, campionando i vicini o utilizzando nodi virtuali).

[7] R. Abboud, R. Dimitrov, e I. Ceylan, Reti del cammino più breve per la predizione delle proprietà del grafo (2022), arXiv:2206.01003.

[8] Vedi ad esempio V. P. Dwivedi e X. Bresson, Una generalizzazione delle reti Transformer ai grafi (2021), arXiv:2012.09699 e C. Ying et al., I Transformer hanno davvero prestazioni scadenti per la rappresentazione del grafo? (2021), NeurIPS.

[9] La riconnessione dinamica comporta la seguente formula di passaggio dei messaggi:

m ᵤ ⁽ ˡ ᵏ ⁾=AGG({ x ᵥ ⁽ ˡ ⁾ : v∈𝒩ₖ ( u )}) con 1 ≤ k ≤ l +1

x ᵤ ⁽ ˡ ⁺¹⁾=UP( x ᵤ ⁽ ˡ ⁾, m ᵤ ⁽ ˡ ¹⁾,…, m ᵤ ⁽ ˡ ˡ ⁺¹⁾)

dove AGG è un operatore di aggregazione invariante alla permutazione, 𝒩ₖ ( u ) è il vicinato di k-hop del nodo u e UP è un’operazione di aggiornamento che riceve messaggi da ogni k-hop separatamente. Vedere l’equazione 5 in [10].

[10] B. Gutteridge et al., DRew: Dynamically rewired message passing with delay (2023), ICML .

[11] Il passaggio dei messaggi ritardato assume la forma

m ᵤ ⁽ ˡ ᵏ ⁾=AGG({ x ᵥ ⁽ ˡ ᐨ ˢ ⁽ ᵏ ⁾ ⁾ : v∈𝒩ₖ ( u )}) con 1 ≤ k ≤ l +1

x ᵤ ⁽ ˡ ⁺¹⁾=UP( x ᵤ ⁽ ˡ ⁾, m ᵤ ⁽ ˡ ¹⁾,…, m ᵤ ⁽ ˡ ˡ ⁺¹⁾)

dove s ( k )=max{0, k ﹣𝝼}, vedere l’equazione 6 in [10]. La scelta 𝝼=∞ corrisponde a nessun ritardo (come in DRew) e 𝝼 = 1 corrisponde a MPNN classico (due nodi u e v a distanza r interagiscono direttamente dalla layer r in avanti, ma con un ritardo costante di r .

[11] V. P. Dwivedi et al. , Long range graph benchmark (2022), arXiv :2206.08164.

[12] Nel nostro proto-libro M. M. Bronstein et al., Geometric Deep Learning: Grids, Groups, Graphs, Geodesics, and Gauges (2021), distinguamo tra tre “gusti” di MPNN: convoluzionale, attentional e message passing generico.

Siamo grati a Federico Barbero , Fabrizio Frasca e Emanuele Rossi per la revisione di questo post e per i commenti illuminanti. Per ulteriori articoli su deep learning sui grafi, vedere gli altri post di Michael in Towards Data Science, iscriversi ai suoi post e al suo canale YouTube , ottenere la membership di Nisoo o seguirlo su Twitter .