# Getting started with deep learning on graphs

*This post introduces deep learning on graphs by mapping its central concept - message passing - to minimal usage patterns of PyTorch Geometric’s foundation-laying MessagePassing class. Exploring those patterns, we gain some basic, very concrete insights into how graph DL works.*

If, in deep-learning world, the first half of the last decade has been the age of images, and the second, that of language, one could say that now, we’re living in the age of graphs. At least, that’s what commonly cited research metrics suggest. But as we’re all aware, deep-learning research is anything but an ivory tower. To see real-world implications, it suffices to reflect on how many things can be modeled as graphs. Some things quite naturally “are” graphs, in the sense of having nodes and edges: neurons, underground stations, social networks. Other things can fruitfully be modeled as graphs: molecules, for example; or language, concepts, three-dimensional shapes … If deep learning on graphs is desirable, what are the challenges, and what do we get for free?

## What’s so special about deep learning on graphs?

Graphs are different from images, language, as well as tabular data in that node numbering does not matter. In other words, graphs are permutation-invariant. Already this means that architectures established in other domains cannot be transferred verbatim. (The *ideas* underlying them can be transferred though. Thus, in the graph neural network (henceforth: GNN) model zoo you’ll see lots of allusions to “convolutional”, “attention”, and other established terms.) Put very simply, and in concordance with common sense, whatever algorithm is used, it will fundamentally be based on how nodes are connected: the edges, that is.

When relationships are modeled as graphs, both nodes and edges can have features. This, too, adds complexity. But not everything is harder with graphs. Think of how cumbersome it can be to obtain labeled data for supervised learning. With graphs, often an astonishingly small amount of labeled data is needed. More surprisingly still, a graph can be constructed when not a single edge is present. Put differently, learning on *sets* can morph into learning on *graphs*.

At this point, let me switch gears and move on to the practical part: the raison d’être of this post.

## Matching concepts and code: PyTorch Geometric

In this (and future) posts, we’ll make use of PyTorch Geometric (from hereon: PyG), the most popular, at this time, and fastest-growing in terms of functionality as well as user base, library dedicated to graph DL.

Deep learning on graphs, in its most general form, is usually characterized by the term *message passing*. Messages are passed between nodes that are linked by an edge: If node $A$ has three neighbors, it will receive three messages. Those messages have to be summarized in some meaningful way. Finally – GNNs consisting of consecutive layers – the node will have to decide how to modify its previous-layer features (a.k.a. embeddings) based on that summary.

Together, these make up a three-step sequence: collect messages; aggegate; update. What about the “learning” in deep learning, though? There are two places where learning can happen: Firstly, in message collection: Incoming messages could be transformed by a MLP, for example. Secondly, as part of the update step. All in all, this yields mathematical formulae like this, given in the PyG documentation:

$$x_i^{(k)} = \gamma^{(k)} \left( x_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} , \phi^{(k)}\left(x_i^{(k-1)}, x_j^{(k-1)},e_{j,i}\right) \right)$$

Scary though this looks, once we read it from the right, we see that it nicely fits the conceptual description. The $(x_i^{(k-1)},x_j^{(k-1)},e_{j,i})$ are the three types of incoming messages a node can receive: its own state at the previous layer, the states of its neighbors (the nodes $j \in \mathcal{N}(i)$) at the previous layer, and features/embeddings associated to the edge in question. (I’m leaving out edge features in this discussion completely, so as to not further enhance complexity.) These messages are (optionally) transformed by the neural network $\phi$, and whatever comes out is summarized by the aggregator function $\square$. Finally, a node will update itself based on that summary as well as its own previous-layer state, possibly by means of applying neural network $\gamma$.

Now that we have this conceptual/mathematical representation, how does it map to code we see, or would like to write? PyG has excellent, extensive documentation, including at the beginner level. But here, I’d like to spell things out in detail – pedantically, if you like, but in a way that tells us a lot about how GNNs work.

Let’s start by the information given in one of the key documentation pages, Creating message passing networks:

PyG provides the

`MessagePassing`

base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation. The user only has to define the functions $\phi$, i.e.`message()`

, and $\gamma$, i.e.`update()`

, as well as the aggregation scheme to use, i.e.`aggr="add"`

,`aggr="mean"`

or`aggr="max"`

.

Scrolling down that page and looking at the two example implementations, however, we see that an implementation of `update()`

does not have to be provided; and from inspecting the source code, it is clear that, technically, the same holds for `message()`

. (And unless we want a form of aggregation different from the default `add`

, we do not even need to specify that, either.)

Thus, the question becomes: What happens if we code the *minimal PyG GNN*? To find out, we first need to create a minimal graph, one minimal enough for us to track what is going on.

### A minimal graph

Now, a basic `Data`

object is created from three tensors. The first holds the node features: two features each for five nodes. (Both features are identical on purpose, for “cognitive ease” – on our, not the algorithm’s, part.)

```
import torch
x = torch.tensor([[1, 1], [2, 2], [3, 3], [11, 11], [12, 12]], dtype=torch.float)
```

The second specifies existing connections. For undirected graphs (like ours), each edge appears twice. The tensor you see here is specified in one-edge-per-line form for convenience reasons; to the `Data()`

constructor we’ll pass its transpose instead.

```
edge_index = torch.tensor([
[0, 1],
[1, 0],
[0, 2],
[2, 0],
[1, 2],
[2, 1],
[2, 3],
[3, 2],
[2, 4],
[4, 2],
[3, 4],
[4, 3]
], dtype=torch.long)
```

The third tensor holds the node labels. (The task will be one of node – not edge, not graph – classification.)

```
y = torch.tensor([[0], [0], [0], [1], [1]], dtype=torch.float)
```

Constructing and inspecting the resulting graph, we have:

```
from torch_geometric.data import Data
data = Data(x = x, edge_index = edge_index.t().contiguous(), y = y)
data.x
data.edge_index
data.y
```

```
tensor([[ 1., 1.],
[ 2., 2.],
[ 3., 3.],
[11., 11.],
[12., 12.]])
tensor([[0, 1, 0, 2, 1, 2, 2, 3, 2, 4, 3, 4],
[1, 0, 2, 0, 2, 1, 3, 2, 4, 2, 4, 3]])
tensor([[0.],
[0.],
[0.],
[1.],
[1.]])
```

For our upcoming experiments, it’s more helpful, though, to visualize the graph:

```
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx
def visualize_graph(G, color, labels):
plt.figure(figsize=(7,7))
plt.axis('off')
nx.draw_networkx(
G,
pos = nx.spring_layout(G, seed = 777),
labels = labels,
node_color = color,
cmap = "Set3"
)
plt.show()
G = to_networkx(data, to_undirected = True, node_attrs = ["x"])
labels = nx.get_node_attributes(G, "x")
visualize_graph(G, color = data.y, labels = labels)
```

Although our experiments won’t be about training performance (how could they be, with just five nodes), let me remark in passing that this graph is small, but not boring: The middle node is equally connected to both “sides”, yet feature-wise, it would pretty clearly appear to belong on just one of them. (Which is true, given the provided class labels). Such a constellation is interesting because, in the majority of networks, edges indicate similarity.

### A minimal GNN

Now, we code and run the minimal GNN. We’re not interested in class labels (yet); we just want to see each node’s embeddings after a single pass.

```
from torch_geometric.nn import MessagePassing
class IAmLazy(MessagePassing):
def forward(self, x, edge_index):
out = self.propagate(edge_index, x = x)
return out
module = IAmLazy()
out = module(data.x, data.edge_index)
out
```

```
tensor([[ 5., 5.],
[ 4., 4.],
[26., 26.],
[15., 15.],
[14., 14.]])
```

Evidently, we just had to start the process – but what process, exactly? From what we know about the three stages of message passing, an essential question is what nodes do with the information that flows over the edges. Our first experiment, then, is to inspect the incoming messages.

### Poking into `message()`

In `message()`

, we have access to a structure named `x_j`

. This tensor holds, for each node $i$, the embeddings of all nodes $j$ connected to it via incoming edges. We’ll print them, and then, just return them, unchanged.

```
class IAmMyOthers(MessagePassing):
def forward(self, x, edge_index):
out = self.propagate(edge_index, x = x)
return out
def message(self, x_j):
print("in message, x_j is")
print(x_j)
return x_j
module = IAmMyOthers()
out = module(data.x, data.edge_index)
print("result is:")
out
```

```
in message, x_j is
tensor([[ 1., 1.],
[ 2., 2.],
[ 1., 1.],
[ 3., 3.],
[ 2., 2.],
[ 3., 3.],
[ 3., 3.],
[11., 11.],
[ 3., 3.],
[12., 12.],
[11., 11.],
[12., 12.]])
result is:
tensor([[ 5., 5.],
[ 4., 4.],
[26., 26.],
[15., 15.],
[14., 14.]])
```

Let me spell this out. In `data.edge_index`

, repeated here for convenience:

```
tensor([[0, 1, 0, 2, 1, 2, 2, 3, 2, 4, 3, 4],
[1, 0, 2, 0, 2, 1, 3, 2, 4, 2, 4, 3]])
```

the first pair denotes the edge from node `0`

(that had features `(1, 1)`

) to node `1`

. This information is found in `x_j`

‘s first row. Then the second row holds the information flowing in the opposite direction, namely, the features associated with node `1`

. And so on.

Interestingly, since we’re passing through this module just once, we can see the messages that will be sent without even running it.

Namely, since `data.edge_index[0]`

designates the *source* nodes for each edge:

```
data.edge_index[0]
```

we can index `into data.x`

to pick up what will be the *incoming* features for each connection.

```
data.x[data.edge_index[0]]
```

```
tensor([[ 1., 1.],
[ 2., 2.],
[ 1., 1.],
[ 3., 3.],
[ 2., 2.],
[ 3., 3.],
[ 3., 3.],
[11., 11.],
[ 3., 3.],
[12., 12.],
[11., 11.],
[12., 12.]])
```

Now, what does this tell us? Node `0`

, for example, received messages from nodes `1`

and `2`

: `(2, 2)`

and `(3, 3)`

, respectively. We know that the default aggregation mode is `add`

; and so, would expect an outcome of `(5, 5)`

. Indeed, this is the new embedding for node `0`

.

In a nutshell, thus, the minimal GNN updates every node’s embedding so as to prototypically reflect the node’s neighborhood. Take care though: Nodes represent their neighborhoods, but themselves, they count for nothing. We will change that now.

### Adding self loops

All we need to do is modify the adjacency matrix to include edges going from each node back to itself.

```
from torch_geometric.utils import add_self_loops
class IAmMyOthersAndMyselfAsWell(MessagePassing):
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes = x.size(0))
print("in forward, augmented edge index now has shape")
print(edge_index.shape)
out = self.propagate(edge_index, x = x)
return out
def message(self, x_j):
return x_j
module = IAmMyOthersAndMyselfAsWell()
out = module(data.x, data.edge_index)
print("result is:")
out
```

```
in forward, augmented edge index now has shape:
torch.Size([2, 17])
result is:
tensor([[ 6., 6.],
[ 6., 6.],
[29., 29.],
[26., 26.],
[26., 26.]])
```

As expected, the neighborhood summary at each node now includes a contribution from each node itself.

Now we know how to access the messages, we’d like to aggregate them in a non-standard way.

### Customizing `aggregate()`

Instead of `message()`

, we now override `aggregate()`

. If we wanted to use another of the “standard” aggregation modes (`mean`

, `mul`

, `min`

, or `max`

), we could just override `__init__(),`

like so:

```
def __init__(self):
super().__init__(aggr = "mean")
```

To implement custom summaries, however, we make use of `torch_scatter`

(one of PyG’s installation prerequisites) for optimal performance. Let me show this by means of a simple example.

```
from torch_scatter import scatter
class IAmJustTheOppositeReally(MessagePassing):
def forward(self, x, edge_index):
out = self.propagate(edge_index, x = x)
return out
def aggregate(self, inputs, index):
print("in aggregate, inputs is")
# same as x_j (incoming node features)
print(inputs)
print("in aggregate, index is")
# this is data.edge_index[1]
print(index)
# see https://pytorch-scatter.readthedocs.io/en/1.3.0/index.html
# for other aggregation modes
# default dim is -1
return - scatter(inputs, index, dim = 0, reduce = "add")
module = IAmJustTheOppositeReally()
out = module(data.x, data.edge_index)
print("result is:")
out
```

```
in aggregate, inputs is
tensor([[ 1., 1.],
[ 2., 2.],
[ 1., 1.],
[ 3., 3.],
[ 2., 2.],
[ 3., 3.],
[ 3., 3.],
[11., 11.],
[ 3., 3.],
[12., 12.],
[11., 11.],
[12., 12.]])
in aggregate, index is
tensor([1, 0, 2, 0, 2, 1, 3, 2, 4, 2, 4, 3])
result is:
tensor([[ -5., -5.],
[ -4., -4.],
[-26., -26.],
[-15., -15.],
[-14., -14.]])
```

In `aggregate()`

, we have two types of tensors to work with. One, `inputs`

, holds what was returned from `message()`

. In our case, this is identical to `x_j`

, since we didn’t make any modifications to the default behavior. The second, `index`

, holds the recipe for where in the aggregation those features should go. Here, the very first tuple, `(1, 1)`

, will contribute to the summary for node `1`

; the second, `(2, 2)`

, to that for node `0`

– and so on. By the way, just like `x_j`

(in a single-layer, single-pass setup) is “just” `data.x[data.edge_index[0]]`

, that `index`

is “just” `data.edge_index[1]`

. Meaning, this is the list of target nodes connected to the edges in question.

At this point, all kinds of manipulations could be done on either `inputs`

or `index`

; however, we content ourselves with just passing them through to `torch_scatter.scatter()`

, and returning the negated sums. We’ve successfully built a network of contrarians.

By now, we’ve played with `message()`

as well as `aggregate()`

. What about `update()`

?

### Adding memory to `update()`

There’s one thing really strange in what we’re doing. It doesn’t jump to the eye, since we’re not simulating a real training phase; we’ve been calling the layer just once. If we hadn’t, we’d have noticed that at every call, the nodes happily forget who they were before, dutifully assuming the new identities assigned. In reality, we probably want them to evolve in a more consistent way.

For example:

```
class IDoEvolveOverTime(MessagePassing):
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
out = self.propagate(edge_index, x = x)
return out
def update(self, inputs, x):
print("in update, inputs is")
print(inputs)
print("in update, x is")
print(x)
return (inputs + x)/2
module = IDoEvolveOverTime()
out = module(data.x, data.edge_index)
print("result is:")
out
```

```
in update, inputs is
tensor([[ 6., 6.],
[ 6., 6.],
[29., 29.],
[26., 26.],
[26., 26.]])
in update, x is
tensor([[ 1., 1.],
[ 2., 2.],
[ 3., 3.],
[11., 11.],
[12., 12.]])
result is:
tensor([[ 3.5000, 3.5000],
[ 4.0000, 4.0000],
[16.0000, 16.0000],
[18.5000, 18.5000],
[19.0000, 19.0000]])
```

In `update()`

, we have access to both the final message aggregate (`inputs`

) and the nodes’ prior states (`x`

). Here, I’m just averaging those two.

At this point, we’ve successfully acquainted ourselves with the three stages of message passing: acting on individual messages, aggregating them, and self-updating based on past state and new information. But none of our models so far could be called a neural network, since there was no learning involved.

### Adding parameters

If we look back at the generic message passing formulation:

$$ \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}*i^{(k-1)}, \square*{j \in \mathcal{N}(i)} , \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}*j^{(k-1)},\mathbf{e}*{j,i}\right) \right) $$ we see two places where neural network modules can act on the computation: before message aggregation, and as part of the node update process. First, we illustrate the former option. For example, we can apply a MLP in `forward()`

, before the call to `aggregate()`

:

```
from torch.nn import Sequential as Seq, Linear, ReLU
class ILearnAndEvolve(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr = "sum")
self.mlp = Seq(Linear(in_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels))
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.mlp(x)
out = self.propagate(edge_index = edge_index, x = x)
return out
def update(self, inputs, x):
return (inputs + x)/2
module = ILearnAndEvolve(2, 2)
out = module(data.x, data.edge_index)
print("result is:")
out
```

```
result is:
tensor([[-0.8724, -0.4407],
[-0.9056, -0.4623],
[-2.0229, -1.1240],
[-1.8691, -1.0867],
[-1.9024, -1.1082]], grad_fn=<DivBackward0>)
```

Finally, we can apply network modules in both places, as exemplified next.

### General message passing

We keep the MLP from the previous class, and add a second in `update()`

:

```
class ILearnAndEvolveDoubly(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr = "sum")
self.mlp_msg = Seq(Linear(in_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels))
self.mlp_upd = Seq(Linear(out_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels))
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.mlp_msg(x)
out = self.propagate(edge_index = edge_index, x = x)
return out
def update(self, inputs, x):
return self.mlp_upd((inputs + x)/2)
module = ILearnAndEvolveDoubly(2, 2)
out = module(data.x, data.edge_index)
print("result is:")
out
```

```
result is:
tensor([[ 0.0573, -0.6988],
[ 0.0358, -0.6894],
[-0.1730, -0.6450],
[-0.5855, -0.4171],
[-0.5890, -0.4141]], grad_fn=<AddmmBackward0>)
```

At this point, I hope you’ll feel comfortable to play around, subclassing the `MessagePassing`

base class. Also, if now you consult the above-mentioned documentation page (Creating message passing networks), you’ll be able to map the example implementations (dedicated to popular GNN layer types) to where they “hook into” the message passing process.

Experimentation with `MessagePassing`

was the point of this post. However, you may be wondering: How do I actually use this for node classification? Didn’t the graph have a class defined for each node? (It did: `data.y`

.)

So let me conclude with a (minimal) end-to-end example that uses one of the above modules.

### A minimal workflow

To that purpose, we compose that module with a linear one that performs node classification:

```
class Network(torch.nn.Module):
def __init__(self, in_channels, out_channels, num_classes):
super().__init__()
self.conv = ILearnAndEvolveDoubly(in_channels, out_channels)
self.classifier = Linear(out_channels, num_classes)
def forward(self, x, edge_index):
x, edge_index = data.x, data.edge_index
x = self.conv(x, edge_index)
return self.classifier(x)
model = Network(2, 2, 1)
```

We can then train the model like any other:

```
import torch.nn.functional as F
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
model.train()
for epoch in range(5):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.binary_cross_entropy_with_logits(out, data.y)
loss.backward()
optimizer.step()
preds = torch.sigmoid(out)
preds
```

```
tensor([[0.6502],
[0.6532],
[0.7027],
[0.7145],
[0.7165]], grad_fn=<SigmoidBackward0>)
```

And that’s it for this time. Stay tuned for examples of how graph models are applied in the sciences, as well as illustrations of bleeding-edge developments in Geometric Deep Learning, the principles-based, heuristics-transcending approach to neural networks.

Thanks for reading!

Photo by Alina Grubnyak on Unsplash