Dark mode switch icon Light mode switch icon

Why I'm lukewarm on graph neural networks

17 min read

TL;DR: GNNs can provide wins over simpler embedding methods, but we’re at a point where other research directions matter more

I’m only lukewarm on Graph Neural Networks (GNNs). There, I said it.

It might sound crazy - GNNs are one of the hottest fields in machine learning right now. There were at least four review papers just in the last few months. I think some progress can come of this research, but we’re also focusing on some incorrect places.

But first, let’s take a step back and go over the basics.

Models are about compression

We say graphs are a “non-euclidean” data type, but that’s not really true. A regular graph [footnote]Nodes and (un)directed edges with real-valued edge weights[/footnote] is just another way to think about a particular flavor of square matrix called the adjacency matrix:

It’s weird, we look at run-of-the-mill matrix full of real numbers and decide to call it “non-euclidean”.

This is for practical reasons. Most graphs are fairly sparse, so the matrix is full of zeros. At this point, where the non-zero numbers are matters most, which makes the problem closer to (computationally hard) discrete math rather than (easy) continuous, gradient-friendly math.

If you had the full matrix, life would be easy

If we step out of the pesky realm of physics for a minute, and assume carrying the full adjacency matrix around isn’t a problem, we solve a bunch of problems.

First, network node embeddings [footnote]where you represent a node as a vector of numbers to bypass the annoyances of discrete math[/footnote] aren’t a thing anymore. A node is a just row in the matrix, so it’s already a vector of numbers.

Second, all network prediction problems are solved. A powerful enough and well-tuned model [footnote]Physics aren’t a thing, remember[/footnote] will simply extract all information between the network and whichever target variable we’re attaching to nodes.

NLP is also just fancy matrix compression

Let’s take a tangent away from graphs to NLP. Most NLP we do can be thought of in terms of graphs as we’ll see, so it’s not a big digression.

First, note that Ye Olde word embedding models like Word2Vec and GloVe are just matrix factorization.

The GloVe algorithm works on a variation of the old bag of words matrix. It goes through the sentences and creates a (implicit) co-occurence graph where nodes are words and the edges are weighed by how often the words appear together in a sentence.

Glove then does matrix factorization [footnote]I’ll use the word “factorization” loosely for the rest of this post to refer to any method that compresses the number of dimensions in a matrix[/footnote] on the matrix representation of that co-occurence graph, Word2Vec is mathematically equivalent[footnote]see section 3.1 in the GloVe paper for confirmation[/footnote].

You can read more on this in my post on embeddings and the one (with code) on word embeddings.

Even language models are also just matrix compression

Language models are all the rage. They dominate most of the state of the art in NLP.

Let’s take BERT as our main example[footnote]the analysis holds for GPT-3 or any other modern language model[/footnote]. BERT predicts a word given the context of the rest of the sentence:

This grows the matrix we’re factoring from flat co-occurences on pairs of words to co-occurences conditional on the sentence’s context[footnote]Alternatively, we’re changing from a graph of word co-occurence to a hypergraph with directed hyperedges denoting conditional co-occurence[/footnote]:

We’re growing the “ideal matrix” we’re factoring combinatorially [footnote]if there’s one thing to be learned from combinatorics it’s that you make huge numbers really fast[/footnote]. As noted by Hanh & Futrell:

[…] human language—and language modelling—has infinite statistical complexity but that it can be approximated well at lower levels. This observation has two implications: 1) We can obtain good results with comparatively small models; and 2) there is a lot of potential for scaling up our models.

Language models tackle such a large problem space that they probably approximate a compression of the entire language in the Kolmogorov Complexity sense. It’s also possible that huge language models just memorize a lot of it rather than compress the information, for what it’s worth.

Can we upsample any graph like language models do?

We’re already doing it.

Let’s call a first-order embedding of a graph a method that works by directly factoring the graph’s adjacency matrix or Laplacian matrix. If you embed a graph using Laplacian Eigenmaps or by taking the principal components of the Laplacian, that’s first order. Similarly, GloVe is a first-order method on the graph of word co-occurences. One of my favorites first order methods for graphs is ProNE, which works as well as most methods while being two orders of magnitude faster.

A higher-order method embeds the original matrix plus connections of neighbours-of-neighbours (2nd degree) and deeper k-step connections. GraRep, shows you can always generate higher-order representations from first order methods by augmenting the graph matrix [footnote]You normalize the adjacency matrix into a random walk transition probability matrix. Then, you take successive dot products of the transition matrix with itself[/footnote].

Higher order method are the “upsampling” we do on graphs. GNNs that sample on large neighborhoods and random-walk based methods like node2vec are doing higher-order embeddings.

Where are the performance gain?

Most GNN papers in the last 5 years present empirical numbers that are useless for practitioners to decide on what to use.

As noted in the OpenGraphsBenchmark (OGB) paper, GNN papers do their empirical section on a handful of tiny graphs (Cora, CiteSeer, PubMed) with 2000-20,000 nodes. These datasets can’t seriously differentiate between methods.

Recent efforts are directly fixing this, but the reasons why researchers focused on tiny, useless datasets for so long are worth discussing.

Performance matters by task

One fact that surprises a lot of people is that even though language models have the best performance in a lot of NLP tasks, if all you’re doing is cram sentence embeddings into a downstream model, there isn’t much gained from language models embeddings over simple methods like summing the individual Word2Vec word embeddings [footnote]This makes sense, because the full context of the sentence is captured in the sentence co-occurence matrix that is generating the Word2Vec embeddings[/footnote].

Similarly, I find that for many graphs simple first-order methods perform just as well on graph clustering and node label prediction tasks than higher-order embedding methods. In fact higher-order methods are massively computationally wasteful for these usecases [footnote]I’ll generalize this to the OGB dataset ASAP[/footnote]

Recommended first order embedding methods are ProNE and my GGVec with order=1 [footnote]It’s an adaptation of GloVe with negative sampling for general graphs. I’ll get to writing the blog post/paper on it one day[/footnote].

Higher order methods normally perform better on the link prediction tasks[footnote]I’m not the only one to find this. In the BioNEV paper, they find: “A large GraRep order value for link prediction tasks (e.g. 3, 4);a small value for node classification tasks (e.g.1, 2)” (p.9).[/footnote].

Interestingly, the gap in link prediction performance is inexistant for artificially created graphs. This suggests higher order methods do learn some of the structure intrinsic to real world graphs.

For visualization, first order methods are better. Visualizations of higher order methods tend to have artifacts of their sampling. For instance, Node2Vec visualizations tend to have elongated/filament-like structures which come from the embeddings coming from long single strand random walks. See the following visualizations by Owen Cornec [footnote]The visualizations are created by first embedding the graph to 32-300 dimensions using a node embedding algorithm, then mapping this to 2d or 3d with the excellent UMAP algorithm[/footnote]:

Lastly, sometimes simple methods soundly beat higher order methods (there’s an instance of it in the OGB paper).

The problem here is that we don’t know when any method is better than another and we definitely don’t know the reason.

There’s definitely a reason different graph types respond better/worse to being represented by various methods. This is currently an open question.

A big part of why is that the research space is inundated under useless new algorithms because…

Academic incentives work against progress

Here’s the cynic’s view of how machine learning papers are made:

  1. Take an existing algorithm
  2. Add some new layer/hyperparameter, make a cute mathematical story for why it matters
  3. Gridsearch your hyperparameters until you beat baselines from the original paper you aped
  4. Absolutely don’t gridsearch stuff you’re comparing against in your results section
  5. Make a cute ACRONYM for your new method, put impossible to use python 2 code on github [footnote]or no code at all![/footnote] and bask in the citations

I’m not the only one with these views on the state reproducible research. At least it’s gotten slightly better in the last 2 years.

A side project of mine is a node embedding library and the most popular method in it is by far Node2Vec. Don’t use Node2Vec.

Node2Vec with p=1; q=1 is the Deepwalk algorithm. Deepwalk is an actual innovation.

The Node2Vec authors closely followed the steps 1-5 including bonus points on step 5 by getting word2vec name recognition.

This is not academic fraud – the hyperparameters do help a tiny bit if you gridsearch really hard. But it’s the presentable-to-your-parents sister of where you make the ML community worse off to progress your academic career. And certainly Node2Vec doesn’t deserve 7500 citations.

Progress is all about practical issues

We’ve known how to train neural networks for well over 40 years. Yet they only exploded in popularity with AlexNet in 2012. This is because implementations and hardware came to a point where deep learning was practical.

Similarly, we’ve known about factoring word co-occurence matrices into Word embeddings for at least 20 years.

But word embeddings only exploded in 2013 with Word2Vec. The breakthrough here was that the minibatch-based methods let you train a Wikipedia-scale embedding model on commodity hardware[footnote]In retrospect, GloVe style methods were always available, but we didn’t know that word embeddings became so useful when trained at large scale. So the motivation to invent something like GloVe may have been missing until Word2Vec show the utility[/footnote].

It’s hard for methods in a field to make progress if training on a small amount of data takes days or weeks. You’re disincentivized to explore new methods. If you want progress, your stuff has to run in reasonable time on commodity hardware. Even Google’s original search algorithm initially ran on commodity hardware [footnote]Before anyone interjects about the recent massive language models, there’s not much innovation done in making models huge even though there’s performance gains. The innovation was finding about LMs in the first place, and you can train a reasonable language model on commodity hardware. If you want to train a full BERT model, yes, that will be expensive[/footnote].

Efficiency is paramount to progress

The reason deep learning research took off the way it did is because of improvements in efficiency as well as much better libraries and hardware support[footnote]Anyone who installed theano or tensorflow in 2015 can attest to that[/footnote].

Academic code is terrible

Any amount of time you spend gridsearching Node2Vec on p and q is all put to better use gridsearching Deepwalk itself [footnote]On number of walks, length of walks, or word2vec hyperparameters[/footnote]. The problem is that people don’t gridsearch over deepwalk because implementations are all terrible.

I wrote the Nodevectors library to have a fast deepwalk implementation because it took 32 hours to embed a graph with a measly 150,000 nodes using the reference Node2Vec implementation (the same takes 3min with Nodevectors). It’s no wonder people don’t gridsearch on Deepwalk a gridsearch would take weeks with the terrible reference implementations.

To give an example, in the original paper of GraphSAGE they their algorithm to DeepWalk with walk lengths of 5 [footnote]Which is horrid if you’ve ever hyperparameter tuned a deepwalk algorithm[/footnote]. From their paper:

We did observe DeepWalk’s performance could improve with further training, and in some cases it could become competitive with the unsupervised GraphSAGE approaches (but not the supervised approaches) if we let it run for >1000× longer than the other approaches (in terms of wall clock time for prediction on the test set)

I don’t even think the GraphSAGE authors had bad intent – deepwalk implementations are simply so awful that they’re turned away from using it properly. It’s like trying to do deep learning with 2002 deep learning libraries and hardware.

Your architectures don’t really matter

One of the more important papers this year was OpenAI’s “Scaling laws” paper, where the raw number of parameters in your model is the most predictive feature of overall performance. This was noted even in the original BERT paper and drives 2020’s increase in absolutely massive language models.

This is really just Sutton’ Bitter Lesson in action:

General methods that leverage computation are ultimately the most effective, and by a large margin

Transformers might be replacing convolution, too. As Yannic Kilcher said, transformers are ruining everything. They work on graphs, in fact it’s one of the recent approaches, and seems to be one of the more succesful when benchmarked

Researchers seem to be putting so much effort into architecture, but it doesn’t matter much in the end because you can approximate anything by stacking more layers.

Efficiency wins are great – but neural net architectures are just one way to achieve that, and by tremendously over-researching this area we’re leaving a lot of huge gains elsewhere on the table.

Current Graph Data Structure Implementations suck

NetworkX is a bad library. I mean, it’s good if you’re working on tiny graphs for babies, but for anything serious it chokes and forces you to rewrite everything in… what library, really?

At this point most people working on large graphs end up hand-rolling some data structure. This is tough because your computer’s memory is a 1-dimensional array of 1’s and 0’s and a graph has no obvious 1-d mapping.

This is even harder when we take updating the graph (adding/removing some nodes/edges) into account. Here’s a few options:

Disconnected networks of pointers

NetworkX is the best example. Here, every node is an object with a list of pointers to other nodes (the node’s edges).

This layout is like a linked list. Linked lists are the root of all performance evil.

Linked lists go completely against how modern computers are designed. Fetching things from memory is slow, and operating on memory is fast (by two orders of magnitude). Whenever you do anything in this layout, you make a roundtrip to RAM. It’s slow by design, you can write this in Ruby or C or assembly and it’ll be slow regardless, because memory fetches are slow in hardware.

The main advantage of this layout is that adding a new node is O(1). So if you’re maintaining a massive graph where adding and removing nodes happens as often as reading from the graph, it makes sense.

Another advantage of this layout is that it “scales”. Because everything is decoupled from each other you can put this data structure on a cluster. However, you’re really creating a complex solution for a problem you created for yourself [footnote]Moreover, the edgelist representation also can be distributed[/footnote]

Sparse Adjacency Matrix

This layout great for read-only graphs. I use it as the backend in my nodevectors library, and many other library writers use the Scipy CSR Matrix, you can see graph algorithms implemented on it here.

The most popular layout for this use is the CSR Format where you have 3 arrays holding the graph. One for edge destinations, one for edge weights and an “index pointer” which says which edges come from which node.

Because the CSR layout is simply 3 arrays, it scales on a single computer: a CSR matrix can be laid out on a disk instead of in-memory. You simply memory map the 3 arrays and use them on-disk from there.

With modern NVMe drives random seeks aren’t slow anymore, much faster than distributed network calls like you do when scaling the linked list-based graph. I haven’t seen anyone actually implement this yet, but it’s in the roadmap for my implementation at least.

The problem with this representation is that adding a node or edge means rebuilding the whole data structure.

Edgelist representations

This representation is three arrays: one for the edge sources, one for the edge destinations, and one for edge weights. DGL uses this representation internally.

This is a simple and compact layout which can be good for analysis.

The problem compared to CSR Graphs is some seek operations are slower. Say you want all the edges for node #4243. You can’t jump there without maintaining an index pointer array.

So either you maintain sorted order and binary search your way there (O(log2n)) or unsorted order and linear search (O(n)).

This data structure can also work on memory mapped disk array, and node append is fast on unsorted versions (it’s slow in the sorted version).

Global methods are a dead end

Methods that work on the entire graph at once can’t leverage computation, because they run out of RAM at a certain scale.

So any method that want a chance of being the new standard need to be able to update piecemeal on parts of the graph.

Sampling-based methods

Sampling Efficiency will matter more in the future

The problem with this approach is that it’s hard to use them for higher-order methods. The advantage is that they easily scale even on one computer. Also, incrementally adding a new node is as simple as taking the existing embeddings, adding a new one, and doing another epoch over the data

But this does scale, for instance Instagram use it to feed their recommendation system models

It’s currently used by Pinterest’s recommendation algorithms.


Here are a few interesting questions:

On the other hand, we should stop focusing on adding spicy new layers to test on the same tiny datasets. No one cares.

Originally published on by Matt Ranger