- KANs proposed by Liu et al..
- See Fourier-KAN implementation, replaces splines with fourier coefficients.
General Message Passing Neural Network (MPNN)
-
Input Node and Edge Features:
- Nodes: $\mathbf{x}_i$ (node features)
- Edges: $\mathbf{e}_{ij}$ (edge features)
-
Message Passing Layer (per layer):
a. Edge Feature Transformation:
\[\mathbf{e}'_{ij} = f_e(\mathbf{e}_{ij})\]where $f_e$ is a transformation function applied to edge features.
b. Message Computation:
\[\mathbf{m}_{ij} = f_m(\mathbf{x}_i, \mathbf{x}_j, \mathbf{e}'_{ij})\]where $f_m$ computes messages using node features $\mathbf{x_i} ,\ \mathbf{x_j}$, and transformed edge features $\mathbf{e}’_{ij}$.
c. Message Aggregation:
\[\mathbf{m}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}\]where $\mathcal{N}(i)$ denotes the set of neighbors of node $i$.
d. Node Feature Update:
\[\mathbf{x}'_i = f_n(\mathbf{x}_i, \mathbf{m}_i)\]where $f_n$ updates node features using the aggregated messages $\mathbf{m}_i$.
-
Output Node and Edge Features:
- Nodes: $\mathbf{x}’_i$ (updated node features)
- Edges: $\mathbf{e}’_{ij}$ (updated edge features)
E3-Equivariant GNN with Learnable Activation Functions on Edges
-
Input Node and Edge Features:
- Nodes: $\mathbf{x}_i$ (node features)
- Edges: $\mathbf{e}_{ij}$ (edge features)
-
Learnable Edge Feature Transformation:
-
Fourier-based Edge Transformation:
\[\mathbf{e}'_{ij} = \text{FourierTransform}(\mathbf{e}_{ij})\]where the Fourier transformation is applied to edge features. Specifically, the transformation is defined as:
\[\mathbf{e}'_{ij} = \sum_{k=1}^{K} a_{ij,k} \cos(k \mathbf{e}_{ij}) + b_{ij,k} \sin(k \mathbf{e}_{ij})\]Here, $a_{ij,k}$ and $b_{ij,k}$ are learnable parameters, and $K$ is the number of Fourier terms.
-
-
Message Passing and Aggregation:
a. Message Computation:
\[\mathbf{m}_{ij} = \mathbf{e}'_{ij} \odot \mathbf{x}_j\]where $\odot$ denotes element-wise multiplication, combining the transformed edge features $\mathbf{e}’_{ij}$ with the neighboring node features $\mathbf{x}_j$.
b. Message Aggregation:
\[\mathbf{m}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}\]c. Simple Node Feature Transformation:
\[\mathbf{x}'_i = \mathbf{W} (\mathbf{x}_i + \mathbf{m}_i) + \mathbf{b}\]where $\mathbf{W}$ is a learnable weight matrix and $\mathbf{b}$ is a bias vector.
-
Output Node and Edge Features:
- Nodes: $\mathbf{x}’_i$ (updated node features)
- Edges: $\mathbf{e}’_{ij}$ (updated edge features)
Full Implementation
import torch
import torch.nn as nn
from torch_scatter import scatter_add
from torch_geometric.data import DataLoader
from torch_geometric.datasets import QM9
from torch_geometric.transforms import Distance
from torch_geometric.nn import MessagePassing
from torch.optim import Adam
from e3nn import o3
from e3nn.nn import Gate, FullyConnectedNet
class LearnableActivationEdge(nn.Module):
"""
Class to define learnable activation functions on edges using Fourier series.
Inspired by Kolmogorov-Arnold Networks (KANs) to capture complex, non-linear transformations on edge features.
"""
def __init__(self, inputdim, outdim, num_terms, addbias=True):
"""
Initialize the LearnableActivationEdge module.
Args:
inputdim (int): Dimension of input edge features.
outdim (int): Dimension of output edge features.
num_terms (int): Number of Fourier terms.
addbias (bool): Whether to add a bias term. Default is True.
"""
super(LearnableActivationEdge, self).__init__()
self.num_terms = num_terms
self.addbias = addbias
self.inputdim = inputdim
self.outdim = outdim
# Initialize learnable Fourier coefficients
self.fouriercoeffs = nn.Parameter(torch.randn(2, outdim, inputdim, num_terms) /
(torch.sqrt(torch.tensor(inputdim)) * torch.sqrt(torch.tensor(num_terms))))
if self.addbias:
self.bias = nn.Parameter(torch.zeros(1, outdim))
def forward(self, edge_attr):
"""
Forward pass to apply learnable activation functions on edge attributes.
Args:
edge_attr (Tensor): Edge attributes of shape (..., inputdim).
Returns:
Tensor: Transformed edge attributes of shape (..., outdim).
"""
# Reshape edge attributes for Fourier transformation
xshp = edge_attr.shape
outshape = xshp[0:-1] + (self.outdim,)
edge_attr = torch.reshape(edge_attr, (-1, self.inputdim))
# Generate Fourier terms
k = torch.reshape(torch.arange(1, self.num_terms + 1, device=edge_attr.device), (1, 1, 1, self.num_terms))
xrshp = torch.reshape(edge_attr, (edge_attr.shape[0], 1, edge_attr.shape[1], 1))
# Compute cosine and sine components
c = torch.cos(k * xrshp)
s = torch.sin(k * xrshp)
# Apply learnable Fourier coefficients
y = torch.sum(c * self.fouriercoeffs[0:1], (-2, -1))
y += torch.sum(s * self.fouriercoeffs[1:2], (-2, -1))
# Add bias if applicable
if self.addbias:
y += self.bias
# Reshape to original edge attribute shape
y = torch.reshape(y, outshape)
return y
class E3EquivariantGNN(MessagePassing):
"""
E(3)-Equivariant Graph Neural Network (GNN) that focuses on learnable activation functions on edges.
"""
def __init__(self, in_features, out_features, hidden_dim, num_layers, num_terms):
"""
Initialize the E3EquivariantGNN module.
Args:
in_features (int): Dimension of input node features.
out_features (int): Dimension of output node features.
hidden_dim (int): Dimension of hidden layers.
num_layers (int): Number of layers in the network.
num_terms (int): Number of Fourier terms for learnable activation functions.
"""
super(E3EquivariantGNN, self).__init__(aggr='add')
self.num_layers = num_layers
# Define the input and output irreps (representations)
self.input_irrep = o3.Irreps.spherical_harmonics(lmax=1) # Example irreps, adjust as needed
self.output_irrep = o3.Irreps([(out_features, (0, 1))]) # Scalar output
# Define the hidden irreps
hidden_irreps = [o3.Irreps.spherical_harmonics(lmax=1) for _ in range(num_layers)] # Adjust as needed
# Create the equivariant layers and learnable activation functions on edges
self.fourier_layers = nn.ModuleList([
LearnableActivationEdge(in_features if i == 0 else hidden_dim, hidden_dim, num_terms)
for i in range(num_layers)
])
self.layers = nn.ModuleList([
Gate(self.input_irrep, hidden_irreps[0], kernel_size=num_terms),
*[Gate(hidden_irreps[i], hidden_irreps[i+1], kernel_size=num_terms) for i in range(num_layers-1)],
Gate(hidden_irreps[-1], self.output_irrep, kernel_size=num_terms)
])
# Output layer
self.output_layer = nn.Linear(hidden_dim, out_features)
def forward(self, x, edge_index, edge_attr):
"""
Forward pass to propagate node features through the GNN.
Args:
x (Tensor): Node features of shape (num_nodes, in_features).
edge_index (Tensor): Edge indices of shape (2, num_edges).
edge_attr (Tensor): Edge attributes of shape (num_edges, edge_dim).
Returns:
Tensor: Output node features of shape (num_nodes, out_features).
"""
row, col = edge_index
# Apply Fourier-based message passing and equivariant transformations
for i in range(self.num_layers):
# Transform edge features with Fourier series
fourier_messages = self.fourier_layers[i](edge_attr)
# Apply equivariant transformations to node features
x = self.layers[i](x, fourier_messages)
# Compute messages
m_ij = fourier_messages[col] * x[row]
# Aggregate messages
m_i = scatter_add(m_ij, row, dim=0, dim_size=x.size(0))
# Update node features
x = m_i
# Apply the final linear layer
x = self.output_layer(x)
return x
# Load and prepare the QM9 dataset
dataset = QM9(root='data/QM9')
dataset.transform = Distance(norm=False)
# Split dataset into training, validation, and test sets
train_dataset = dataset[:110000]
val_dataset = dataset[110000:120000]
test_dataset = dataset[120000:]
# Data loaders for training, validation, and test sets
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# Define the loss function and optimizer
criterion = nn.MSELoss()
model = E3EquivariantGNN(in_features=16, out_features=1, hidden_dim=32, num_layers=3, num_terms=5)
optimizer = Adam(model.parameters(), lr=1e-3)
def train_step(model, optimizer, criterion, data):
"""
Perform a single training step.
Args:
model (nn.Module): The neural network model.
optimizer (Optimizer): The optimizer.
criterion (Loss): The loss function.
data (Data): The input data batch.
Returns:
float: The loss value.
"""
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.edge_attr)
loss = criterion(out, data.y)
loss.backward()
optimizer.step()
return loss.item()
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
train_loss = 0
for data in train_loader:
train_loss += train_step(model, optimizer, criterion, data)
train_loss /= len(train_loader)
val_loss = 0
model.eval()
with torch.no_grad():
for data in val_loader:
out = model(data.x, data.edge_index, data.edge_attr)
loss = criterion(out, data.y)
val_loss += loss.item()
val_loss /= len(val_loader)
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
Detailed Explanation of Mathematical Formulations
Learnable Edge Feature Transformation
For each edge $(i, j)$ with feature $\mathbf{e}_{ij}$:
\[\mathbf{e}'_{ij} = \sum_{k=1}^{K} a_{ij,k} \cos(k \mathbf{e}_{ij}) + b_{ij,k} \sin(k \mathbf{e}_{ij})\]where $a_{ij,k}$ and $b_{ij,k}$ are learnable parameters, and $K$ is the number of terms.
Message Computation
For each edge $(i, j)$:
\[\mathbf{m}_{ij} = \mathbf{e}'_{ij} \odot \mathbf{x}_j\]where $\odot$ denotes element-wise multiplication.
Message Aggregation
For each node $i$:
\[\mathbf{m}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}\]where $\mathcal{N}(i)$ denotes the set of neighbors of node $i$.
Node Feature Update
For each node $i$:
\[\mathbf{x}'_i = \mathbf{W} (\mathbf{x}_i + \mathbf{m}_i) + \mathbf{b}\]where $\mathbf{W}$ is a learnable weight matrix and $\mathbf{b}$ is a bias vector.
Summary
This implementation combines the learnable activation functions on edges with E(3) equivariant transformations on node features. The detailed mathematical formulations provided in the comments explain each step of the process, making it suitable for a physicist audience familiar with these concepts.
#Idea #TODO: KANs for learnable edge activations in MACE - to have it as an option. Train on the same set.