Skip to content

Trainers

mlpoppyns.learning.trainers.trainer_base

Base trainer.

Authors:

Alberto Garcia Garcia (garciagarcia@ice.csic.es)

BaseTrainer

Base trainer.

This class serves as a blueprint for creating various trainers. It defines the essential methods that all trainers must implement, ensuring consistency.

Source code in mlpoppyns/learning/trainers/trainer_base.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
class BaseTrainer:
    """
    Base trainer.

    This class serves as a blueprint for creating various trainers.
    It defines the essential methods that all trainers must implement, ensuring consistency.
    """

    def __init__(
        self,
        model: torch.nn.Module,
        criterion: LossBase,
        metric: MetricBase,
        optimizer: torch.optim.Optimizer,
        configuration: ConfigurationParser,
    ) -> None:
        """
        Trainer initialization.

        Args:
            model (torch.nn.Module): Network to train.
            criterion (LossBase): Loss criterion.
            metric (MetricBase): Metric to validate.
            optimizer (torch.optim.Optimizer): Optimizer for training.
            configuration (ConfigurationParser): Current experiment configuration.
        """

        self.configuration = configuration

        self.logger = configuration.get_logger(
            "trainer", configuration["trainer"]["verbosity"]
        )

        self.criterion = criterion
        self.metric = metric
        self.optimizer = optimizer

        # Setup GPU device if available, move model into configured device.
        self.logger.info(
            "Requesting {} GPUs...".format(configuration["n_gpu"])
        )
        self.device, device_ids = request_device(
            self.logger, configuration["n_gpu"]
        )
        self.logger.info("Devices obtained: {}".format(device_ids))
        self.model = model.to(self.device)
        if len(device_ids) >= 1:
            self.logger.info(
                f"{len(device_ids)} GPU detected, running in parallel!"
            )
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        # Trainer configuration and parameter fetching from config dictionary.
        self.logger.info("Configuring trainer...")
        trainer_configuration = configuration["trainer"]
        self.epochs = trainer_configuration["epochs"]
        self.save_period = trainer_configuration["save_period"]
        self.monitor_best = self.metric.initial_value()
        self.early_stop = trainer_configuration.get("early_stop", inf)
        self.start_epoch = 1
        self.checkpoint_dir = self.configuration.save_dir
        self.log_dir = self.configuration.log_dir

        # Initialize training and validation JSONs.
        self.train_json_path = pathlib.Path().joinpath(
            self.log_dir, "train_result.json"
        )
        self.train_json: dict = {}
        with open(self.train_json_path, "w") as f:
            json.dump(self.train_json, f, indent=2, sort_keys=True)

        self.train_eval_json_path = pathlib.Path().joinpath(
            self.log_dir, "train_eval_result.json"
        )
        self.train_eval_json: dict = {}
        with open(self.train_eval_json_path, "w") as f:
            json.dump(self.train_eval_json, f, indent=2, sort_keys=True)

        self.validation_json_path = pathlib.Path().joinpath(
            self.log_dir, "validation_result.json"
        )
        self.validation_json: dict = {}
        with open(self.validation_json_path, "w") as f:
            json.dump(self.validation_json, f, indent=2, sort_keys=True)

        # setup visualization writer instance
        self.writer = TensorboardWriter(
            configuration.log_dir,
            self.logger,
            trainer_configuration["tensorboard"],
        )

        if configuration.resume:
            self._resume_checkpoint(
                configuration["resume_training"]["save_dir"]
            )

    @abstractmethod
    def _train_epoch(self, epoch: int) -> typing.Any:
        """
        Trains the model for a single epoch.

        This method trains the model for the specified epoch.

        Args:
            epoch (int): The current epoch number.

        Returns:
            (Any): The output of the training for a single epoch.
        """

        raise NotImplementedError

    def train(self, trial: int = None) -> typing.Tuple[dict, float]:
        """
        Main training procedure.

        This captures the whole training process. It calls the specific epoch
        training method from the derived trainers and updates the logging
        information accordingly.

        Furthermore, it also monitors the metric to check if it has improved
        or not and perform early stopping if needed.

        At last, it checkpoints the training process at the specified interval;
        it also saves the most accurate model to `best_model.pth`.

        Args:
            trial (int): The current trial to add suffixes to the saved models
                and checkpoints. Can be none if no trial is specified.

        Returns:
            (dict): A dictionary with the best values for each individual loss for
                each one of the targets.
            (float): The best result for the specified metric over the whole
                training process (validation accuracy according to the metric if
                validation is performed and training accuracy otherwise).
        """

        best_losses = {}
        not_improved_count = 0

        for epoch in range(self.start_epoch, self.epochs + 1):
            epoch_str = f"{epoch:05d}"

            self.logger.info(
                "************************************************"
            )
            self.logger.info(f"Epoch {epoch_str}")
            self.logger.info(f"Best accuracy: {self.monitor_best}")

            # Run one epoch and fetch the result dictionaries for train/val and
            # the losses that will be used for convergence.
            (
                train_result,
                val_result,
                train_eval_result,
                losses,
            ) = self._train_epoch(epoch)

            # Update current epoch logging dictionary with the results from the
            # training epoch (usually loss and accuracy averages).
            log = {"epoch": epoch}
            log.update(train_result)
            current_result = log[self.metric.__class__.__name__]

            # Print training per-epoch logged information to the screen.
            self.logger.info("Training results...")
            for key, value in log.items():
                self.logger.info(f"    {str(key):15s}: {value}")

            # Log results to training JSON.
            self.train_json[epoch_str] = {}
            for key, value in log.items():
                if key == "epoch":
                    continue
                self.train_json[epoch_str][key] = value

            with open(self.train_json_path, "w") as f:
                json.dump(self.train_json, f, indent=2, sort_keys=True)

            # Update current epoch logging dictionary with the results from the
            # training evaluation epoch (usually loss and accuracy averages).
            train_eval_log = {"epoch": epoch}
            train_eval_log.update(train_eval_result)

            # Print training per-epoch logged information to the screen.
            self.logger.info("Training evaluation results...")
            for key, value in train_eval_log.items():
                self.logger.info(f"    {str(key):15s}: {value}")

            # Log results to training JSON.
            self.train_eval_json[epoch_str] = {}
            for key, value in train_eval_log.items():
                if key == "epoch":
                    continue
                self.train_eval_json[epoch_str][key] = value

            with open(self.train_eval_json_path, "w") as f:
                json.dump(self.train_eval_json, f, indent=2, sort_keys=True)

            # Print validation information if validation was performed and use
            # it to update the training tracking metrics if so (like the current
            # best loss so far).
            if val_result is not None:
                # Update current epoch logging dictionary with the results from
                # the validation epoch (usually loss and accuracy averages).
                val_log = {"epoch": epoch}
                val_log.update(val_result)
                current_result = val_log[self.metric.__class__.__name__]

                # Print validation per-epoch logged information to the screen.
                self.logger.info("Validation results...")
                for key, value in val_log.items():
                    self.logger.info("    {:15s}: {}".format(str(key), value))

                # Log results to validation JSON.
                self.validation_json[epoch_str] = {}
                for key, value in val_log.items():
                    if key == "epoch":
                        continue
                    self.validation_json[epoch_str][key] = value

                with open(self.validation_json_path, "w") as f:
                    json.dump(
                        self.validation_json, f, indent=2, sort_keys=True
                    )

            # Check whether model performance improved or not, according
            # to specified metric behavior (minimum or maximum). The metric will
            # be the validation one if validation is performed or training if
            # no validation is carried out.
            best = False

            if self.metric.improved(self.monitor_best, current_result):
                # The current result improves the running best, save it and
                # reset the patience counter for early stopping.
                self.monitor_best = current_result
                not_improved_count = 0
                best = True
                best_losses = losses
                self.logger.info("Metric improved!")
            else:
                # The current result did not improve the running best, increase
                # the patience counter for early stopping.
                self.logger.info(
                    f"Metric did not improve for {not_improved_count} epochs..."
                )
                not_improved_count += 1

            # Perform early stopping if the metric has not improved for
            # the specified number of epochs (patience).
            if not_improved_count > self.early_stop:
                self.logger.info(
                    "Target metric did not improve for {} epochs. "
                    "Training stops.".format(self.early_stop)
                )
                break

            # Create checkpoint at the requested interval.
            if (epoch % self.save_period) == 0:
                self._save_checkpoint(
                    epoch, f"checkpoint_trial{trial}_epoch{epoch}.pth"
                )
                self.logger.info("Saved checkpoint...")

            # Save best model if it is the case.
            if best:
                self._save_checkpoint(epoch, f"best_model_trial{trial}.pth")
                self.logger.info("Saved best model so far...")

        return best_losses, self.monitor_best

    def _progress(
        self, batch_idx: int, data_loader: LoaderBase, len_epoch: int
    ) -> str:
        """
        Epoch progress tracker.

        Args:
            batch_idx (int): Current batch index.
            data_loader (LoaderBase): Data loader in use.
            len_epoch (int): Length of a whole epoch.

        Returns:
            (str): A string representation of the progress in the current epoch.
        """

        base = "[{}/{} ({:.0f}%)]"

        if hasattr(data_loader, "n_samples"):
            current = (batch_idx + 1) * data_loader.batch_size
            total = data_loader.n_samples

        else:
            current = batch_idx + 1
            total = len_epoch

        return base.format(current, total, 100.0 * current / total)

    def _save_checkpoint(self, epoch: int, filename: str) -> None:
        """
        Checkpoint saving.

        Saves the current state of the training process to a checkpoint:
        the model architecture, the epoch, the optimizer state, and the
        configuration.

        Args:
            epoch (int): Current training epoch.
            filename (str): Filename to save the checkpoint to.
        """

        arch = type(self.model).__name__

        state = {
            "arch": arch,
            "epoch": epoch,
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "monitor_best": self.monitor_best,
            "config": self.configuration,
        }

        filename = str(self.checkpoint_dir / filename)
        torch.save(state, filename)

    def _resume_checkpoint(self, checkpoint_path: pathlib.Path) -> None:
        """
        Resumes a checkpoint to continue training.

        Checks if the checkpoint architecture matches the current architecture,
        if not, no parameters are loaded. It also performs the same check for
        the optimizer state.

        Args:
            checkpoint_path (pathlib.Path): Path to checkpoint to resume.
        """

        checkpoint_path = str(checkpoint_path)

        self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)

        self.start_epoch = checkpoint["epoch"] + 1
        self.monitor_best = checkpoint["monitor_best"]

        # Only load model parameters (architecture) if the checkpoint arch
        # is the same as the current architecture.
        if checkpoint["config"]["arch"] != self.configuration["arch"]:
            self.logger.warning(
                "Warning: Architecture configuration given in config file is "
                "different from that of checkpoint."
            )

        else:
            self.model.load_state_dict(checkpoint["state_dict"])

        # Load optimizer state from checkpoint only when the optimizer type
        # from the checkpoint is the same as the current in use.
        if (
            checkpoint["config"]["optimizer"]["type"]
            != self.configuration["optimizer"]["type"]
        ):
            self.logger.warning(
                "Optimizer type given in config file is different from that of "
                "checkpoint. Optimizer parameters not being resumed."
            )

        else:
            self.optimizer.load_state_dict(checkpoint["optimizer"])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch
            )
        )

__init__(model, criterion, metric, optimizer, configuration)

Trainer initialization.

Parameters:

Name Type Description Default
model Module

Network to train.

required
criterion LossBase

Loss criterion.

required
metric MetricBase

Metric to validate.

required
optimizer Optimizer

Optimizer for training.

required
configuration ConfigurationParser

Current experiment configuration.

required
Source code in mlpoppyns/learning/trainers/trainer_base.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(
    self,
    model: torch.nn.Module,
    criterion: LossBase,
    metric: MetricBase,
    optimizer: torch.optim.Optimizer,
    configuration: ConfigurationParser,
) -> None:
    """
    Trainer initialization.

    Args:
        model (torch.nn.Module): Network to train.
        criterion (LossBase): Loss criterion.
        metric (MetricBase): Metric to validate.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        configuration (ConfigurationParser): Current experiment configuration.
    """

    self.configuration = configuration

    self.logger = configuration.get_logger(
        "trainer", configuration["trainer"]["verbosity"]
    )

    self.criterion = criterion
    self.metric = metric
    self.optimizer = optimizer

    # Setup GPU device if available, move model into configured device.
    self.logger.info(
        "Requesting {} GPUs...".format(configuration["n_gpu"])
    )
    self.device, device_ids = request_device(
        self.logger, configuration["n_gpu"]
    )
    self.logger.info("Devices obtained: {}".format(device_ids))
    self.model = model.to(self.device)
    if len(device_ids) >= 1:
        self.logger.info(
            f"{len(device_ids)} GPU detected, running in parallel!"
        )
        self.model = torch.nn.DataParallel(model, device_ids=device_ids)

    # Trainer configuration and parameter fetching from config dictionary.
    self.logger.info("Configuring trainer...")
    trainer_configuration = configuration["trainer"]
    self.epochs = trainer_configuration["epochs"]
    self.save_period = trainer_configuration["save_period"]
    self.monitor_best = self.metric.initial_value()
    self.early_stop = trainer_configuration.get("early_stop", inf)
    self.start_epoch = 1
    self.checkpoint_dir = self.configuration.save_dir
    self.log_dir = self.configuration.log_dir

    # Initialize training and validation JSONs.
    self.train_json_path = pathlib.Path().joinpath(
        self.log_dir, "train_result.json"
    )
    self.train_json: dict = {}
    with open(self.train_json_path, "w") as f:
        json.dump(self.train_json, f, indent=2, sort_keys=True)

    self.train_eval_json_path = pathlib.Path().joinpath(
        self.log_dir, "train_eval_result.json"
    )
    self.train_eval_json: dict = {}
    with open(self.train_eval_json_path, "w") as f:
        json.dump(self.train_eval_json, f, indent=2, sort_keys=True)

    self.validation_json_path = pathlib.Path().joinpath(
        self.log_dir, "validation_result.json"
    )
    self.validation_json: dict = {}
    with open(self.validation_json_path, "w") as f:
        json.dump(self.validation_json, f, indent=2, sort_keys=True)

    # setup visualization writer instance
    self.writer = TensorboardWriter(
        configuration.log_dir,
        self.logger,
        trainer_configuration["tensorboard"],
    )

    if configuration.resume:
        self._resume_checkpoint(
            configuration["resume_training"]["save_dir"]
        )

train(trial=None)

Main training procedure.

This captures the whole training process. It calls the specific epoch training method from the derived trainers and updates the logging information accordingly.

Furthermore, it also monitors the metric to check if it has improved or not and perform early stopping if needed.

At last, it checkpoints the training process at the specified interval; it also saves the most accurate model to best_model.pth.

Parameters:

Name Type Description Default
trial int

The current trial to add suffixes to the saved models and checkpoints. Can be none if no trial is specified.

None

Returns:

Type Description
dict

A dictionary with the best values for each individual loss for each one of the targets.

float

The best result for the specified metric over the whole training process (validation accuracy according to the metric if validation is performed and training accuracy otherwise).

Source code in mlpoppyns/learning/trainers/trainer_base.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def train(self, trial: int = None) -> typing.Tuple[dict, float]:
    """
    Main training procedure.

    This captures the whole training process. It calls the specific epoch
    training method from the derived trainers and updates the logging
    information accordingly.

    Furthermore, it also monitors the metric to check if it has improved
    or not and perform early stopping if needed.

    At last, it checkpoints the training process at the specified interval;
    it also saves the most accurate model to `best_model.pth`.

    Args:
        trial (int): The current trial to add suffixes to the saved models
            and checkpoints. Can be none if no trial is specified.

    Returns:
        (dict): A dictionary with the best values for each individual loss for
            each one of the targets.
        (float): The best result for the specified metric over the whole
            training process (validation accuracy according to the metric if
            validation is performed and training accuracy otherwise).
    """

    best_losses = {}
    not_improved_count = 0

    for epoch in range(self.start_epoch, self.epochs + 1):
        epoch_str = f"{epoch:05d}"

        self.logger.info(
            "************************************************"
        )
        self.logger.info(f"Epoch {epoch_str}")
        self.logger.info(f"Best accuracy: {self.monitor_best}")

        # Run one epoch and fetch the result dictionaries for train/val and
        # the losses that will be used for convergence.
        (
            train_result,
            val_result,
            train_eval_result,
            losses,
        ) = self._train_epoch(epoch)

        # Update current epoch logging dictionary with the results from the
        # training epoch (usually loss and accuracy averages).
        log = {"epoch": epoch}
        log.update(train_result)
        current_result = log[self.metric.__class__.__name__]

        # Print training per-epoch logged information to the screen.
        self.logger.info("Training results...")
        for key, value in log.items():
            self.logger.info(f"    {str(key):15s}: {value}")

        # Log results to training JSON.
        self.train_json[epoch_str] = {}
        for key, value in log.items():
            if key == "epoch":
                continue
            self.train_json[epoch_str][key] = value

        with open(self.train_json_path, "w") as f:
            json.dump(self.train_json, f, indent=2, sort_keys=True)

        # Update current epoch logging dictionary with the results from the
        # training evaluation epoch (usually loss and accuracy averages).
        train_eval_log = {"epoch": epoch}
        train_eval_log.update(train_eval_result)

        # Print training per-epoch logged information to the screen.
        self.logger.info("Training evaluation results...")
        for key, value in train_eval_log.items():
            self.logger.info(f"    {str(key):15s}: {value}")

        # Log results to training JSON.
        self.train_eval_json[epoch_str] = {}
        for key, value in train_eval_log.items():
            if key == "epoch":
                continue
            self.train_eval_json[epoch_str][key] = value

        with open(self.train_eval_json_path, "w") as f:
            json.dump(self.train_eval_json, f, indent=2, sort_keys=True)

        # Print validation information if validation was performed and use
        # it to update the training tracking metrics if so (like the current
        # best loss so far).
        if val_result is not None:
            # Update current epoch logging dictionary with the results from
            # the validation epoch (usually loss and accuracy averages).
            val_log = {"epoch": epoch}
            val_log.update(val_result)
            current_result = val_log[self.metric.__class__.__name__]

            # Print validation per-epoch logged information to the screen.
            self.logger.info("Validation results...")
            for key, value in val_log.items():
                self.logger.info("    {:15s}: {}".format(str(key), value))

            # Log results to validation JSON.
            self.validation_json[epoch_str] = {}
            for key, value in val_log.items():
                if key == "epoch":
                    continue
                self.validation_json[epoch_str][key] = value

            with open(self.validation_json_path, "w") as f:
                json.dump(
                    self.validation_json, f, indent=2, sort_keys=True
                )

        # Check whether model performance improved or not, according
        # to specified metric behavior (minimum or maximum). The metric will
        # be the validation one if validation is performed or training if
        # no validation is carried out.
        best = False

        if self.metric.improved(self.monitor_best, current_result):
            # The current result improves the running best, save it and
            # reset the patience counter for early stopping.
            self.monitor_best = current_result
            not_improved_count = 0
            best = True
            best_losses = losses
            self.logger.info("Metric improved!")
        else:
            # The current result did not improve the running best, increase
            # the patience counter for early stopping.
            self.logger.info(
                f"Metric did not improve for {not_improved_count} epochs..."
            )
            not_improved_count += 1

        # Perform early stopping if the metric has not improved for
        # the specified number of epochs (patience).
        if not_improved_count > self.early_stop:
            self.logger.info(
                "Target metric did not improve for {} epochs. "
                "Training stops.".format(self.early_stop)
            )
            break

        # Create checkpoint at the requested interval.
        if (epoch % self.save_period) == 0:
            self._save_checkpoint(
                epoch, f"checkpoint_trial{trial}_epoch{epoch}.pth"
            )
            self.logger.info("Saved checkpoint...")

        # Save best model if it is the case.
        if best:
            self._save_checkpoint(epoch, f"best_model_trial{trial}.pth")
            self.logger.info("Saved best model so far...")

    return best_losses, self.monitor_best

mlpoppyns.learning.trainers.trainer_basic

Basic trainer.

Authors:

Alberto Garcia Garcia (garciagarcia@ice.csic.es)

TrainerBasic

Bases: BaseTrainer

Basic trainer.

This class represents the most simple basic training pipeline which allows the user to perform training epochs coupled with validation passes and fully customize every single step of the pipeline (model to use, criterion to optimize, metrics to compute, optimizer to update the weights, and loaders from which data and targets can be fetched).

Source code in mlpoppyns/learning/trainers/trainer_basic.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
class TrainerBasic(BaseTrainer):
    """
    Basic trainer.

    This class represents the most simple basic training pipeline which allows
    the user to perform training epochs coupled with validation passes and fully
    customize every single step of the pipeline (model to use, criterion to
    optimize, metrics to compute, optimizer to update the weights, and loaders
    from which data and targets can be fetched).
    """

    def __init__(
        self,
        model: torch.nn.Module,
        criterion: mlpoppyns.learning.losses.loss_base,
        metric: mlpoppyns.learning.metrics.metric_base,
        optimizer: torch.optim.Optimizer,
        configuration: mlpoppyns.learning.configuration_parser,
        train_loader: mlpoppyns.learning.loaders.loader_base,
        val_loader: mlpoppyns.learning.loaders.loader_base = None,
        lr_scheduler: torch.optim.lr_scheduler = None,
    ) -> None:
        """
        Basic trainer initialization.

        Args:
            model (torch.nn.Module): Network model to train.
            criterion (mlpoppyns.LossBase): Criterion for the loss calculation.
            metric (mlpoppyns.MetricBase): Accuracy metric to be computed.
            optimizer (torch.optim.Optimizer): Optimizer for training.
            configuration (mlpoppyns.learning.configuration_parser): Configuration dictionary.
            train_loader (mlpoppyns.learning.loaders.loader_base): Train loader.
            val_loader (mlpoppyns.learning.loaders.loader_base): Validation loader.
            lr_scheduler (torch.optim.lr_scheduler): Learning rate scheduler.
        """

        super().__init__(model, criterion, metric, optimizer, configuration)

        self.train_loader = train_loader
        self.len_epoch = len(self.train_loader)

        self.logger.info(
            "Training loader normalization: {}".format(
                self.train_loader.normalize
            )
        )
        self.logger.info(
            "Training loader standardization: {}".format(
                self.train_loader.standardize
            )
        )

        if self.train_loader.normalize and self.train_loader.standardize:
            self.logger.error(
                "Error: Both standardization and normalization enabled for the train loader. "
                "You should choose only one of the two options."
            )
            exit()

        self.val_loader = val_loader
        self.validate = self.val_loader is not None

        if self.validate:
            self.logger.info(
                "Validation loader normalization: {}".format(
                    self.val_loader.normalize
                )
            )
            self.logger.info(
                "Validation loader standardization: {}".format(
                    self.val_loader.standardize
                )
            )

            if self.val_loader.normalize and self.val_loader.standardize:
                self.logger.error(
                    "Error: Both standardization and normalization enabled for the validation loader. "
                    "You should choose only one of the two options."
                )
                exit()

        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(self.train_loader.batch_size))

        self.train_metrics = learning_utils.metric_tracker.MetricTracker(
            [], writer=self.writer
        )
        self.train_denormalized_metrics = (
            learning_utils.metric_tracker.MetricTracker([], writer=self.writer)
        )
        self.valid_metrics = learning_utils.metric_tracker.MetricTracker(
            [], writer=self.writer
        )

    def _train_epoch(self, epoch: int) -> typing.Tuple[dict, dict, dict]:
        """
        Single-epoch training routine.

        Args:
            epoch (int): Current epoch number.

        Returns:
            (Tuple[dict, dict, dict]): A Tuple containing the following dictionaries:

                - A dictionary with the results for the epoch, i.e., the
                average for the losses and for the tracked metric for the training set.
                - A dictionary with the same info but for the validation set (if available, None is
                returned otherwise).
                - A dictionary with the values for each individual loss for
                each one of the targets. If validation is performed, such losses
                correspond to validation losses, otherwise they are the training
                set losses.
        """

        # Set the model on training mode and reset all tracked metrics to zero.
        self.model.train()
        self.train_metrics.reset()

        for batch_idx, (data, target) in enumerate(self.train_loader):
            # Fetch data and labels and move them to the appropriate device.
            data, target = data.to(self.device), target.to(self.device)

            # Zero gradients to reset loss.
            self.optimizer.zero_grad()

            # Compute output for this batch.
            output = self.model(data)

            # Compute each individual loss on each of the parameters to be
            # predicted by comparing the output and the ground truth for each
            # one of them. Then accumulate each individual loss in the total one.
            loss = 0.0
            for i in range(len(output[0])):
                # Compute individual loss for this output.
                loss_i = self.criterion(output[:, i], target[:, i])
                # Update tracked loss and output to TensorBoard.
                self.train_metrics.update(
                    "{}".format(self.train_loader.target_names[i]),
                    loss_i.item(),
                )
                # Accumulate into total loss.
                loss = loss + loss_i

            # Update tracked general loss and output to TensorBoard.
            self.train_metrics.update("loss", loss.item())

            # Only backpropagate on total loss not on individual ones.
            loss.backward()
            self.optimizer.step()

            # Now output all the log information to console and write the
            # necessary log values for TensorBoard.

            # Set the TensorBoard step.
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)

            # Update tracked metric and output to TensorBoard.
            self.train_metrics.update(
                self.metric.__class__.__name__,
                self.metric(output, target).item(),
            )

            # For each specified logging to console step, show the current
            # epoch training information (batch progress, loss...). Usually
            # We don't show it every batch because there will be too many.
            if batch_idx % self.log_step == 0:
                self.logger.debug(
                    "Train Epoch: {} {} Loss: {:.6f}".format(
                        epoch,
                        self._progress(
                            batch_idx, self.train_loader, self.len_epoch
                        ),
                        loss.item(),
                    )
                )

            if batch_idx == self.len_epoch:
                break

        # After a whole epoch has been carried out, store the dictionary of
        # results for each tracked metrics: usually the average loss and any
        # other specified accuracy metrics.
        log = self.train_metrics.result()
        # Pack the individual losses separately.
        losses = dict(
            filter(
                lambda e: e[0] in self.train_loader.target_names, log.items()
            )
        )

        # Perform evaluation on denormalized training set.
        train_denormalized_log = self._training_eval_epoch(epoch)

        # If there is a validation set, perform a validation step and fetch
        # the logged metrics and losses.
        val_log = None
        if self.validate:
            val_log = self._valid_epoch(epoch)
            # Pack the individual losses separately.
            losses = dict(
                filter(
                    lambda e: e[0] in self.val_loader.target_names,
                    val_log.items(),
                )
            )

        # Step learning rate if a scheduler is provided.
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log, val_log, train_denormalized_log, losses

    def _training_eval_epoch(self, epoch: int) -> dict:
        """
        Evaluate the model on the training dataset for a single epoch.

        This method sets the model to evaluation mode and processes the training dataset
        without gradient computation. It computes the loss and metrics, denormalizes or
        destandardizes the outputs and targets if necessary, and logs the results using
        TensorBoard.

        Args:
            epoch (int): The current epoch number.

        Returns:
            (dict): A dictionary containing the evaluation metrics for the training dataset.
        """

        # Set the model to evaluation mode and reset validation metrics.
        self.model.eval()
        self.train_denormalized_metrics.reset()

        # Fetch standardization and normalization factors.
        target_max = torch.tensor(self.train_loader.target_max).to(self.device)
        target_min = torch.tensor(self.train_loader.target_min).to(self.device)
        target_std = torch.tensor(self.train_loader.target_std).to(self.device)
        target_mean = torch.tensor(self.train_loader.target_mean).to(
            self.device
        )

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.train_loader):
                # Fetch data and targets, move them to the compute device.
                data, target = data.to(self.device), target.to(self.device)

                # Compute predictions.
                output = self.model(data)

                # De-normalize or de-standardize targets and outputs on the fly
                # if needed to rescale the loss values to a more readable range.
                # TODO: THIS COULD BE IMPROVED AND IDEALLY I WOULD LIKE THIS TO
                # BE DONE MORE TRANSPARENTLY, I DON'T KNOW HOW NOW.
                if self.train_loader.normalize:
                    output = output * (target_max - target_min) + target_min
                    target = target * (target_max - target_min) + target_min
                elif self.train_loader.standardize:
                    output = output * target_std + target_mean
                    target = target * target_std + target_mean

                # Compute each individual loss on each of the parameters to be
                # predicted by comparing the output and the ground truth for
                # each one of them. Then accumulate each individual loss in the
                # total one which will be reported.
                loss = 0.0
                for i in range(len(output[0])):
                    # Compute individual loss for this output.
                    loss_i = self.criterion(output[:, i], target[:, i])
                    # Update tracked loss and output to TensorBoard.
                    self.train_denormalized_metrics.update(
                        "{}".format(self.train_loader.target_names[i]),
                        loss_i.item(),
                    )
                    # Accumulate into total loss.
                    loss = loss + loss_i

                # Update tracked loss and output to TensorBoard.
                self.train_denormalized_metrics.update("loss", loss.item())
                # Update tracked metric and output to TensorBoard.
                self.train_denormalized_metrics.update(
                    self.metric.__class__.__name__,
                    self.metric(output, target).item(),
                )

                # Set TensorBoard step.
                self.writer.set_step(
                    (epoch - 1) * len(self.val_loader) + batch_idx,
                    "training_denormalized",
                )

        return self.train_denormalized_metrics.result()

    def _valid_epoch(self, epoch: int) -> dict:
        """
        Single-epoch validation routine.

        Args:
            epoch (int): Current epoch number.

        Returns:
            (dict): A dictionary which contains the results of the validation over the
                whole dataset for all the requested metrics and losses.
        """

        # Set the model to evaluation mode and reset validation metrics.
        self.model.eval()
        self.valid_metrics.reset()

        # Fetch standardization and normalization factors.
        target_max = torch.tensor(self.val_loader.target_max).to(self.device)
        target_min = torch.tensor(self.val_loader.target_min).to(self.device)
        target_std = torch.tensor(self.val_loader.target_std).to(self.device)
        target_mean = torch.tensor(self.val_loader.target_mean).to(self.device)

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.val_loader):
                # Fetch data and targets, move them to the compute device.
                data, target = data.to(self.device), target.to(self.device)

                # Compute predictions.
                output = self.model(data)

                # De-normalize or de-standardize targets and outputs on the fly
                # if needed to rescale the loss values to a more readable range.
                # TODO: THIS COULD BE IMPROVED AND IDEALLY I WOULD LIKE THIS TO
                # BE DONE MORE TRANSPARENTLY, I DON'T KNOW HOW NOW.
                if self.val_loader.normalize:
                    output = output * (target_max - target_min) + target_min
                    target = target * (target_max - target_min) + target_min
                elif self.val_loader.standardize:
                    output = output * target_std + target_mean
                    target = target * target_std + target_mean

                # Compute each individual loss on each of the parameters to be
                # predicted by comparing the output and the ground truth for
                # each one of them. Then accumulate each individual loss in the
                # total one which will be reported.
                loss = 0.0
                for i in range(len(output[0])):
                    # Compute individual loss for this output.
                    loss_i = self.criterion(output[:, i], target[:, i])
                    # Update tracked loss and output to TensorBoard.
                    self.valid_metrics.update(
                        "{}".format(self.val_loader.target_names[i]),
                        loss_i.item(),
                    )
                    # Accumulate into total loss.
                    loss = loss + loss_i

                # Update tracked loss and output to TensorBoard.
                self.valid_metrics.update("loss", loss.item())
                # Update tracked metric and output to TensorBoard.
                self.valid_metrics.update(
                    self.metric.__class__.__name__,
                    self.metric(output, target).item(),
                )

                # Set TensorBoard step.
                self.writer.set_step(
                    (epoch - 1) * len(self.val_loader) + batch_idx,
                    "validation",
                )

        # Add histogram of model parameters to TensorBoard.
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins="auto")

        return self.valid_metrics.result()

__init__(model, criterion, metric, optimizer, configuration, train_loader, val_loader=None, lr_scheduler=None)

Basic trainer initialization.

Parameters:

Name Type Description Default
model Module

Network model to train.

required
criterion LossBase

Criterion for the loss calculation.

required
metric MetricBase

Accuracy metric to be computed.

required
optimizer Optimizer

Optimizer for training.

required
configuration configuration_parser

Configuration dictionary.

required
train_loader loader_base

Train loader.

required
val_loader loader_base

Validation loader.

None
lr_scheduler lr_scheduler

Learning rate scheduler.

None
Source code in mlpoppyns/learning/trainers/trainer_basic.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def __init__(
    self,
    model: torch.nn.Module,
    criterion: mlpoppyns.learning.losses.loss_base,
    metric: mlpoppyns.learning.metrics.metric_base,
    optimizer: torch.optim.Optimizer,
    configuration: mlpoppyns.learning.configuration_parser,
    train_loader: mlpoppyns.learning.loaders.loader_base,
    val_loader: mlpoppyns.learning.loaders.loader_base = None,
    lr_scheduler: torch.optim.lr_scheduler = None,
) -> None:
    """
    Basic trainer initialization.

    Args:
        model (torch.nn.Module): Network model to train.
        criterion (mlpoppyns.LossBase): Criterion for the loss calculation.
        metric (mlpoppyns.MetricBase): Accuracy metric to be computed.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        configuration (mlpoppyns.learning.configuration_parser): Configuration dictionary.
        train_loader (mlpoppyns.learning.loaders.loader_base): Train loader.
        val_loader (mlpoppyns.learning.loaders.loader_base): Validation loader.
        lr_scheduler (torch.optim.lr_scheduler): Learning rate scheduler.
    """

    super().__init__(model, criterion, metric, optimizer, configuration)

    self.train_loader = train_loader
    self.len_epoch = len(self.train_loader)

    self.logger.info(
        "Training loader normalization: {}".format(
            self.train_loader.normalize
        )
    )
    self.logger.info(
        "Training loader standardization: {}".format(
            self.train_loader.standardize
        )
    )

    if self.train_loader.normalize and self.train_loader.standardize:
        self.logger.error(
            "Error: Both standardization and normalization enabled for the train loader. "
            "You should choose only one of the two options."
        )
        exit()

    self.val_loader = val_loader
    self.validate = self.val_loader is not None

    if self.validate:
        self.logger.info(
            "Validation loader normalization: {}".format(
                self.val_loader.normalize
            )
        )
        self.logger.info(
            "Validation loader standardization: {}".format(
                self.val_loader.standardize
            )
        )

        if self.val_loader.normalize and self.val_loader.standardize:
            self.logger.error(
                "Error: Both standardization and normalization enabled for the validation loader. "
                "You should choose only one of the two options."
            )
            exit()

    self.lr_scheduler = lr_scheduler
    self.log_step = int(np.sqrt(self.train_loader.batch_size))

    self.train_metrics = learning_utils.metric_tracker.MetricTracker(
        [], writer=self.writer
    )
    self.train_denormalized_metrics = (
        learning_utils.metric_tracker.MetricTracker([], writer=self.writer)
    )
    self.valid_metrics = learning_utils.metric_tracker.MetricTracker(
        [], writer=self.writer
    )

mlpoppyns.learning.trainers.trainers

Trainers.

This is just an empty module that gathers all the available trainer modules.

Authors:

Alberto Garcia Garcia (garciagarcia@ice.csic.es)