Approach to the Problem
Learn how to solve a node classification problem.
We'll cover the following...
Our approach
We use the PyTorch
Geometric library to train the GNN problem, which implements several algorithms.
The first step in this approach is to build a data handler. There are a few methods provided in torch_geometric.data
for this purpose.
Custom data handler
Let's take a look at the following code in which we create a custom data handler for our graph data using the PyTorch
Geometric library:
import torchfrom torch_geometric.data import Dataimport pandas as pdfrom sklearn.model_selection import train_test_split# create edge index of the graphedge_index = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()# create the edge weight tensoredge_weights = [G[u][v]['contact'] for u, v in list(G.edges())]edge_weight = torch.tensor(edge_weights, dtype=torch.float)# create node features# Create a dataframe from the graph nodesdf = pd.DataFrame(dict(G.nodes(data=True))).T# convert selected features to tensornode_features = torch.tensor(df[['tested','symptoms','vaccinated','mobility']].astype(float).values,dtype=torch.float)# labelsnode_labels = df.label.map({'infected': 1, 'not infected': 0})y = torch.from_numpy(node_labels.values).type(torch.long)# create train and test masksX_train, X_test, y_train, y_test = train_test_split(pd.Series(G.nodes()),node_labels,stratify = node_labels,test_size=0.20,random_state=56)n_nodes = G.number_of_nodes()train_mask = torch.zeros(n_nodes, dtype=torch.bool)test_mask = torch.zeros(n_nodes, dtype=torch.bool)train_mask[X_train.index] = Truetest_mask[X_test.index] = True# create torch_geometric Data objectdata = Data(x=node_features, edge_index=edge_index, edge_weight=edge_weight,y=y, train_mask=train_mask, test_mask=test_mask,num_classes = 2, num_features=len(node_features))print(data)
Let’s look at the code explanation below:
Line 7: Creates an edge index of the graph, which is the default input used in this library.
Lines 10–11: Create an edge weight tensor using the contact details of the graph.
Line 15: Creates a DataFrame of all the node features.
Lines 18–20: Select relevant features and convert them into a
PyTorch
tensor.Line 23–24: Create a tensor of node labels and change the categorical variables into numerical ones.
Lines 27–31: Split the nodes and labels into training and testing sets in a ratio of 80:20 using a stratified split. This ensures equal proportions of infected and not infected cases in both sets. ...