Data Handling in PyTorch Geometric - Part 1
Inspired by Antonio Longa
import numpy as np
import torch
import torch_geometric.datasets as datasets
import torch_geometric.data as data
import torch_geometric.transforms as transforms
import torch_geometric.loader as loader
import networkx as nx
from torch_geometric.utils.convert import to_networkx
node_features = torch.rand((100, 16), dtype=torch.float)
print(node_features.shape)
rows = np.random.choice(100, 500)
cols = np.random.choice(100, 500)
edges = torch.tensor([rows, cols])
print(rows.shape)
print(cols.shape)
print(edges.shape)
edges_attr = np.random.choice(3, 500) # one of three values 0, 1, 2
ys = torch.rand((100)).round().long() # target class for each node
Creating the graph information in PyG Data Object
graph = data.Data(x = node_features, edge_index=edges, edge_attr=edges_attr, y=ys)
graph
type(graph)
for prop in graph:
print(prop)
vis = to_networkx(graph)
node_labels = graph.y.numpy()
import matplotlib.pyplot as plt
plt.figure(1, figsize=(15, 13))
nx.draw(vis, cmap=plt.get_cmap('Set3'), node_color= node_labels, node_size=70, linewidths=6)
plt.show()
with the Batch we can represent graphs as single disconnected graph
graph2 = graph
batch = data.Batch().from_data_list([graph, graph2])
print("Number of graphs:", batch.num_graphs)
print("Graph at index 1:", batch[1])
print("Retrieve the list of graphs:\n", len(batch.to_data_list()))
for each convolitonal layer, sample a maximum of nodes from each neighbourhood (as in GraphSAGE)
sampler = data.NeighborSampler(graph.edge_index, sizes=[3,10], batch_size=4, shuffle=False) #sizes is the nodes ion the levels
for s in sampler:
print(s)
break # first batch of neighbors
print("Batch Size:", s[0])
print("Number of unique nodes involved in the sampling:", len(s[1]))
print("Number of neighbors sampled:", len(s[2][0].edge_index[0]), len(s[2][1].edge_index[0]))
datasets.__all__
name = 'Cora'
transform = transforms.Compose([
transforms.RandomNodeSplit('train_rest', num_val=500, num_test=500),
transforms.TargetIndegree()
])
cora = datasets.Planetoid('./data', name, pre_transform=transforms.NormalizeFeatures(), transform=transform)
# pre_transform applied only when the dataset is downloading
# once dataeset is downloaded the transforms will be applies
# if dataset is already downloaded the running cell again will retrive dataset from the local it self
aids = datasets.TUDataset(root='./data', name="AIDS")
print("AIDS info:")
print("# of graphs:", len(aids))
print("# Classes {graphs}", aids.num_classes)
print("# Edge features", aids.num_edge_features)
print("# Edge labels", aids.num_edge_labels)
print("# Node features", aids.num_node_features)
print("Cora info:")
print("# of graphs:", len(cora))
print("# Classes {graphs}", cora.num_classes)
print("# Edge features", cora.num_edge_features)
print("# Node features", cora.num_node_features)
cora does not have attribute num_edge_labels
aids.data
# implies one single data object
aids[0]
#info of a specific graph
cora.data
cora[0]
cora_loader = loader.DataLoader(cora)
for l in cora_loader:
print(l)
break