Source code for accmt.callbacks

# Copyright 2025 ghanvert. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from dataclasses import dataclass

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from typing_extensions import Any, override

from .modules import AcceleratorModule
from .states import TrainingState


[docs] @dataclass class Callback(ABC): """ Callback module containing different callback functions for different stages of the traininig process. NOTE: Every callback function will run on every process. If you want your callback functions to only run on a single process, make sure to import `accmt.decorators` for different function decorators. Attributes: module (`AcceleratorModule`): Training module. trainer (`Trainer`): Defined `Trainer` class. state (`TrainingState`): Reference to `TrainingState` class. Methods: on_fit_start (*optional*): Callback when training process starts. on_fit_end (*optional*): Callback when training process ends. on_before_backward (*optional*): Callback before engine's backward. on_after_backward (*optional*): Callback after engine's backward. on_before_optimizer_step (*optional*): Callback before optimizers steps. on_after_optimizer_step (*optional*): Callback after optimizer steps. on_before_scheduler_step (*optional*): Callback before scheduler steps: on_after_scheduler_step (*optional*): Callback after scheduler steps. on_before_zero_grad (*optional*): Callback before optimizer resets gradients. on_after_zero_grad (*optional*): Callback after optimizer resets gradients. on_runtime_error (*optional*): Callback when process raises a `RunTimeError` exception. on_cuda_out_of_memory (*optional*): Callback when process raises a `RunTimeError` exception with CUDA Out Of Memory. on_keyboard_interrupt (*optional*): Callback when process raises a `KeyboardInterrupt` exception. on_exception (*optional*): Callback when process raises any other `Exception` different than `RuntimeError` and `KeyboardInterrupt` on_resume (*optional*): Callback when resuming training process. on_save_checkpoint (*optional*): Callback when saving checkpoint. on_before_training_step (*optional*): Callback before `training_step` function. on_after_training_step (*optional*): Callback after `training_step` function. on_before_validation_step (*optional*): Callback before `validation_step` function. on_after_validation_step (*optional*): Callback after `validation_step` function. on_epoch_start (*optional*): Callback when an epoch starts. on_epoch_end (*optional*): Callback when an epoch ends. on_evaluation_start (*optional*): Callback when evaluation starts. on_evaluation_end (*optional*): Callback when evaluation ends. """ module: AcceleratorModule = None trainer = None state: TrainingState = None
[docs] @override def on_fit_start(self): """Callback when training process starts."""
[docs] @override def on_fit_end(self): """Callback when training process ends."""
[docs] @override def on_before_backward(self, loss: torch.Tensor): """ Callback before engine's backward. Args: loss (`torch.Tensor`): Scalar loss tensor. """
[docs] @override def on_after_backward(self): """Callback after engine's backward."""
[docs] @override def on_before_optimizer_step(self, optimizer: Optimizer): """ Callback before optimizers steps. Args: optimizer (`Optimizer`): Wrapped optimizer. """
[docs] @override def on_after_optimizer_step(self, optimizer: Optimizer): """ Callback after optimizer steps. Args: optimizer (`Optimizer`): Wrapped optimizer. """
[docs] @override def on_before_scheduler_step(self, scheduler: LRScheduler): """ Callback before scheduler steps: Args: scheduler (`LRScheduler`): Wrapped scheduler. """
[docs] @override def on_after_scheduler_step(self, scheduler: LRScheduler): """ Callback after scheduler steps. Args: scheduler (`LRScheduler`): Wrapped scheduler. """
[docs] @override def on_before_zero_grad(self, optimizer: Optimizer): """ Callback before optimizer resets gradients. Args: optimizer (`Optimizer`): Wrapped optimizer. """
[docs] @override def on_after_zero_grad(self, optimizer: Optimizer): """ Callback after optimizer resets gradients. Args: optimizer (`Optimizer`): Wrapped optimizer. """
[docs] @override def on_runtime_error(self, exception: Exception): """ Callback when process raises a `RunTimeError` exception. Args: exception (`Exception`): Raised exception. """
[docs] @override def on_cuda_out_of_memory(self, exception: Exception): """ Callback when process raises a `RunTimeError` exception with CUDA Out Of Memory. Args: exception (`Exception`): Raised exception. """
[docs] @override def on_keyboard_interrupt(self, exception: Exception): """ Callback when process raises a `KeyboardInterrupt` exception. Args: exception (`Exception`): Raised exception. """ pass
[docs] @override def on_exception(self, exception: Exception): """ Callback when process raises any other `Exception` different than `RuntimeError` and `KeyboardInterrupt` Args: exception (`Exception`): Raised exception. """
[docs] @override def on_resume(self): """Callback when resuming training process."""
[docs] @override def on_save_checkpoint(self): """Callback when saving checkpoint."""
[docs] @override def on_before_training_step(self, batch: Any): """ Callback before `training_step` function. Args: batch (`Any`): Dataloader's batch. """
[docs] @override def on_after_training_step(self): """Callback after `training_step` function."""
[docs] @override def on_before_validation_step(self, batch: Any): """ Callback before `validation_step` function. Args: batch (`Any`): Dataloader's batch. """
[docs] @override def on_after_validation_step(self): """Callback after `validation_step` function."""
[docs] @override def on_epoch_start(self): """Callback when an epoch starts."""
[docs] @override def on_epoch_end(self): """Callback when an epoch ends."""
[docs] @override def on_evaluation_start(self): """Callback when evaluation starts."""
[docs] @override def on_evaluation_end(self): """Callback when evaluation ends."""
# TODO there is a better way to do this, using a decorator like @register_callback("on_fit_start"), but # we'll implement that (probably) before release of version 2.0. @dataclass class CallbackMaster: children: list[Callback] def on_fit_start(self): for child in self.children: child.on_fit_start() def on_fit_end(self): for child in self.children: child.on_fit_end() def on_before_backward(self, loss: torch.Tensor): for child in self.children: child.on_before_backward(loss) def on_after_backward(self): for child in self.children: child.on_after_backward() def on_before_optimizer_step(self, optimizer: Optimizer): for child in self.children: child.on_before_optimizer_step(optimizer) def on_after_optimizer_step(self, optimizer: Optimizer): for child in self.children: child.on_after_optimizer_step(optimizer) def on_before_scheduler_step(self, scheduler: LRScheduler): for child in self.children: child.on_before_scheduler_step(scheduler) def on_after_scheduler_step(self, scheduler: LRScheduler): for child in self.children: child.on_after_scheduler_step(scheduler) def on_before_zero_grad(self, optimizer: Optimizer): for child in self.children: child.on_before_zero_grad(optimizer) def on_after_zero_grad(self, optimizer: Optimizer): for child in self.children: child.on_after_zero_grad(optimizer) def on_resume(self): for child in self.children: child.on_resume() def on_save_checkpoint(self): for child in self.children: child.on_save_checkpoint() def on_before_training_step(self, batch: Any): for child in self.children: child.on_before_training_step(batch) def on_after_training_step(self): for child in self.children: child.on_after_training_step() def on_before_validation_step(self, batch: Any): for child in self.children: child.on_before_validation_step(batch) def on_after_validation_step(self): for child in self.children: child.on_after_validation_step() def on_epoch_start(self): for child in self.children: child.on_epoch_start() def on_epoch_end(self): for child in self.children: child.on_epoch_end() def on_evaluation_start(self): for child in self.children: child.on_evaluation_start() def on_evaluation_end(self): for child in self.children: child.on_evaluation_end()