import numpy as np
import torch
import torch_geometric.datasets as datasets
import torch_geometric.data as data
import torch_geometric.transforms as transforms
import torch_geometric.loader as loader
import networkx as nx
from torch_geometric.utils.convert import to_networkx

Data Manipulation

node_features = torch.rand((100, 16), dtype=torch.float)
print(node_features.shape)
torch.Size([100, 16])
rows = np.random.choice(100, 500)
cols = np.random.choice(100, 500)
edges = torch.tensor([rows, cols])
print(rows.shape)
print(cols.shape)
print(edges.shape)
(500,)
(500,)
torch.Size([2, 500])
edges_attr = np.random.choice(3, 500) # one of three values 0, 1, 2
ys = torch.rand((100)).round().long() # target class for each node

Creating the graph information in PyG Data Object

graph = data.Data(x = node_features, edge_index=edges, edge_attr=edges_attr, y=ys)
graph
Data(x=[100, 16], edge_index=[2, 500], edge_attr=[500], y=[100])
type(graph)
torch_geometric.data.data.Data
for prop in graph:
    print(prop)
('x', tensor([[0.3260, 0.5614, 0.7219,  ..., 0.8333, 0.2684, 0.6991],
        [0.2552, 0.7965, 0.5601,  ..., 0.2872, 0.7215, 0.5561],
        [0.4680, 0.9911, 0.2609,  ..., 0.7926, 0.7183, 0.5728],
        ...,
        [0.8130, 0.6181, 0.0477,  ..., 0.4111, 0.1915, 0.2908],
        [0.5820, 0.0549, 0.5638,  ..., 0.9298, 0.0790, 0.3821],
        [0.8570, 0.0329, 0.2215,  ..., 0.7657, 0.5044, 0.0548]]))
('edge_index', tensor([[22, 46, 57, 33, 51, 36, 73, 53, 98, 44, 19,  1, 43, 18, 75, 40, 28, 29,
         81, 16, 81, 77, 19, 28, 26, 56, 14, 65, 56, 23, 29, 32, 87, 62,  8, 82,
         17, 67, 27, 56, 39,  4, 34, 77, 45, 96, 28, 46, 32, 12, 15, 32,  8, 59,
         93, 49, 24, 50, 62, 95, 17, 16, 95, 61, 77,  2, 75, 32, 33, 19, 96, 49,
         33, 74, 61, 65, 13,  8,  0, 71, 29, 78, 24, 10, 35, 66, 49,  6, 38, 93,
          6, 54, 99, 41, 29, 59, 18, 62, 24, 48, 87, 77, 58, 60, 62, 37, 36, 31,
          7, 39, 12, 43, 26, 50, 37, 53, 98, 99, 58, 12,  7, 89, 78, 26, 81, 25,
         13, 79, 28, 21,  4, 57, 63,  0, 25, 17, 64, 59, 59, 96,  9, 98, 78,  5,
         10, 17, 63, 45, 68, 54, 67,  3, 47, 19, 49, 35, 24, 22, 29, 28, 39, 76,
         63, 29, 35, 92, 97, 62, 96, 79, 78, 84, 41, 39, 28,  3, 53, 46, 88, 79,
         39, 77, 94, 10, 10, 33, 63,  4, 73, 35, 10, 68, 99, 72, 58, 96, 10, 34,
         42, 62, 62,  7, 86, 39, 27, 42, 32, 91, 52, 52, 71, 36, 10, 93, 20,  6,
         65, 39, 33, 46, 23, 92, 84, 57, 29, 27, 29, 84, 94, 32, 76, 68, 32, 56,
         64, 12, 49, 48, 63, 92, 71, 25,  0, 15, 93,  0,  4, 60, 86, 68, 32,  0,
         13, 83,  8, 29, 37,  7, 49, 17, 37, 57, 40, 27, 26, 74, 74, 21, 90, 88,
         53, 40, 39, 36,  5,  5,  3, 82, 44, 19, 69, 76,  0, 58, 33, 26, 60, 52,
         27, 26, 18, 82, 40,  9, 18, 51,  4, 47, 53, 46, 59, 25, 55, 68, 75, 72,
         24, 22, 15, 92, 73, 12, 18, 15, 73, 67, 12, 16, 55, 13, 45, 64, 80, 33,
         17, 62,  1, 43, 29, 16, 51, 44, 14, 37, 71, 17, 55, 58, 93, 94, 21, 54,
         61, 81, 96, 34, 50, 97, 81, 77, 44, 64, 82, 26, 16, 25, 31, 66, 11, 74,
         24, 12, 50, 88, 69, 98, 26, 59, 35, 77, 68, 37, 79,  9, 50, 41, 76, 17,
         86, 36, 96, 72, 92, 31, 73, 32, 64, 37, 69,  2, 45, 14, 51, 19, 72, 65,
         51, 79, 11, 89, 61, 88, 85, 58,  4,  9, 93, 89, 78,  3, 15, 58, 16, 25,
         32, 99, 82, 33,  3, 22, 29, 40, 24, 81, 38, 56,  2, 72, 20,  9, 10, 58,
         51,  1, 46, 41, 62, 32, 14, 74, 56, 94, 27,  5, 69,  2, 30, 35, 59, 15,
         30, 72, 20,  5, 86, 28, 15, 26, 62, 57, 37,  6, 94, 84, 92,  2,  4, 63,
         53, 42, 62,  6, 71, 87, 17,  6, 84, 10, 27, 59, 87,  4,  6, 23, 23, 57,
         39, 20, 31, 97, 86, 52, 36, 26, 80, 77, 92, 54, 94, 29],
        [81, 22, 82, 30, 24, 54, 37, 31, 26, 45, 35, 26, 93, 28,  2, 42, 71, 40,
         73, 96, 90, 18, 44, 36, 39, 92, 73, 18, 38,  2, 91, 16, 43, 49, 39, 23,
         18,  3, 78, 91, 40, 73,  6, 47, 85, 46, 36, 67, 73, 60, 72, 18, 21, 15,
         15, 46,  3, 48, 82,  2, 54, 41,  9, 32, 71, 80, 60, 80, 45, 80, 35, 76,
         41, 59,  3, 97, 46, 84, 58, 31, 16, 52, 84, 85, 66,  6, 32, 83, 52, 81,
         86, 36, 58, 11, 48, 27, 52, 92, 73, 38,  4, 85,  9, 55, 20, 73, 44, 37,
         58, 86,  4, 39, 77, 61, 91, 41, 85, 75, 74, 88, 84, 28, 14,  8, 90, 96,
         61, 18, 24, 16, 19,  1, 86, 35, 74, 65, 63, 17, 32, 62, 23, 67, 31,  8,
         93, 62, 62, 76,  0, 20, 18, 69, 29, 94, 25,  6, 42, 46, 46, 32,  7, 80,
         13, 38, 67, 29, 11,  7, 57, 28, 62, 91, 17, 15, 88, 69, 55,  8, 19, 18,
         67, 90, 28, 58, 86, 13, 64, 51, 47, 61, 31, 57, 60, 69, 64, 39, 91,  3,
         57, 15, 27, 31, 15, 39, 16, 34, 11, 91, 86, 88, 23, 26, 45, 88, 41, 89,
         76,  3, 84, 26, 42, 29, 23, 55, 94, 55, 18,  2, 40, 13,  0, 54, 89, 88,
         25, 71, 49, 50, 28, 68, 36, 74, 68, 66, 51, 56, 85, 58, 61, 20, 11, 19,
         48, 19, 86, 56, 27, 44, 65, 82, 57, 60, 68,  2, 54,  9, 51, 97, 70, 21,
         95, 91, 50, 61, 47, 46, 94, 92, 24, 34, 50, 24, 58, 72, 34, 30, 16, 96,
         95, 56, 34, 93, 25, 26, 44, 26, 96, 14, 15, 48, 46, 30, 97, 89, 31, 99,
         58, 18, 45, 49, 80,  5,  4,  7, 88, 46, 17, 33,  4, 72, 94, 30, 22, 64,
         66, 95, 88, 44, 19, 63, 68, 12, 49, 67, 82, 49, 38, 91, 53, 34, 87, 19,
         35, 19, 40, 46,  4, 34,  5, 78,  0, 47, 98,  5, 77, 29, 15, 88, 17, 15,
         15,  8, 61, 59, 31, 44, 71, 57, 12, 47, 13, 28, 61, 13, 96, 93, 76, 62,
         57,  7, 10, 93, 45, 94,  8, 90, 74, 33, 66, 67, 52, 76, 28, 88, 95, 72,
         90, 69, 17, 63, 87, 83,  6, 14, 94, 67, 64, 85,  6,  3, 76, 25, 59, 93,
         64, 19, 53, 42, 98, 66, 96, 83, 18, 64, 55, 58, 73, 15, 35, 52, 67, 37,
         10, 52, 83, 20, 58, 93,  1, 74, 85, 16, 79, 59, 65, 41, 91, 78, 54, 48,
          3, 83, 41,  0, 39, 20, 40, 32,  7, 37, 57, 22, 35, 70, 13, 71, 41, 41,
         63, 11, 89, 26, 56, 57, 95, 72, 28, 34, 38, 25,  5, 65, 38,  7, 94, 35,
         38, 65, 25, 12, 56, 70, 49, 98, 56, 39, 99, 98, 28, 35]]))
('edge_attr', array([0, 1, 1, 0, 1, 2, 1, 0, 1, 1, 2, 0, 2, 2, 0, 1, 1, 2, 1, 2, 2, 0,
       2, 0, 2, 2, 0, 0, 2, 0, 2, 1, 2, 1, 2, 1, 0, 2, 2, 1, 0, 1, 2, 1,
       0, 2, 1, 0, 2, 0, 1, 0, 0, 1, 0, 2, 2, 2, 1, 1, 0, 0, 2, 2, 2, 0,
       1, 2, 1, 1, 0, 2, 0, 0, 1, 2, 0, 0, 2, 0, 2, 2, 2, 0, 2, 2, 0, 1,
       0, 0, 2, 0, 0, 0, 2, 1, 0, 1, 0, 0, 1, 2, 2, 0, 2, 2, 0, 1, 0, 1,
       2, 0, 1, 2, 2, 2, 0, 0, 2, 0, 0, 0, 2, 2, 0, 0, 2, 1, 2, 1, 0, 2,
       1, 1, 0, 0, 0, 0, 2, 0, 1, 1, 2, 2, 2, 1, 2, 0, 2, 0, 0, 0, 0, 0,
       1, 1, 2, 1, 1, 1, 2, 2, 2, 0, 1, 1, 2, 0, 0, 1, 1, 0, 2, 0, 1, 1,
       0, 2, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 2, 2, 1, 0, 2, 0, 0, 1, 2,
       2, 0, 0, 2, 2, 2, 1, 1, 1, 2, 2, 0, 1, 2, 1, 1, 2, 0, 0, 0, 1, 2,
       2, 2, 2, 1, 1, 0, 2, 0, 2, 1, 0, 0, 2, 1, 1, 0, 0, 1, 2, 2, 1, 0,
       2, 1, 0, 1, 0, 2, 1, 1, 0, 1, 1, 2, 2, 2, 0, 2, 0, 0, 0, 1, 2, 0,
       0, 2, 2, 0, 0, 1, 0, 2, 1, 2, 2, 2, 0, 1, 2, 1, 1, 1, 2, 1, 0, 2,
       2, 1, 2, 1, 1, 1, 0, 0, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 2, 0, 1,
       2, 1, 1, 2, 0, 1, 0, 2, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 2, 2, 1,
       0, 2, 0, 1, 0, 2, 0, 1, 0, 2, 0, 0, 0, 1, 2, 2, 1, 2, 1, 2, 2, 2,
       1, 1, 2, 0, 1, 2, 0, 1, 2, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1,
       1, 2, 2, 2, 2, 1, 0, 2, 1, 1, 2, 2, 2, 0, 0, 2, 1, 0, 2, 1, 2, 2,
       2, 0, 2, 0, 0, 0, 1, 1, 1, 1, 2, 0, 1, 0, 0, 0, 0, 2, 2, 0, 1, 1,
       1, 2, 0, 0, 2, 1, 2, 0, 1, 1, 2, 1, 1, 0, 0, 2, 0, 1, 1, 1, 0, 1,
       1, 2, 1, 2, 2, 1, 0, 2, 0, 1, 1, 2, 2, 0, 2, 0, 1, 0, 1, 0, 1, 2,
       1, 1, 0, 2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 1, 2, 1, 2, 1, 2, 0, 2, 1,
       0, 0, 0, 0, 1, 1, 2, 1, 2, 1, 1, 0, 0, 2, 1, 1]))
('y', tensor([0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0,
        0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0,
        1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1,
        1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0,
        1, 1, 0, 0]))
vis = to_networkx(graph)

node_labels = graph.y.numpy()

import matplotlib.pyplot as plt
plt.figure(1, figsize=(15, 13))
nx.draw(vis, cmap=plt.get_cmap('Set3'), node_color= node_labels, node_size=70, linewidths=6)
plt.show()

Batch

with the Batch we can represent graphs as single disconnected graph

graph2 = graph
batch = data.Batch().from_data_list([graph, graph2])
print("Number of graphs:", batch.num_graphs)
print("Graph at index 1:", batch[1])
print("Retrieve the list of graphs:\n", len(batch.to_data_list()))
Number of graphs: 2
Graph at index 1: Data(x=[100, 16], edge_index=[2, 500], edge_attr=[500], y=[100])
Retrieve the list of graphs:
 2

Cluster

 
 

Sampler

for each convolitonal layer, sample a maximum of nodes from each neighbourhood (as in GraphSAGE)

sampler = data.NeighborSampler(graph.edge_index, sizes=[3,10], batch_size=4, shuffle=False) #sizes is the nodes ion the levels
/home/siddy/anaconda3/envs/torch/lib/python3.8/site-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.NeighborSampler' is deprecated, use 'loader.NeighborSampler' instead
  warnings.warn(out)
for s in sampler:
    print(s)
    break # first batch of neighbors
(4, tensor([ 0,  1,  2,  3, 68,  5, 44, 57, 14, 84, 27, 23, 30, 24, 39, 76, 95, 75,
        67, 61, 34, 92, 51, 40, 87, 81, 12, 26, 98, 18,  7, 19, 36, 43, 96, 86,
        37, 42, 59, 78, 47, 58, 33,  8, 62, 82,  9, 71, 64, 25, 28, 77]), [EdgeIndex(edge_index=tensor([[ 4,  5,  6, 15,  7,  8,  9, 10, 11, 16, 17,  3, 12, 13, 14, 18, 19, 20,
          0, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,  4, 24, 34, 35,
         36, 36, 37, 38, 39, 40, 41, 13, 30, 42, 43, 36, 38, 44,  9, 45, 46, 47,
         27, 42, 48, 49,  6, 15, 22, 50, 14, 27, 33, 34, 35, 43, 51],
        [ 0,  0,  0,  0,  1,  1,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,
          4,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,
          7,  7,  7,  7,  8,  8,  8,  9,  9,  9,  9, 10, 10, 10, 11, 11, 11, 11,
         12, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14]]), e_id=tensor([148, 453, 350, 230, 131, 438, 227, 263,  29,  59,  14, 409, 450,  56,
        217,  37,  74, 197, 242, 239, 330, 262, 480, 348, 311, 353, 365, 294,
        257,  22, 106, 327, 191, 473, 168, 378, 260, 460, 198, 367, 122, 297,
        403,  82, 120, 218,  77, 256,  95, 200, 222,  35, 140, 210, 285,   3,
        321, 301, 278, 281,   4, 128, 203,  24, 111, 195, 454,  34, 495]), size=(52, 15)), EdgeIndex(edge_index=tensor([[ 4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
        [ 0,  0,  0,  1,  1,  2,  2,  2,  3,  3,  3]]), e_id=tensor([148, 453, 350, 131, 438, 227, 263,  29, 450,  56, 217]), size=(15, 4))])
print("Batch Size:", s[0])
print("Number of unique nodes involved in the sampling:", len(s[1]))
print("Number of neighbors sampled:", len(s[2][0].edge_index[0]), len(s[2][1].edge_index[0]))
Batch Size: 4
Number of unique nodes involved in the sampling: 52
Number of neighbors sampled: 69 11

Datasets

datasets.__all__
['KarateClub',
 'TUDataset',
 'GNNBenchmarkDataset',
 'Planetoid',
 'FakeDataset',
 'FakeHeteroDataset',
 'NELL',
 'CitationFull',
 'CoraFull',
 'Coauthor',
 'Amazon',
 'PPI',
 'Reddit',
 'Reddit2',
 'Flickr',
 'Yelp',
 'AmazonProducts',
 'QM7b',
 'QM9',
 'MD17',
 'ZINC',
 'MoleculeNet',
 'Entities',
 'RelLinkPredDataset',
 'GEDDataset',
 'AttributedGraphDataset',
 'MNISTSuperpixels',
 'FAUST',
 'DynamicFAUST',
 'ShapeNet',
 'ModelNet',
 'CoMA',
 'SHREC2016',
 'TOSCA',
 'PCPNetDataset',
 'S3DIS',
 'GeometricShapes',
 'BitcoinOTC',
 'ICEWS18',
 'GDELT',
 'DBP15K',
 'WILLOWObjectClass',
 'PascalVOCKeypoints',
 'PascalPF',
 'SNAPDataset',
 'SuiteSparseMatrixCollection',
 'TrackMLParticleTrackingDataset',
 'AMiner',
 'WordNet18',
 'WordNet18RR',
 'WikiCS',
 'WebKB',
 'WikipediaNetwork',
 'Actor',
 'OGB_MAG',
 'DBLP',
 'MovieLens',
 'IMDB',
 'LastFM',
 'HGBDataset',
 'JODIEDataset',
 'MixHopSyntheticDataset',
 'UPFD',
 'GitHub',
 'FacebookPagePage',
 'LastFMAsia',
 'DeezerEurope',
 'GemsecDeezer',
 'Twitch',
 'Airports',
 'BAShapes',
 'MalNetTiny',
 'OMDB',
 'PolBlogs',
 'EmailEUCore',
 'StochasticBlockModelDataset',
 'RandomPartitionGraphDataset',
 'LINKXDataset']
name = 'Cora'
transform = transforms.Compose([
    transforms.RandomNodeSplit('train_rest', num_val=500, num_test=500),
    transforms.TargetIndegree()
])
cora = datasets.Planetoid('./data', name, pre_transform=transforms.NormalizeFeatures(), transform=transform)
# pre_transform applied only when the dataset is downloading
# once dataeset is downloaded the transforms will be applies 
# if dataset is already downloaded the running cell again will retrive dataset from the local it self
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
aids = datasets.TUDataset(root='./data', name="AIDS")
Downloading https://www.chrsmrrs.com/graphkerneldatasets/AIDS.zip
Extracting data/AIDS/AIDS.zip
Processing...
Done!

Cora and AIDS Datasets

print("AIDS info:")
print("# of graphs:", len(aids))
print("# Classes {graphs}", aids.num_classes)
print("# Edge features", aids.num_edge_features)
print("# Edge labels", aids.num_edge_labels)
print("# Node features", aids.num_node_features)
AIDS info:
# of graphs: 2000
# Classes {graphs} 2
# Edge features 3
# Edge labels 3
# Node features 38
print("Cora info:")
print("# of graphs:", len(cora))
print("# Classes {graphs}", cora.num_classes)
print("# Edge features", cora.num_edge_features)
print("# Node features", cora.num_node_features)
Cora info:
# of graphs: 1
# Classes {graphs} 7
# Edge features 1
# Node features 1433

cora does not have attribute num_edge_labels

aids.data
# implies one single data object
Data(x=[31385, 38], edge_index=[2, 64780], edge_attr=[64780, 3], y=[2000])
aids[0]
#info of a specific graph
Data(edge_index=[2, 106], x=[47, 38], edge_attr=[106, 3], y=[1])
cora.data
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
cora[0]
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_attr=[10556, 1])
cora_loader = loader.DataLoader(cora)
for l in cora_loader:
    print(l)
    break
DataBatch(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_attr=[10556, 1], batch=[2708], ptr=[2])