Source code for accmt.modules

# Copyright 2025 ghanvert. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import os
from abc import ABC
from typing import Callable, Optional, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator, DistributedType
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from typing_extensions import Any, Literal, override

from .curriculum import _CurriculumLearning
from .states import TrainingState
from .tracker import BaseTracker
from .utils import clear_device_cache


[docs] class AcceleratorModule(ABC): """ Super class to define training and validation logic without the need to write a training loop. The constructor of this class must implement `self.model`, specifying the model from `torch.nn.Module`. `self.teacher` is also a reserved property for teacher-student approaches. """ accelerator: Accelerator = None tracker: BaseTracker = None log_every: int = None state: TrainingState = None device: torch.device = None _implemented_collate_fn_train = False _implemented_collate_fn_val = False _extended = False model: nn.Module = None teacher: Optional[nn.Module] = None optimizer: Optimizer = None scheduler: LRScheduler = None _prepared: bool = False _log_cache = {} # noqa: RUF012 _registered_models: list[tuple[str, nn.Module]] = [] # noqa: RUF012 _registered_optimizers: list[tuple[str, Optimizer]] = [] # noqa: RUF012 _registered_schedulers: list[tuple[str, LRScheduler]] = [] # noqa: RUF012 # key is the object id _registered_accelerators: dict[int, Accelerator] = {} # noqa: RUF012 _model_path: str = None # initialized in Trainer _temp_path: str = None # initialized in Trainer
[docs] @override def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: """Defines the flow of data."""
[docs] @override def training_step(self, batch: Any) -> torch.Tensor: """Defines the training logic. Must return a loss tensor (scalar)."""
[docs] @override def validation_step(self, key: str, batch: Any) -> Union[dict, torch.Tensor]: """ Defines the validation logic. Must return a dictionary containing each metric with corresponding arguments, and also the loss value in the dictionary. Example: ``` # format is ==> "metric": (predictions, targets, ...) return { "loss": validation_loss_tensor, # (scalar tensor) # with additional metrics: "accuracy": (accuracy_predictions, accuracy_targets), "bleu": (bleu_predictions, bleu_targets) } ``` """
@override def test_step(self, batch: Any) -> Union[dict, torch.Tensor]: """ Defines the test logic. Must return a dictionary containing each metric with corresponding arguments. This function is similar to `validation_step`, but it is used for testing using the `Evaluator` class. Example: ``` # format is ==> "metric": (predictions, targets, ...) return { "accuracy": (accuracy_predictions, accuracy_targets), "...": (..., ...) } ``` """ @override def collate_fn_train(self, batch: list) -> Any: """Defines a collate function for PyTorch train DataLoader.""" @override def collate_fn_val(self, batch: list) -> Any: """Defines a collate function for PyTorch validation DataLoader."""
[docs] @override def get_optimizer(self) -> Optimizer: """Defines a custom PyTorch optimizer logic here."""
@override def get_scheduler(self, optimizer: Optimizer, steps_per_epoch: int, epochs: int) -> LRScheduler: """Defines a custom PyTorch scheduler logic here."""
[docs] @override def get_train_dataloader( self, dataset: Union[Dataset, list[Union[tuple[int, Dataset], tuple[int, Dataset, dict]]], _CurriculumLearning] ) -> Union[DataLoader, list[tuple[int, DataLoader]]]: """ Defines a custom PyTorch DataLoader class for training. In case of returning a `list` of tuples, the first element of each tuple represents the maximum step for each dataset, and the second element is the `DataLoader` for that dataset. For simple definitions of curriculum learning, you can use an instance of `StepsCurriculum`, `RangeCurriculum` or `RatioCurriculum` from `accmt.curriculum`. Must return a `DataLoader` or a `list` of tuples of `(max_step, DataLoader)`. """
[docs] @override def get_validation_dataloader( self, dataset: Union[Dataset, dict[int, Dataset], list[Dataset]] ) -> Union[DataLoader, dict[int, DataLoader], list[DataLoader]]: """Defines a custom PyTorch DataLoader class for validation."""
[docs] def log( self, values: dict[str, Union[torch.Tensor, float]], step: Optional[int] = None, reduction: Literal["sum", "mean"] = "mean", instant: bool = False, ): """ Log metrics to the tracker every N steps (defined in `Trainer`). If you want to apply any other logic, consider using `self.tracker.log` directly. This function will reduce tensors across all processes and only the main process will log the metrics. Also, values are accumulated then averaged when it's time to log. If no tracker is active, this function will do nothing. Args: values (`dict`): Dictionary of metrics to log. If values are tensors, they will be reduced across all processes. step (`int`, *optional*, defaults to `None`): Step number to log the metrics. Can access `self.state.global_step` (default) to log the current step, `self.state.train_step` or `self.state.val_step`. reduction (`str`, *optional*, defaults to `mean`): Reduction method to apply to tensors. Available options are `sum` and `mean`. Only applicable if values are tensors. instant (`bool`, *optional*, defaults to `False`): If `True`, log the metrics immediately, ignoring the `log_every` property. """ if self.tracker is None: return if step is None: step = self.state.global_step _log_every = self.log_every if not instant else 1 for k, v in values.items(): if isinstance(v, (float, int)): # convert to tensor to gather across all processes self._log_cache[k] = torch.tensor(v, device=self.device, dtype=torch.float64) elif isinstance(v, np.ndarray): self._log_cache[k] = torch.from_numpy(v).to(dtype=torch.float64, device=self.device) elif isinstance(v, torch.Tensor): self._log_cache[k] = v.detach().to(dtype=torch.float64, device=self.device) else: raise TypeError(f"Unsupported type for logging: {type(v)}") if k in self._log_cache: self._log_cache[k] += v else: self._log_cache[k] = v if step % _log_every == 0: cache_values = {} for k in values.keys(): cache_values[k] = self.accelerator.reduce(self._log_cache[k] / _log_every, reduction=reduction).float() self._log_cache.pop(k) self.tracker.log(cache_values, step=step, run_id=self.tracker.run_id)
def __init_subclass__(cls, **kwargs): # check collate functions if cls.collate_fn_train != AcceleratorModule.collate_fn_train: cls._implemented_collate_fn_train = True if cls.collate_fn_val != AcceleratorModule.collate_fn_val: cls._implemented_collate_fn_val = True super().__init_subclass__(**kwargs)
[docs] def __call__(self, *args: Any, **kwargs: Any): return self.forward(*args, **kwargs)
def __repr__(self): return self.model def __str__(self): return self.model.__repr__()
[docs] def __len__(self): return sum(p.numel() for p in self.model.parameters())
@classmethod def from_hf(cls, path: str, type: Union[str, Any] = None, **kwargs: Optional[Any]): """ Build a custom AcceleratorModule for HuggingFace's transformers library. It simply replaces the following standard: ``` class Module(AcceleratorModule): def __init__(self): self.model = AutoModel.from_pretrained(path, **kwargs) def training_step(self, batch): return self.model(**batch).loss def validation_step(self, batch): return {"loss": self.model(**batch).loss} ``` Args: path (`str`): Path for HuggingFace model. type (`str` or `Any`): Model type in transformers library. It can be the class itself or a string (no need for imports). kwargs (`Any`): Keyword arguments for `from_pretrained` function for model initialization. """ if isinstance(type, str): import importlib module = importlib.import_module("transformers") type = getattr(module, type) elif type is None: from transformers import AutoModel type = AutoModel class Module(AcceleratorModule): def __init__(self): self.model = type.from_pretrained(path, **kwargs) def training_step(self, batch): return self.model(**batch).loss def validation_step(self, batch): return self.model(**batch).loss return Module()
[docs] def freeze(self, module: nn.Module): """ Freeze all parameters inside a module. Args: module (`nn.Module`): Module where all parameters will have `requires_grad` set to `False`. """ for param in module.parameters(): param.requires_grad = False
[docs] def unfreeze(self, module: nn.Module): """ Unfreeze all parameters inside a module. Args: module (`nn.Module`): Module where all parameters will have `requires_grad` set to `True`. """ for param in module.parameters(): param.requires_grad = True
[docs] def pad( self, tensor: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]], value: float, padding: Optional[Literal["max_length", "longest"]] = None, max_length: Optional[int] = None, side: Literal["left, right"] = "right", op: Optional[Union[str, Callable]] = None, ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: """ Pad last dimension of tensors to a given 'max_length' or to the longest tensor in an iterable (`tuple` or `list`). Args: tensor (`torch.Tensor`, `list` or `tuple`): Single tensor or an iterable of tensors to be padded. value (`int` or `float`): Constant value to be added when padding. padding (`str`, *optional*, defaults to `None`): Padding strategy to apply. `longest` means that all tensors in an iterable will be padded to the longest tensor, and `max_length` will pad all tensors to a given `max_length`. **NOTE**: A single tensor can only be padded to `max_length`. If padding is not specified, its value will default to `longest` for iterables and `max_length` for single tensors. max_length (`int`, *optional*, defaults to `None`): Max length for tensors to calculate remaining padding amount. This applies only when `padding` is set to `max_length` or `tensor` is a single tensor. side (`str`, *optional*, defaults to `right`): Padding side. Available options are `right` and `left`. op (`str`, *optional*, defaults to `None`): PyTorch operation to do after tensors are padded. Options can be `stack`, `cat` or a function. Only applicable for iterable of tensors. Returns: (`torch.Tensor`, `list` or `tuple`): Padded tensors. """ _type = type(tensor) is_iterable = _type in {list, tuple} if _type is torch.Tensor or (is_iterable and len(tensor) == 1): if is_iterable: tensor = tensor[0] if tensor.ndim == 0: tensor.unsqueeze_(0) # if it's a single tensor, pad to 'max_length' and ignore 'padding' if max_length is None: self.accelerator.end_training() raise ValueError("When padding a single tensor, you must provide 'max_length'.") padding = max_length - tensor.size(-1) if padding < 0: raise RuntimeError("'pad' function is intended for padding and not truncation.") if side == "right": output = F.pad(tensor, pad=(0, padding), mode="constant", value=value) elif side == "left": output = F.pad(tensor, pad=(padding, 0), mode="constant", value=value) else: raise ValueError("'side' argument must be either 'left' or 'right'.") return _type(output) if is_iterable else output else: # if it's an iterable of tensors, pad to 'padding', and if 'padding' is not specified, # pad to 'longest'. padding = padding if padding is not None else "longest" if padding == "max_length": if max_length is None: raise ValueError("Must provide 'max_length' argument when padding = 'max_length'.") _max_length = max_length else: _max_length = max(x.size(-1) for x in tensor) kwargs = {"value": value, "max_length": _max_length, "side": side} for x in tensor: x.data = self.pad(x, **kwargs) if op is not None: tensor = getattr(torch, op)(tensor) if isinstance(op, str) else op(tensor) return tensor # objects inside iterable modified
def compile(self): """ Compile the model and teacher. At this stage, models are already on the correct device. """ self.model = torch.compile(self.model) if self.teacher is not None: self.teacher = torch.compile(self.teacher) def before_eval(self): """ This function is called before the evaluation loop. """ def after_eval(self): """ This function is called after the evaluation loop. """ def free_memory(self, *objects, clear_cache: bool = False, gc_collect: bool = False): """ Free memory from `objects` by setting them to `None`, and optionally calls `torch.{backend}.empty_cache()` when `clear_cache` is `True` along with `gc_collect` (if `gc_collect` is `True`). Args: `objects` (`Any`): Objects to free memory from. `clear_cache` (`bool`, *optional*, defaults to `False`): Clear device cache. `gc_collect` (`bool`, *optional*, defaults to `False`): Collect garbage. """ if not isinstance(objects, list): objects = list(objects) for i in range(len(objects)): objects[i] = None if clear_cache: clear_device_cache(garbage_collection=gc_collect) elif gc_collect: gc.collect() def _register_model(self, name: Optional[str] = None): """ Register a model to be wrapped by the accelerator. For safety, use `register` function instead. Args: name (`str`, *optional*, defaults to `None`): Attribute name of the model to register. If `None`, the model will be registered as `None` (no wrapping). """ model = getattr(self, name) if name is not None else None self._registered_models.append((name, model)) def _register_optimizer(self, name: Optional[str] = None): """ Register an optimizer to be wrapped by the accelerator. For safety, use `register` function instead. Args: name (`str`, *optional*, defaults to `None`): Attribute name of the optimizer to register. If `None`, the optimizer will be registered as `None` (no wrapping). """ optimizer = getattr(self, name) if name is not None else None self._registered_optimizers.append((name, optimizer)) def _register_scheduler(self, name: Optional[str] = None): """ Register a scheduler to be wrapped by the accelerator. For safety, use `register` function instead. Args: name (`str`, *optional*, defaults to `None`): Attribute name of the scheduler to register. If `None`, the scheduler will be registered as `None` (no wrapping). """ scheduler = getattr(self, name) if name is not None else None self._registered_schedulers.append((name, scheduler)) def register( self, model: str, optimizer: Optional[str] = None, scheduler: Optional[str] = None, ): """ Register a model, optimizer and scheduler to be wrapped by the accelerator. NOTE: Additional models will require custom compilation. Args: model (`str`): Attribute name of the model to register. optimizer (`str`, *optional*, defaults to `None`): Attribute name of the optimizer to register. scheduler (`str`, *optional*, defaults to `None`): Attribute name of the scheduler to register. """ if not isinstance(model, str): raise TypeError("'model' must be an attribute name (`str` instance).") if optimizer is not None and not isinstance(optimizer, str): raise TypeError("'optimizer' must be an attribute name (`str` instance) or `None`.") if scheduler is not None and not isinstance(scheduler, str): raise TypeError("'scheduler' must be an attribute name (`str` instance) or `None`.") self._register_model(model) self._register_optimizer(optimizer) self._register_scheduler(scheduler) def additional_backward( self, model: nn.Module, loss: torch.Tensor, lomo_optimizer: Optional[Any] = None, **kwargs, ): """ Similar to `self.backward(...)`, but for additional models created. Args: model (`nn.Module`): Model to backward. loss (`torch.Tensor`): Loss tensor to backward. lomo_optimizer (`Lomo` or `AdaLomo`): LOMO optimizer to use for backward pass. kwargs (`Any`): Extra arguments to be passed to `backward(...)`. Can include `learning_rate` for LOMO. """ # copied from accelerate's implementation learning_rate = kwargs.get("learning_rate") if self.accelerator.distributed_type == DistributedType.DEEPSPEED: model.backward(loss, **kwargs) elif self.accelerator.distributed_type == DistributedType.MEGATRON_LM: return elif self.accelerator.scaler is not None: self.accelerator.scaler.scale(loss).backward(**kwargs) elif learning_rate is not None and lomo_optimizer is not None: if learning_rate is None: raise ValueError("`learning_rate` must be passed in order to call backward pass with LOMO optimizer.") lomo_optimizer.optimizer.fused_backward(loss, learning_rate) else: loss.backward(**kwargs) def additional_optimizer_step(self, optimizer: Optimizer, **kwargs): """ Similar to `self.step_optimizer(...)`, but for additional models created. Args: optimizer (`Optimizer`): Optimizer to step. kwargs (`Any`): Extra arguments to be passed to `step(...)`. """ optimizer.step(**kwargs) def additional_optimizer_zero_grad(self, optimizer: Optimizer, **kwargs): """ Similar to `self.zero_grad(...)`, but for additional models created. Args: optimizer (`Optimizer`): Optimizer to zero gradients. kwargs (`Any`): Extra arguments to be passed to `zero_grad(...)`. """ optimizer.zero_grad(**kwargs) def additional_scheduler_step(self, scheduler: LRScheduler, **kwargs): """ Similar to `self.step_scheduler(...)`, but for additional models created. Args: scheduler (`LRScheduler`): Scheduler to step. kwargs (`Any`): Extra arguments to be passed to `step(...)`. """ scheduler.step(**kwargs) def save_temp_state(self, safe_serialization: bool = False, **save_model_func_kwargs: Any): default_path = os.path.join(self._temp_path, "accelerator0") self.accelerator.save_state(default_path, safe_serialization=safe_serialization, **save_model_func_kwargs) if len(self._registered_accelerators) > 0: seen = set() for i, accelerator in enumerate(self._registered_accelerators.values()): if id(accelerator) not in seen: seen.add(id(accelerator)) additional_path = os.path.join(self._temp_path, f"accelerator{i + 1}") accelerator.save_state( additional_path, safe_serialization=safe_serialization, **save_model_func_kwargs ) def load_temp_state(self, load_kwargs: Optional[dict] = None, **load_model_func_kwargs: Any): default_path = os.path.join(self._temp_path, "accelerator0") self.accelerator.load_state(default_path, load_kwargs, **load_model_func_kwargs) if len(self._registered_accelerators) > 0: seen = set() for i, accelerator in enumerate(self._registered_accelerators.values()): if id(accelerator) not in seen: seen.add(id(accelerator)) additional_path = os.path.join(self._temp_path, f"accelerator{i + 1}") accelerator.load_state(additional_path, load_kwargs, **load_model_func_kwargs) def reset_optimizer(self, optimizer: Optional[torch.optim.Optimizer] = None): optimizer = optimizer or self.optimizer optimizer.state.clear()
[docs] class ExtendedAcceleratorModule(AcceleratorModule): """ Extended module from `AcceleratorModule` to enhance `training_step` function. This means that the backpropagation part must be done manually. Example: ``` class Module(ExtendedAcceleratorModule): # other logic remains the same def training_step(self, batch): loss = ... self.backward(loss) self.step_optimizer() self.step_scheduler() return loss # loss will only be used to log metrics. ``` NOTE: `grad_accumulation_steps` in `fit` function from `Trainer` will not work. If you want to accumulate gradients and then backpropagate, you may want to make use of `self.state.global_step`. """ _extended = True
[docs] def backward(self, loss: torch.Tensor, **kwargs): """ Performs backward operation. Args: `loss` (`torch.Tensor`): Scalar loss tensor to backward. `kwargs` (`Any`): Extra arguments to be passed to 'accelerator.backward' function. """ self.accelerator.backward(loss, **kwargs)
[docs] def step_optimizer(self): self.optimizer.step()
[docs] def step_scheduler(self): self.scheduler.step()
[docs] def step(self): """Step optimizer and scheduler (in that order). If there is no scheduler, it will be ignored.""" self.step_optimizer() if self.scheduler is not None: self.step_scheduler()
[docs] def zero_grad(self, set_to_none: bool = True): """ Call optimizer's 'zero_grad' operation to reset gradients. Args: `set_to_none` (`bool`, *optional*, defaults to `True`): Set gradients to `None` instead of `0`. """ self.optimizer.zero_grad(set_to_none=set_to_none)
@override def training_step(self, batch: Any): pass def __init_subclass__(cls, **kwargs): # No call to super(), so it suppresses the behavior. pass