Training Setup
Defines the TrainingSetup class, which is an interface for training setups.
Functions that raise NotImplementedError are meant to be overridden.
To create a custom training setup, define a new class in the bfm.training.setups folder that
inherits from TrainingSetup and implements the required methods.
TrainingSetup
Interface for training setups.
Must implement the following methods
- initialize_model: Initializes the model components and sets self.model_components.
- calculate_pretrain_loss: Calculates the pretraining loss for a batch of data.
- generate_frozen_features: Generates frozen features for a batch of data.
Source code in bfm/training/training_setup.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | |
calculate_pretrain_loss(batch, output_accuracy=True)
Calculate the pretraining loss for a batch of data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict
|
Dictionary containing: data (torch.Tensor): Shape (batch_size, n_electrodes, n_timesamples) electrode_index (torch.Tensor): Shape (batch_size, n_electrodes) metadata (dict): Contains subject identifier, trial id, sampling rate, etc. |
required |
output_accuracy
|
bool
|
Whether to output accuracy metrics |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
dict |
Dictionary of losses where keys are loss names and values are loss values. The final loss is the mean of all losses. Accuracies are exempt and used only for logging. |
Source code in bfm/training/training_setup.py
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | |
calculate_pretrain_test_loss()
Calculate the pretraining test loss. This function uses the calculate_pretrain_loss function.
Returns:
| Name | Type | Description |
|---|---|---|
dict |
Dictionary of losses where keys are loss names and values are loss values. The final loss is the mean of all losses. Accuracies are exempt and used only for logging. |
Source code in bfm/training/training_setup.py
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | |
generate_frozen_features(batch)
Generate frozen features (meaning, the weights of the model are frozen) for a batch of data. This function is used for the model evaluation on the benchmarks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict
|
Dictionary containing: data (torch.Tensor): Shape (batch_size, n_electrodes, n_timesamples) electrode_labels (list): List of length 1 (since it's the same across the batch), each element is a list of electrode labels metadata (dict): Contains subject identifier, trial id, sampling rate, etc. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
features |
Tensor
|
Shape (batch_size, n_electrodes or n_electrodes+1, n_timebins, *) where * can be arbitrary if n_electrodes+1, then the first dimension is the cls token |
Source code in bfm/training/training_setup.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | |
generate_state_dicts()
This function generates the state dicts for the model components. It is used for saving and retrieving the model.
Source code in bfm/training/training_setup.py
131 132 133 134 135 136 137 138 139 140 | |
get_preprocess_functions(pretraining=False)
Get the preprocess functions for the training setup. Default: only subset electrodes to the maximum number during pretraining (during eval, pass all electrodes).
This must be an array of functions, each takes just batch as input and returns the (modified) batch. Modifying it in place is fine (but still return the batch).
Source code in bfm/training/training_setup.py
122 123 124 125 126 127 128 129 | |
initialize_model()
This function initializes the model.
It must set the self.model_components dictionary to a dictionary of the model components, like {"model": model, "electrode_embeddings": electrode_embeddings}, where model and electrode_embeddings are PyTorch modules (those classes must inherit from bfm.model.base)
Source code in bfm/training/training_setup.py
43 44 45 46 47 48 49 50 | |
load_dataloaders()
This function loads the dataloaders for the training and test sets.
It must set the self.train_dataloader and self.test_dataloader attributes to the dataloaders (they are used in the pretraining code in pretrain.py)
Source code in bfm/training/training_setup.py
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | |
model_parameters(verbose=False)
This function returns all of the parameters in the model, and stores the number of parameters in the config. It must output a list of all parameters in the model.
Source code in bfm/training/training_setup.py
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | |