bird_cloud_gnn.cross_validation
Helper functions for cross validation.
Module Contents
Functions
|
Returns train and test dataloaders for a given dataset, train indices, test indices, and batch size. |
|
Evaluate the model on a dataset using StratifiedKFold. |
|
Evaluate the model on a dataset by looping over each origin, and training the data with |
- bird_cloud_gnn.cross_validation.get_dataloaders(dataset, train_idx, test_idx, batch_size)[source]
Returns train and test dataloaders for a given dataset, train indices, test indices, and batch size.
- Parameters:
- Returns:
A tuple containing the train and test dataloaders.
- Return type:
- bird_cloud_gnn.cross_validation.kfold_evaluate(dataset, layers_data, n_splits=5, learning_rate=0.01, num_epochs=100, batch_size=512)[source]
Evaluate the model on a dataset using StratifiedKFold.
- Parameters:
dataset (RadarDataset) – The dataset
layers_data (list) – The list of input size and activation
n_splits (int, optional) – Number of folds. Defaults to 5.
learning_rate (float, optional) – Learning rate. Defaults to 0.01.
num_epochs (int, optional) – Training epochs. Defaults to 20.
batch_size (int, optional) – Batch size used in the data loaders. Defaults to 512.
- Returns:
None
- bird_cloud_gnn.cross_validation.leave_one_origin_out_evaluate(dataset, layers_data, learning_rate=0.01, num_epochs=100, batch_size=512)[source]
Evaluate the model on a dataset by looping over each origin, and training the data with all data not from that origin, and testing with data from that origin. In other words, doing a leave one out validation on the origins.
- Parameters:
dataset (RadarDataset) – The dataset.
layers_data (list) – The list of input size and activation
n_splits (int, optional) – Number of folds. Defaults to 5.
learning_rate (float, optional) – Learning rate. Defaults to 0.01.
num_epochs (int, optional) – Training epochs. Defaults to 20.
batch_size (int, optional) – Batch size used in the data loaders. Defaults to 512.