bird_cloud_gnn.gnn_model
Module for creating GCN class
Module Contents
Classes
Graph Convolutional Network construction module |
- class bird_cloud_gnn.gnn_model.GCN(in_feats: int, layers_data: list)[source]
Bases:
torch.nn.ModuleGraph Convolutional Network construction module
A n-layer GCN is constructed from input features and list of layers Each layer computes new node representations by aggregating neighbour information.
- Parameters:
- layers
list of layers
- Type:
nn.ModuleList
- fit_and_evaluate(train_dataloader, test_dataloader, callback=None, learning_rate=0.01,
- num_epochs=20, sch_explr_gamma=0.99, sch_multisteplr_milestones=None,
- sch_multisteplr_gamma=0.1)
Fit the model while evaluating every iteraction.
- forward(g, in_feats)[source]
The forward function computes the output of the model.
- Parameters:
self – Access the attributes of the class
g – Access the graph structure and send messages between nodes
in_feat – Pass the input feature of the node
- Returns:
The output of the second convolutional layer
- evaluate(test_dataloader)[source]
Evaluate model.
- Parameters:
test_dataloader – Data loader, such as SubsetRandomSampler.
- Returns:
Accuracy
- Return type:
accuracy
- fit_and_evaluate(train_dataloader, test_dataloader, callback=None, learning_rate=0.01, num_epochs=20, sch_explr_gamma=0.99, sch_multisteplr_milestones=None, sch_multisteplr_gamma=0.1)[source]
Fit the model while evaluating every iteraction.
- Parameters:
train_dataloader (RandomWSubsetSampler) – Data loader to train set.
test_dataloader (RandomWSubsetSampler) – Data loader to test set.
callback (callable, optional) – Callback function. If defined, should receive a dict that stores “Loss/train”, “Accuracy/train”, “Loss/test”, “Accuracy/test”, and “epoch” of a single epoch. To send a stop signal, return True. Defaults to None.
learning_rate (float, optional) – Learning rate. Defaults to 0.01.
num_epochs (int, optional) – Number of training epochs. Defaults to 20.
sch_explr_gamma (float) – The exponential decay rate of the learning rate.
sch_multisteplr_milestones (list) – epoch numbers where the learning rate is decreased by a factor of sch_multisteplr_gamma. If None this is done at epoch 100
sch_multisteplr_gamma (float) – If a stepped decay of the learning rate is taken, the multiplication factor