Source code for accmt.trainer

# Copyright 2025 ghanvert. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import inspect
import logging
import math
import os
import shutil
import signal
import sys
import time
import traceback
from collections import defaultdict
from collections.abc import Mapping
from contextlib import nullcontext
from typing import Any, Callable, Optional, Union

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from accelerate import Accelerator, DistributedType
from accelerate.utils import ProjectConfiguration
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset

from .callbacks import Callback, CallbackMaster
from .curriculum import _CurriculumLearning
from .debug_timings import DebugTimings
from .evaluator import Evaluator
from .hyperparameters import HyperParameters
from .metrics import Metric
from .model_wrapper import _DistributedDataParallel
from .modules import AcceleratorModule
from .monitor import Monitor
from .states import LossState, TrainingState
from .tqdm import tqdm
from .tracker import _tracker_map
from .tunnel import AsyncDiskQueue, AsyncState, ModelTunnel
from .utils import clear_device_cache
from .utils.distributed import all_gather_dictionary
from .utils.globals import (
    ASYNC,
    ASYNC_HASH,
    ASYNC_TRAIN_GROUP,
    DEBUG_MODE,
    DEBUG_TIMINGS,
    DIST_HASH,
    MASTER_PROCESS,
    WORLD_SIZE,
    __version__,
)
from .utils.maps import _operator_map
from .utils.misc import filter_kwargs, get_number_and_unit, get_time_prefix, is_url, print_gpu_users_by_device, rprint
from .utils.seed import get_seed, set_seed


CHECKPOINT_DIR = "checkpoint"
STATE_FILE = "state.json"
TRAIN_LOSS_STATE_FILE = "train_loss_state.pt"
_bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} - ETA: {remaining}{postfix} - {rate_s}"
_tqdm_kwargs = {"leave": False, "ncols": 100, "bar_format": _bar_format}
_debug_timings_buffer = {"batch": 0.0, "step": 0.0}


[docs] class Trainer: """Class to implement full training process."""
[docs] def __init__( self, hps_config: Union[str, dict, HyperParameters], model_path: str, track_name: Optional[str] = None, enable_checkpointing: bool = True, multiple_checkpoints: bool = False, max_checkpoints: Optional[int] = None, resume: Optional[Union[bool, int]] = None, disable_model_saving: bool = False, patience: Optional[Union[int, dict[str, Any]]] = None, evaluate_every_n_steps: Optional[int] = None, checkpoint_every: Optional[str] = "epoch", logging_dir: str = "logs", log_with: Optional[str] = None, log_every: Optional[int] = 1, grad_accumulation_steps: Optional[int] = None, gradient_checkpointing: bool = False, gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None, clip_grad: Optional[float] = 1.0, set_to_none: bool = True, shuffle_train: bool = True, sampler: Optional[Union[Any, list]] = None, sampler_train: Optional[Union[Any, list]] = None, sampler_val: Optional[Union[Any, list]] = None, batch_sampler: Optional[Union[Any, list]] = None, batch_sampler_train: Optional[Union[Any, list]] = None, batch_sampler_val: Optional[Union[Any, list]] = None, collate_fn: Optional[Callable] = None, collate_fn_train: Optional[Callable] = None, collate_fn_val: Optional[Callable] = None, max_shard_size: str = "10GB", safe_serialization: bool = False, compile: bool = False, compile_kwargs: Optional[dict[str, Any]] = None, safe_mode: bool = True, train_loss_metric_name: str = "train_loss", val_loss_metric_name: str = "val_loss", dataloader_pin_memory: bool = True, dataloader_num_workers: Optional[int] = None, dataloader_drop_last: bool = False, eval_when_finish: bool = True, eval_when_start: bool = False, monitor: Optional[Monitor] = None, metrics: Optional[Union[Metric, list[Metric], dict[Any, Union[Metric, list[Metric]]]]] = None, consolidate_metrics: Optional[dict[str, Callable[[str], bool]]] = None, additional_metric_consolidation: Optional[Callable[[dict], dict]] = None, report_all_metrics: bool = False, cleanup_cache_every_n_steps: Optional[int] = None, callback: Optional[Union[Callback, list[Callback]]] = None, additional_tracker_config: Optional[dict[str, Any]] = None, batch_device_placement: bool = True, prepare_batch: bool = True, safe_steps: bool = True, destroy_after_training: bool = True, enable_prepare_logging: bool = False, module_hooks: bool = True, inference_mode: bool = True, reset_optimizer_every_n_steps: Optional[int] = None, **kwargs: Optional[Any], ): """ Trainer constructor to set configuration. Args: hps_config (`str`, `dict`, or `HyperParameters`): YAML hyperparameters file path, dictionary or `HyperParameters`. model_path (`str`): Path to save model. track_name (`str`, *optional*, defaults to `None`): Track name for trackers. If set to `None` (default), the track name will be the model's folder name. enable_checkpointing (`bool`, *optional*, defaults to `True`): Enable checkpointing. multiple_checkpoints (`bool`, *optional*, defaults to `False`): Enable multiple checkpoints. max_checkpoints (`int`, *optional*, defaults to `None`): Maximum number of checkpoints to keep. If set to `None`, all checkpoints will be kept. resume (`bool` or `int`, *optional*, defaults to `None`): Whether to resume from checkpoint. Default option is `None`, which means resuming from checkpoint will be handled automatically, whether the checkpoint directory exists or not. If set to `True`, the latest checkpoint will be loaded. If set to an integer, the checkpoint will be loaded from the given index (if `multiple_checkpoints` is `True`). If set to `-1`, the latest checkpoint will be loaded (if `multiple_checkpoints` is `True`). disable_model_saving (`bool`, *optional*, defaults to `False`): Disable any model saving registered (by default, `"best_valid_loss"` is registered, or if there are none evaluations to do, default will be `"best_train_loss"`). patience (`int` or `dict`, *optional*, defaults to `None`): Set up a patience parameter for model savings. If set, every model saving will check if the previous metric was higher. If the metric has not improved over the N model savings (`patience`), then the training process will stop. Can also implement patience per model saving in a dictionary. evaluate_every_n_steps (`int`, *optional*, defaults to `None`): Evaluate model in validation dataset (if implemented) every N steps. If this is set to `None` (default option), evaluation will happen at the end of every epoch. checkpoint_every (`str`, *optional*, defaults to `epoch`): Checkpoint every N epochs, steps or evaluations. Requires a number and a unit in a string. The following examples are valid: - `"epoch"`, `"ep"`, `"1epoch"`, `"1ep"`, `"1 epoch"`, `"1 ep"`: 1 Epoch - `"step"`, `"st"`, `"1step"`, `"1st"`, `"1 step"`, `"1 st"`: 1 Step - `"evaluation"`, `"eval"`, `"1evaluation"`, `"1eval"`, `"1 evaluation"`, `"1 eval"`: 1 Evaluation (a character `s` at the end of the string is also valid) If set to `None`, checkpointing will be disabled. logging_dir (`str`, *optional*, defaults to `logs`): Path where to save logs to show progress. It can be an IP address (local or remote), HTTP or HTTPS link, or simply a directory. log_with (`str`, *optional*, defaults to `None`): Logger to log metrics. It can be one of the following: - `mlflow` NOTE: MLFlow is the only one supported right now. Other trackers are not currently available. log_every (`int`, *optional*, defaults to `1`): Log train loss every N steps. If set to `-1`, training loss will be logged at the end of every epoch. grad_accumulation_steps (`int`, *optional*, defaults to `None`): Accumulate gradients for N steps. Useful for training large models and simulate large batches when memory is not enough. If set to `None` or `1`, no accumulation will be perfomed. gradient_checkpointing (`bool`, *optional*, defaults to `False`): Use gradient checkpointing. It requires a `gradient_checkpointing_enable` method in the model (models from HuggingFace's `transformers` library have this method already implemented) with a single argument `gradient_checkpointing_kwargs` (can be a dictionary or `None`). gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): Keyword arguments for `gradient_checkpointing_enable` method. clip_grad (`float`, *optional*, defaults to 1.0): Performs gradient clipping in between backpropagation and optimizer's step function. set_to_none (`bool`, *optional*, defaults to `True`): From PyTorch documentation: "instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance." Some optimizers have a different behaviour if the gradient is 0 or None. See PyTorch docs for more information: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html shuffle_train (`bool`, *optional*, defaults to `True`): Whether to shuffle train DataLoader. sampler (`list` or `Any`, *optional*, defaults to `None`): Sampler (or list of samplers) for train and validation DataLoader. sampler_train (`list` or `Any`, *optional*, defaults to `None`): Sampler (or list of samplers) for train DataLoader. Cannot be implemented if `sampler` was already declared. sampler_val (`list` or `Any`, *optional*, defaults to `None`): Sampler (or list of samplers) for validation DataLoader. Cannot be implemented if `sampler` was already declared. batch_sampler (`list` or `Any`, *optional*, defaults to `None`): Batch sampler for train and validation DataLoader. batch_sampler_train (`list` or `Any`, *optional*, defaults to `None`): Batch sampler for train DataLoader. Cannot be implemented if `batch_sampler` was already declared. batch_sampler_val (`list` or `Any`, *optional*, defaults to `None`): Batch sampler for validation DataLoader. Cannot be implemented if `batch_sampler` was already declared. collate_fn (`Callable`, *optional*, defaults to `None`): Collate function to be implemented in both train and validation dataloaders. collate_fn_train (`Callable`, *optional*, defaults to `None`): Collate function to be implemented in train dataloader. Cannot be implemented if `collate_fn` was already declared. collate_fn_val (`Callable`, *optional*, defaults to `None`): Collate function to be implemented in validation dataloader. Cannot be implemented if `collate_fn` was already declared. max_shard_size (`str`, *optional*, defaults to `10GB`): Max model shard size to be used. safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save model using safe tensors or the traditional PyTorch way. If `True`, some tensors will be lost. compile (`bool`, *optional*, defaults to `False`): Whether to call `torch.compile` on model (and teacher, if implemented). compile_kwargs (`dict`, *optional*, defaults to `None`): `torch.compile` kwargs for additional customization. safe_mode (`bool`, *optional*, defaults to `True`): Run forward passes of the model in safe mode. This means that the forward pass of the model will run through the corresponding wrapper (DDP, FSDP or DeepSpeedEngine). If not running in safe mode, forward pass will skip the wrapper and run directly on the module (instance of `nn.Module`). Running with safe mode disabled will slightly improve throughput, although gradients consistency and mixed precision could be affected because skipping the wrapper's forward pass might skip internal parallel functionality. **NOTE**: This parameter takes no effect running with FSDP since forward passes are already done through this wrapper. train_loss_metric_name (`str`, *optional*, defaults to `train_loss`): Metric name for train loss in logs. val_loss_metric_name (`str`, *optional*, defaults to `val_loss`): Metric name for validation loss in logs. dataloader_pin_memory (`bool`, *optional*, defaults to `True`): Enables pin memory option in DataLoader (only if GPU is enabled). dataloader_num_workers (`int`, *optional*, defaults to `None`): Number of processes for DataLoader. This defaults to `None`, meaning the number of workers will be equal to the number of processes set for training. dataloader_drop_last (`bool`, *optional*, defaults to `False`): Whether to drop last batch on DataLoader or not. eval_when_finish (`bool`, *optional*, defaults to `True`): At the end of training, evaluate model on validation dataset (if available). This option is only valid when `evaluate_every_n_steps` is not `None`. eval_when_start (`bool`, *optional*, defaults to `False`): Start training with evaluation (if available). monitor (`Monitor` or `dict`, *optional*, defaults to `None`): Monitor arguments to keep track of variables during training. If not specified, 'train_loss' and 'validation_loss' will be set to `True` by default. NOTE: Learning rate, GPU and CPU monitoring will only be reported during training, not evaluation. Also, GPU and CPU monitoring will only be reported on main process (index 0). metrics (`Metric`, `list` or `dict`, *optional*, defaults to `None`): List of additional metrics of type 'Metric' to track. When doing multiple evaluations, this should be a dictionary of metrics (or list of metrics), where each key corresponds to the dataset to evaluate (specified in `val_dataset` in `fit` function) and the value corresponds to a `Metric` or list of metrics. If metrics are given as only `Metric` or list of metrics, these metrics will apply for all evaluations. If you want specific metrics for specific evaluations, consider dividing your metrics per validation dataset in a dictionary. consolidate_metrics (`dict`, *optional*, defaults to `None`): Dictionary of metrics to consolidate. The key is the metric name and the value is a function that takes the metric's key and returns a `bool` value indicating whether the metric should be included in the consolidation. Consolidated metrics will be reported as a single value, and individual metrics used for consolidation will not be reported by default. To control this behavior, set `report_all_metrics` to `True` (default is `False`). NOTE: When doing multiple evaluations, the metric name has a prefix corresponding to the dataset being evaluated. e.g. `accuracy__fist_val_dataset`, `accuracy_second_val_dataset`, etc. Example: ``` Trainer( ..., metrics=..., # <-- at least one metric with name 'accuracy' consolidate_metrics={"average_accuracy": lambda key: key.startswith("accuracy")}, ) ``` additional_metric_consolidation (`Callable`, *optional*, defaults to `None`): Function to apply additional consolidation on metrics. This function takes a dictionary of metrics and returns a dictionary of consolidated metrics. Example: ``` def additional_metric_consolidation(metrics: dict): # `metrics` is a dictionary where keys are metric names consolidated and values are arrays of metric values. # Here, "max_accuracy" would be the new reported metric. return { "max_accuracy": max(metrics["accuracy"].values()) } ``` report_all_metrics (`bool`, *optional*, defaults to `False`): When `consolidate_metrics` is implemented, this option controls whether to report all individual metrics in addition to the consolidated metrics. If `False` (default), individual metrics used for consolidation will not be reported. This function does not take any effect when `consolidate_metrics` is not implemented. cleanup_cache_every_n_steps (`int`, *optional*, defaults to `None`): Cleanup CPU and CUDA caches every N steps. Default is no cleanup. NOTE: On every epoch and evaluation call we cleanup cache. callback (`Callback` or `list`, *optional*, defaults to `None`): `Callback` or callbacks to implement. additional_tracker_config (`dict`, *optional*, defaults to `None`): Additional configuration specification for tracker (e.g. hyper-parameters). batch_device_placement (`bool`, *optional*, defaults to `True`): Move batches to correct device automatically. If `False`, batches will be in CPU. prepare_batch (`bool`, *optional*, defaults to `True`): Prepares a batch dynamically when using Mixed Precision. When using DeepSpeed, we need to scale down the floating point tensors to be able to do calculations with the model. If not using DeepSpeed, this argument takes no effect. safe_steps (`bool`, *optional*, defaults to `True`): Run safe training and validation steps to avoid OOMs (Out Of Memory errors) and retry steps. If a retry does not solve the problem, a list of users using GPUs will pop up and the OOM error will raise. destroy_after_training (`bool`, *optional*, defaults to `True`): Destroy the process group after training. Set to `False` if you're running multiple trainings in the same script. enable_prepare_logging (`bool`, *optional*, defaults to `False`): Enable internal model preparation logging. When using DeepSpeed, there are many messages that appear in the terminal that can be annoying. module_hooks (`bool`, *optional*, defaults to `True`): Whether to call the `before_eval` and `after_eval` hooks of the module. inference_mode (`bool`, *optional*, defaults to `True`): Whether to run evaluation in `torch.inference_mode` or simply `torch.no_grad`. This takes no effect if given `module` in `fit` function is an instance of `ExtendedAcceleratorModule`, since context manager needs to be given manually by the user. reset_optimizer_every_n_steps (`int`, *optional*, defaults to `None`): Reset optimizer state every N steps. kwargs (`Any`, *optional*): Extra arguments for specific `init` function in Tracker, e.g. `run_name`, `tags`, etc. """ # do some previous checks self.log_with = log_with.lower() if isinstance(log_with, str) else log_with self.tracker = _tracker_map[self.log_with]() if self.log_with is not None and DEBUG_MODE < 1 else None assert isinstance(hps_config, (str, dict, HyperParameters)), ( "'hps_config' needs to be either a string, dictionary or HyperParameters class." ) assert clip_grad is None or isinstance(clip_grad, float), "'clip_grad' argument needs to be a float." from . import IS_GPU, accelerator self.is_gpu = IS_GPU self.accelerator = accelerator self.hps = HyperParameters.from_config(hps_config) if isinstance(hps_config, (str, dict)) else hps_config self.track_name = track_name self.checkpoint_path = os.path.join(model_path, CHECKPOINT_DIR) self.model_path = model_path if type(resume) is int: if not multiple_checkpoints: raise ValueError( "Cannot specify a checkpoint index in 'resume' when 'multiple_checkpoints' is disabled." ) elif resume == 0 or resume < -1: raise ValueError( "Checkpoint index in 'resume' must be greater than 0 (or -1 to resume from latest checkpoint)." ) self.resume = ( ( resume if resume is not None else os.path.exists(self.checkpoint_path) and len(os.listdir(self.checkpoint_path)) > 0 ) if DEBUG_MODE < 3 else False ) self.metrics: dict[Any, list[Metric]] = metrics self.consolidate_metrics = consolidate_metrics self.additional_metric_consolidation = additional_metric_consolidation self.report_all_metrics = report_all_metrics self.disable_model_saving = disable_model_saving self.model_saving: dict[ str, tuple[float, float] ] = {} # key: model saving, value: (saving_below, saving_above) if patience is not None and isinstance(patience, int): self.patience = patience if patience is not None else -1 if self.patience == 0: raise ValueError("The 'patience' argument in Trainer should have a value greater than 0.") elif isinstance(patience, dict): for k, v in patience.items(): if v == 0: raise ValueError( "The 'patience' argument when declared as a dictionary needs to have values above 0. " f"Got {v} in '{k}'." ) elif patience is not None: raise ValueError("'patience' must be either an integer value or a dictionary.") else: self.patience = -1 self.evaluate_every_n_steps = evaluate_every_n_steps self.enable_checkpointing = enable_checkpointing if DEBUG_MODE < 3 else False self.multiple_checkpoints = multiple_checkpoints if max_checkpoints is not None and max_checkpoints <= 0: raise ValueError("'max_checkpoints' must be greater than 0 or `None`.") self.max_checkpoints = max_checkpoints self.checkpoint_every, self.checkpoint_strat = get_number_and_unit(checkpoint_every) self.logging_dir = logging_dir self.log_every = log_every self.grad_accumulation_steps = grad_accumulation_steps if grad_accumulation_steps is not None else 1 self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs self.clip_grad = clip_grad if clip_grad is not None else 0.0 if self.accelerator.distributed_type == DistributedType.DEEPSPEED: self.accelerator.deepspeed_plugin.deepspeed_config["gradient_clipping"] = self.clip_grad self.set_to_none = set_to_none self.shuffle_train = shuffle_train if sampler is not None and (sampler_train is not None or sampler_val is not None): raise ValueError("'sampler' cannot be declared along with 'sampler_train' or 'sampler_val'.") self.sampler = sampler self.sampler_train = sampler_train if sampler is None else sampler self.sampler_val = sampler_val if sampler is None else sampler if batch_sampler is not None and (batch_sampler_train is not None or batch_sampler_val is not None): raise ValueError( "'batch_sampler' cannot be declared along with 'batch_sampler_train' or 'batch_sampler_val'." ) self.batch_sampler = batch_sampler self.batch_sampler_train = batch_sampler_train if batch_sampler is None else batch_sampler self.batch_sampler_val = batch_sampler_val if batch_sampler is None else batch_sampler if collate_fn is not None and (collate_fn_train is not None or collate_fn_val is not None): raise ValueError("'collate_fn' cannot be declared along with 'collate_fn_train' or 'collate_fn_val'.") self.collate_fn = collate_fn self.collate_fn_train = collate_fn_train if collate_fn is None else collate_fn self.collate_fn_val = collate_fn_val if collate_fn is None else collate_fn self.max_shard_size = max_shard_size self.safe_serialization = safe_serialization self.compile = compile self.compile_kwargs = compile_kwargs if compile_kwargs is not None else {} self.safe_mode = safe_mode self.train_loss_metric_name = train_loss_metric_name self.val_loss_metric_name = val_loss_metric_name self.dataloader_pin_memory = dataloader_pin_memory if IS_GPU else False self.dataloader_num_workers = ( dataloader_num_workers if dataloader_num_workers is not None else self.accelerator.num_processes ) if (DEBUG_MODE > 0 and self.dataloader_num_workers != 0) or self.accelerator.num_processes == 1: # force when debugging to not have problems with dataloader during breakpoints self.dataloader_num_workers = 0 self.dataloader_drop_last = dataloader_drop_last self.eval_when_finish = eval_when_finish self.eval_when_start = eval_when_start if DEBUG_MODE < 4 else False self.monitor = monitor if isinstance(monitor, Monitor) else Monitor.from_config(monitor) self.cleanup_cache_every_n_steps = cleanup_cache_every_n_steps callback = callback if callback is not None else Callback() callback = callback if isinstance(callback, list) else [callback] self.callback = CallbackMaster(callback) self.additional_tracker_config = additional_tracker_config if additional_tracker_config is not None else {} self.batch_device_placement = batch_device_placement self.prepare_batch = prepare_batch self.safe_steps = safe_steps self.destroy_after_training = destroy_after_training self.enable_prepare_logging = enable_prepare_logging self.module_hooks = module_hooks self.inference_mode = inference_mode self.reset_optimizer_every_n_steps = reset_optimizer_every_n_steps self.init_kwargs = kwargs self.accelerator.project_configuration = ProjectConfiguration( project_dir=".", logging_dir=logging_dir, total_limit=1 ) self._logging = self.log_with is not None self.state = TrainingState() # adding a total (at maximum) of 64 bytes for additional tensors self.train_loss_state = LossState(self.accelerator, self.accelerator.device, self.log_every, pin_memory=IS_GPU) self.val_loss_state: dict[Any, LossState] = None # prepare val loss states in 'fit' function self._checkpointing_every_n_steps = self.enable_checkpointing and self.checkpoint_strat == "step" self._checkpointing_after_evaluation = self.enable_checkpointing and self.checkpoint_strat == "eval" self._checkpointing_when_epoch_ends = self.enable_checkpointing and self.checkpoint_strat == "epoch" self._module: AcceleratorModule = None self._scheduler: LRScheduler = None self._optimizer: Optimizer = None self._multiple_evaluations = False self._multiple_train_datasets = False self.unwrapped_model: nn.Module = None self.wrapped_model = None self.async_state = AsyncState(self.model_path) if ASYNC else None self.async_queue = AsyncDiskQueue(self.model_path, self.accelerator) if ASYNC else None self.tunnel = ModelTunnel(ASYNC_HASH) if ASYNC else None # initialize trackers self.run_id = None self.tracker_initialized = False if self._logging and DEBUG_MODE < 1 and ((ASYNC and ASYNC_TRAIN_GROUP) or not ASYNC): self.run_id = self._init_trackers() self.tracker_initialized = True self._model_dtype = torch.float32 self.do_sync = False self.accum_steps_done = 0 self._max_steps: int = None self._deepspeed_default_micro_batch_size = 1 self._consolidated_metrics: dict[str, list[float]] = defaultdict(list) self._debug_timings = DebugTimings(DEBUG_TIMINGS)
[docs] def fit( self, module: Union[AcceleratorModule, str, Union[tuple[str, str], tuple[str, Any]]], train_dataset: Optional[ Union[Dataset, list[Union[tuple[int, Dataset], tuple[int, Dataset, dict]]], _CurriculumLearning] ] = None, val_dataset: Optional[Union[Dataset, list[Dataset], dict[str, Dataset]]] = None, **kwargs: Any, ): """ Function to train a given `AcceleratorModule`. Args: module (`AcceleratorModule`, `str` or `tuple`): `AcceleratorModule` class containig the training logic. This can also be a string specifying a HuggingFace model, or a tuple of type (model, type), where 'model' is a string for the HuggingFace model, and 'type' is a string or class (from transformers library) for the model type. train_dataset (`torch.utils.data.Dataset` or `list`, *optional*, defaults to `None`): `Dataset` class from PyTorch containing the train dataset logic. If not provided, then `get_train_dataloader` from `module` will be used to get the train DataLoader. Can also be a list of tuples, in that case, the first element of each tuple is the maximum step for each dataset, and the second element is the `Dataset` to use, and optionally, a dictionary of keyword arguments for the dataloader as the third element. For more simple definitions, you can use an instance of `StepsCurriculum`, `RangeCurriculum` or `RatioCurriculum` from `accmt.curriculum`. val_dataset (`torch.utils.data.Dataset`, `list` or `dict`, *optional*, defaults to `None`): `Dataset` class from PyTorch containing the validation dataset logic. This can also be a list or a dictionary of `Dataset`, in that case, multiple evaluations will run following the logic of `validation_step` and specified metrics. Metric names reported for a multiple evaluation setting will add a '_' followed by a key related to the dataset (e.g. 'accuracy_1' or 'accuracy_another_dataset'). If this dataset is not specified, then the validation logic of `AcceleratorModule` (if specified) will be skipped. kwargs (`Any`): Keyword arguments for `from_pretrained` function for model initialization. """ # reset loss states in case of another fit function call in the script clear_device_cache(garbage_collection=True) self.train_loss_state.reset() if self.val_loss_state is not None: for v in self.val_loss_state.values(): v.reset() module = self._get_module(module, **kwargs) self._module = module self._module._model_path = self.model_path self._module._temp_path = os.path.join(self.model_path, f"_temp_state_{DIST_HASH}") model = module.model self.unwrapped_model = model if model is None or not isinstance(model, nn.Module): raise RuntimeError( "`AcceleratorModule` subclass requires `self.model` and needs to be an instance of `nn.Module`." ) teacher = module.teacher if torch.cuda.is_available(): model.to(self.accelerator.device) if teacher is not None: teacher.eval() teacher.to(self.accelerator.device) module.state = self.state module.accelerator = self.accelerator module.device = self.accelerator.device if MASTER_PROCESS and DEBUG_MODE < 3 and (self.enable_checkpointing or not self.disable_model_saving): os.makedirs(self.model_path, exist_ok=True) val_dataset = val_dataset if val_dataset is None or isinstance(val_dataset, (list, dict)) else [val_dataset] if val_dataset is not None and len(val_dataset) == 0: raise ValueError("'val_dataset' cannot be empty.") self._multiple_train_datasets = ( isinstance(train_dataset, (list, _CurriculumLearning)) and len(train_dataset) > 1 ) self._multiple_evaluations = val_dataset is not None and len(val_dataset) > 1 train_dataloader, val_dataloader = self._get_dataloaders(module, train_dataset, val_dataset) self.metrics = self._prepare_metrics(self.metrics, val_dataloader) if len(self.model_saving) == 0: if val_dataloader is not None: self.model_saving["best_valid_loss"] = (float("inf"), float("-inf")) else: self.model_saving["best_train_loss"] = (float("inf"), float("-inf")) if isinstance(self.patience, int): self.state.patience_left = {k: self.patience for k in self.model_saving.keys()} else: if not all(k in self.model_saving for k in self.patience.keys()): raise RuntimeError("Keys declared in 'patience' do not match model savings.") self.state.patience_left = {k: (self.patience.get(k, -1)) for k in self.model_saving.keys()} if self.metrics is not None: for k, v in self.metrics.items(): self.state.additional_metrics[k] = {m.main_metric: 0 for m in v} elif val_dataloader is not None: for k in val_dataloader.keys(): self.state.additional_metrics[k] = {} if self.resume: checkpoint_path = self._get_current_checkpoint_path() if checkpoint_path.endswith("checkpoint_0"): raise FileNotFoundError("Checkpoint directory is empty or not found.") training_state_path = os.path.join(checkpoint_path, STATE_FILE) loss_tracker_path = os.path.join(checkpoint_path, TRAIN_LOSS_STATE_FILE) self.state.load(training_state_path) self.train_loss_state.load(loss_tracker_path) if self.state.finished: raise RuntimeError("Training process has been flagged as finished.") module.state = self.state self.monitor._set_extra( self.accelerator, self.state, self.train_loss_metric_name, self.val_loss_metric_name, self.tracker ) self.monitor._tracking = self.tracker is not None self.val_loss_state = ( { k: LossState( self.accelerator, self.accelerator.device, -1, include_per_batch=False, pin_memory=self.is_gpu ) for k in val_dataloader.keys() } if val_dataloader is not None else None ) if self._multiple_train_datasets and self.hps.max_steps is None: raise ValueError("`max_steps` must be specified when using multiple training datasets.") length_first_train_dataloader = len(train_dataloader[0][1]) max_steps = ( self.hps.max_steps if self._multiple_train_datasets or self.hps.max_steps is not None else math.ceil( length_first_train_dataloader / (self.accelerator.num_processes * self.grad_accumulation_steps) ) * self.hps.epochs ) self._max_steps = max_steps optimizer = self._get_optimizer(module) if self.hps.step_scheduler_per_epoch: if self._multiple_train_datasets: raise RuntimeError("`step_scheduler_per_epoch` is incompatible with curriculum learning.") elif self.hps.max_steps is not None: raise RuntimeError("`step_scheduler_per_epoch` cannot be used if `max_steps` is specified.") scheduler = self._get_scheduler(module, optimizer, self.hps.epochs, self.hps.epochs) elif self.hps.max_steps is not None: if not self._multiple_train_datasets: # epochs are not taken into account if multiple training datasets are used steps_per_epoch = length_first_train_dataloader / ( self.accelerator.num_processes * self.grad_accumulation_steps ) self.hps.epochs = math.ceil(max_steps / steps_per_epoch) scheduler = self._get_scheduler(module, optimizer, max_steps, 1) # ignore epochs to avoid multiplication # avoid double evaluation at the end of training if max_steps == self.evaluate_every_n_steps: self.eval_when_finish = False else: scheduler = self._get_scheduler(module, optimizer, max_steps, self.hps.epochs) if ASYNC: if ASYNC_TRAIN_GROUP: self.tunnel.init(model) self.async_state.init() self.async_state.update(tunnel_ready=True, run_id=self.run_id) # only MASTER_PROCESS returns a valid 'run_id', and 'update' function already handles that. else: self.async_state.wait_for_tunnel() model, teacher, train_dataloader, val_dataloader, optimizer, scheduler = self._prepare( module, train_dataloader, val_dataloader, optimizer, scheduler, batch_device_placement=self.batch_device_placement, ) self._model_dtype = next(model.parameters()).dtype if ASYNC and not ASYNC_TRAIN_GROUP: # force train dataloader, optimizer and scheduler to be None in evaluation group since they're not being used. train_dataloader = None optimizer = None scheduler = None self._scheduler = scheduler self._optimizer = optimizer module.scheduler = scheduler module.optimizer = optimizer self.wrapped_model = model if self.log_every < 0: # report training loss at the last step (or end of an epoch) if self._multiple_train_datasets: raise RuntimeError("`log_every` cannot be negative if multiple training datasets are used.") self.log_every = math.ceil(length_first_train_dataloader / self.grad_accumulation_steps) module.log_every = self.log_every for callback in self.callback.children: callback.module = module callback.trainer = self callback.state = self.state self.callback.on_fit_start() if ASYNC and not ASYNC_TRAIN_GROUP: self.dispatch_async_eval(module, model, val_dataloader) else: self.loop(module, model, train_dataloader, val_dataloader, optimizer, scheduler) if ASYNC and ASYNC_TRAIN_GROUP: self.async_state.update(train_finished=True) # wait until evaluation group is finished while not self.async_state.evaluation_finished: time.sleep(0.5) # evaluation group delegates the job to train group eval_runs_pending = self.async_state.evaluations_in_queue if eval_runs_pending > 0: for _ in range(eval_runs_pending): self._async_eval(module, model, val_dataloader) self.tunnel.close() self.state.finished = True self.callback.on_fit_end() self.accelerator.free_memory(model, train_dataloader, val_dataloader, scheduler, optimizer, scheduler) if self.log_with is not None: self.accelerator.get_tracker(self.log_with).finish() if self.destroy_after_training and WORLD_SIZE > 1: # done to avoid pytorch distributed warnings if script finishes here dist.destroy_process_group() else: module.model = self.unwrapped_model
# TODO still getting memory leaks if running multiple trainings using the very same module def loop( self, module: AcceleratorModule, model: nn.Module, train_dataloader: list[tuple[int, DataLoader]], val_dataloader: Optional[dict[Any, DataLoader]], optimizer: Optimizer, scheduler: Optional[LRScheduler], ): """Runs a training loop.""" if self.state.evaluations_done == 0 and self.eval_when_start: self.launch_eval(module, model, val_dataloader, ignore_sync=True) for _ in self.epoch_iterator(): for batch in self.batch_iterator(module, train_dataloader, model): self._train_logic(module, model, optimizer, batch, scheduler) if ( self.evaluate_every_n_steps is not None and (self.state.global_step + 1) % self.evaluate_every_n_steps == 0 ): self.launch_eval(module, model, val_dataloader) if self.evaluate_every_n_steps is None or (self.eval_when_finish and self.state.is_last_epoch): self.launch_eval(module, model, val_dataloader) def dispatch_async_eval( self, module: AcceleratorModule, model: nn.Module, dataloader: dict[Any, DataLoader], delay: float = 0.1 ): while not self.async_state.train_finished: self._async_eval(module, model, dataloader) # continue checking for evaluations time.sleep(delay) self.async_state.update(evaluation_finished=True) def _async_eval(self, module: AcceleratorModule, model: nn.Module, dataloader: dict[str, DataLoader]): evals_in_queue = self.async_state.evaluations_in_queue if evals_in_queue > 0: # read last model from SHM self.tunnel.read(model) if evals_in_queue >= 2: # read next model from disk and write it into SHM state_dict = self.async_queue.dequeue() self.tunnel.write_state_dict(state_dict, non_blocking=True) self.async_state.update(evaluations_in_queue=-1) self.eval(module, model, dataloader) def launch_eval( self, module: AcceleratorModule, model: nn.Module, dataloader: dict[Any, DataLoader], ignore_sync: bool = False, ): if not self.do_sync and not ignore_sync: # launch evaluation only after gradient synchronization return if ASYNC: self.accelerator.wait_for_everyone() unwrapped_model = self.accelerator.unwrap_model(model) if self.async_state.evaluations_in_queue == 0: # SHM is free self.tunnel.write(unwrapped_model) else: # SHM waiting, then we write to disk self.async_queue.enqueue(unwrapped_model) self.async_state.update(evaluations_in_queue=1) self.accelerator.wait_for_everyone() else: self.eval(module, model, dataloader) should_save_model = not ( self.eval_when_start and self.state.evaluations_done == 1 ) # not doing first requested evaluation if self._checkpointing_after_evaluation and should_save_model: self._save_checkpoint( self.state.epoch + (0 if not self.state.is_end_of_epoch else 1), self.state.train_step + (1 if not self.state.is_end_of_epoch else 0), self.state.global_step + (1 if not self.state.is_end_of_epoch else 0), self.state.evaluations_done, finished=self.state.finished, ) def eval(self, module: AcceleratorModule, model: nn.Module, dataloader: Optional[dict[Any, DataLoader]]): """ NOTE: This function is only used in the training loop. Consider using `evaluate` instead. Runs evaluation on a given dataloader. """ no_patience_left = all(v == 0 for v in self.state.patience_left.values()) if DEBUG_MODE >= 5 or no_patience_left or dataloader is None: return if self.module_hooks: module.before_eval() if model.training: model.eval() for name, _ in module._registered_models: getattr(module, name).eval() clear_device_cache(garbage_collection=True) self.callback.on_evaluation_start() run_id = self.async_state.run_id if ASYNC and MASTER_PROCESS else None for k, val_dataloader in dataloader.items(): val_str = f" ({k}) " if self._multiple_evaluations else " " for i, batch in tqdm( iterable=enumerate(val_dataloader), total=len(val_dataloader), desc=f"📊{val_str}Evaluating in Epoch {self.state.epoch + 1}/{self.hps.epochs}", position=1, colour="cyan", **_tqdm_kwargs, ): self.state.val_step = i self.state.is_last_validation_batch = i == len(val_dataloader) - 1 batch = self._prepare_batch(batch) if self.prepare_batch else batch self._validation_logic(module, k, batch) self.state.additional_metrics[k]["valid_loss"] = self.val_loss_state[k].get_total_loss() with torch.inference_mode(): if self.metrics is not None: for metric in self.metrics[k]: if (not metric._parallel and MASTER_PROCESS) or metric._parallel: # we don't want to call '_compute' for metrics that are not implemented in main process, # since the state on other processes is empty if metric._per_batch: metric_dict = metric._get_metric_averages(clear=True) else: metric_dict = metric._compute() self.state.additional_metrics[k].update(metric_dict) # re-format metrics, instead of a dict dataset_key (key) and metrics (dictionary value), gather # all metrics into a single dictionary with the format {metric__dataset_key: value}. # e.g. {"accuracy__dataset1": 0.21, "accuracy__dataset2": 0.67} log_dict = {} for _metric_name, _value in self.state.additional_metrics[k].items(): if _metric_name.startswith("best_"): continue _metric_name = f"{_metric_name}__{k}" if self._multiple_evaluations else _metric_name log_dict[_metric_name] = _value if self.consolidate_metrics is not None: _log_dict = log_dict.copy() for metric_name, consolidate_fn in self.consolidate_metrics.items(): for local_metric_name, value in log_dict.items(): if consolidate_fn(local_metric_name): self._consolidated_metrics[metric_name].append(value) else: if not self.report_all_metrics: _log_dict.pop(local_metric_name) self.monitor.log_additional_metrics({local_metric_name: value}, run_id=run_id) log_dict = _log_dict if self.consolidate_metrics is None or self.report_all_metrics: self.monitor.log_additional_metrics(log_dict, run_id=run_id) if len(self._consolidated_metrics) > 0: if self.additional_metric_consolidation is not None: _additional_metric_consolidation = self.additional_metric_consolidation( dict(self._consolidated_metrics) ) self.monitor.log_additional_metrics(_additional_metric_consolidation, run_id=run_id) self._consolidated_metrics = {k: np.mean(v).item() for k, v in self._consolidated_metrics.items()} self.monitor.log_additional_metrics(self._consolidated_metrics, run_id=run_id) self._consolidated_metrics = defaultdict(list) self.state.evaluations_done += 1 if self.module_hooks: module.after_eval() should_save_model = not ( self.eval_when_start and self.state.evaluations_done == 1 ) # not doing first requested evaluation # save model if self.model_saving is not None and should_save_model and DEBUG_MODE < 3: self._save_model_on_criteria(model) else: # reset total loss state for validation since it's not being used with torch.inference_mode(): for k in self.val_loss_state.keys(): self.val_loss_state[k].total_loss.zero_() self.val_loss_state[k].num_steps.zero_() self.state.val_step = 0 # flag as finished if doing very last evaluation self.state.finished = self.state.is_last_training_batch and self.state.is_last_epoch self.callback.on_evaluation_end() def _save_model_on_criteria(self, model: nn.Module): """Save model depending on criteria defined in `model_saving`""" self.accelerator.wait_for_everyone() train_loss = self.train_loss_state.get_total_loss() can_save = not (self.eval_when_start and self.state.evaluations_done == 1) # TODO: add support to save other additional models if needed. def _check_and_save(model_saving: str): _model_saving = model_saving model_saving_without_prefix = model_saving.removeprefix("best_") # we already have all metrics calculated per dataset in self.state.additional_metrics metrics_and_datasets = defaultdict( list ) # e.g. {"accuracy": ["dataset1", "dataset2"], "metric": [dataset_keys, ...]} for metric in model_saving_without_prefix.split("/"): metric, *datasets = metric.split("@") if len(datasets) == 0: # if datasets are not specified for a metric, then it means that we need to average # across all datasets datasets = self.state.additional_metrics.keys() for dataset in datasets: metrics_and_datasets[metric].append(dataset) # now create a buffer per metric, where each value in the buffer corresponds to the # metric found in a dataset metric_buffer = defaultdict(list) # e.g. {"accuracy": [0.2, 0.5], "metric": [values, ...]} for dataset_key, metrics_dict in self.state.additional_metrics.items(): for metric, value in metrics_dict.items(): if dataset_key in set(metrics_and_datasets[metric]): metric_buffer[metric].append(value) # now average those metrics in buffer metric_avgs = {k: (np.mean(v) if len(v) > 1 else v[0]) for k, v in metric_buffer.items()} _metrics = [ms.split("@")[0] for ms in model_saving_without_prefix.split("/")] count = 0 for metric in _metrics: best_metric_str = f"best_{metric}" comparator = self._get_comparator(metric) if metric != "valid_loss" else "<" compare = _operator_map[comparator] new = metric_avgs[metric] # calculate average between previous metrics in wanted datasets prev = [] for dataset_key in set(metrics_and_datasets[metric]): if best_metric_str not in self.state.additional_metrics[dataset_key]: # only register best metrics in wanted datasets self.state.additional_metrics[dataset_key][best_metric_str] = ( float("inf") if comparator in {"<", "<=", "=="} else float("-inf") ) prev.append(self.state.additional_metrics[dataset_key][best_metric_str]) prev = np.mean(prev) if len(prev) > 1 else prev[0] saving_below, saving_above = self.model_saving[_model_saving] is_better = compare(new, prev) and new < saving_below and new > saving_above if is_better: count += 1 # register best metrics for all wanted datasets for dataset, new_metric_calculated in zip(metrics_and_datasets[metric], metric_buffer[metric]): local_prev = self.state.additional_metrics[dataset][best_metric_str] if compare(new_metric_calculated, local_prev): self.state.additional_metrics[dataset][best_metric_str] = new_metric_calculated if count == len(_metrics): # all these metrics have improved if MASTER_PROCESS and can_save and not self.disable_model_saving: model_path = os.path.join(self.model_path, _model_saving.replace("/", "__")) self._save_model(model, model_path) elif can_save and self.state.patience_left[_model_saving] > 0: self.state.patience_left[_model_saving] -= 1 if len(self.model_saving) > 0: if train_loss < self.state.best_train_loss: self.state.best_train_loss = train_loss if ( "best_train_loss" in self.model_saving and MASTER_PROCESS and can_save and not self.disable_model_saving ): model_path = os.path.join(self.model_path, "best_train_loss") self._save_model(model, model_path) elif ( can_save and "best_train_loss" in self.model_saving and self.state.patience_left["best_train_loss"] > 0 ): self.state.patience_left["best_train_loss"] -= 1 for model_saving in self.model_saving: # TODO we could implement a tqdm bar maybe... _check_and_save(model_saving) # if all model savings have no patience anymore, finish training process count = 0 for model_saving in self.model_saving.keys(): count += self.state.patience_left[model_saving] == 0 self.accelerator.wait_for_everyone() if count == len(self.model_saving): rprint("Ran out of patience. Process finished.") self.state.finished = True if MASTER_PROCESS: state_in_checkpoint = os.path.join(self.checkpoint_path, STATE_FILE) self.state.save(state_in_checkpoint) for model_saving in self.model_saving: model_saving_path = os.path.join(self.model_path, model_saving) os.makedirs(model_saving_path, exist_ok=True) model_saving_path = os.path.join(model_saving_path, STATE_FILE) self.state.save(model_saving_path) self.accelerator.end_training() sys.exit(0) def _save_model(self, model: nn.Module, path: str): """Save model inside a path.""" tqdm.write(f"\r{get_time_prefix()} Saving model...") os.makedirs(path, exist_ok=True) unwrapped_model = self.accelerator.unwrap_model(model, keep_torch_compile=False) state_dict = unwrapped_model.state_dict() if hasattr(unwrapped_model, "save_pretrained"): # special function for models from transformers library unwrapped_model.save_pretrained( path, is_main_process=True, state_dict=state_dict, max_shard_size=self.max_shard_size, save_function=self.accelerator.save, safe_serialization=self.safe_serialization, ) else: pt_state_dict = os.path.join(path, "pytorch_model.pt") self.accelerator.save(state_dict, pt_state_dict, safe_serialization=self.safe_serialization) training_state_path = os.path.join(path, STATE_FILE) self.state.save(training_state_path) tqdm.write(f"\033[A\033[K{get_time_prefix()} Model saved.") def _validation_logic(self, module: AcceleratorModule, dataloader_key: Any, batch: Any): """Runs all the validation logic.""" no_grad_context = torch.inference_mode if self.inference_mode else torch.no_grad with no_grad_context() if not self._module._extended else nullcontext(): self.callback.on_before_validation_step(batch) if self.safe_steps: metrics = self._safe_step(module.validation_step, dataloader_key, batch) else: metrics = module.validation_step(dataloader_key, batch) if isinstance(metrics, torch.Tensor): # assume it's loss value, so convert wrap it into a dictionary metrics = {"loss": metrics} self.callback.on_after_validation_step() # track loss loss = metrics["loss"].detach() self.val_loss_state[dataloader_key].add_total_loss(loss) # track metrics if self.metrics is not None: for metric in self.metrics[dataloader_key]: if metric.main_metric not in metrics: raise RuntimeError("Make sure to align 'validation_step' with declared metrics.") metric_compute_arguments = metrics[metric.main_metric] if not isinstance(metric_compute_arguments, tuple): metric_compute_arguments = (metric_compute_arguments,) if not metric._parallel: # check if any argument is on CPU _runtime_error = RuntimeError( "Metric arguments must be on GPU. If they are not, you can use `MetricParallel` " "or `MetricBatch` instead." ) for arg in metric_compute_arguments: if isinstance(arg, dict): for v in arg.values(): if isinstance(v, torch.Tensor) and v.device.type == "cpu": raise _runtime_error elif isinstance(arg, torch.Tensor) and arg.device.type == "cpu": raise _runtime_error elif not isinstance(arg, (dict, torch.Tensor)): raise RuntimeError(f"Metric argument {arg} is not a dictionary or a tensor.") metric_compute_arguments = ( *( ( all_gather_dictionary(arg) if isinstance(arg, dict) else self.accelerator.gather_for_metrics(arg) ) for arg in metric_compute_arguments ), # leave it as tuple ) if MASTER_PROCESS and metric_compute_arguments[0] is not None: metric.add_batch(*metric_compute_arguments) elif metric_compute_arguments[0] is not None: metric.add_batch(*metric_compute_arguments) if metric._per_batch: metric._compute() def _prepare_batch(self, batch: Any) -> Any: """ Prepare elements in a batch based on Mixed Precision. This function only takes effect when using DeepSpeed. """ if self.accelerator.distributed_type != DistributedType.DEEPSPEED: return batch return self._prepare_nested_batch(batch) def _prepare_nested_batch(self, batch: Any) -> Any: """ Prepare nested batch. This function is derived from `transformers` library (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py). """ if isinstance(batch, Mapping): return type(batch)({k: self._prepare_nested_batch(v) for k, v in batch.items()}) elif isinstance(batch, (tuple, list)): return type(batch)(self._prepare_nested_batch(v) for v in batch) elif isinstance(batch, torch.Tensor): kwargs = {"device": self.accelerator.device} if torch.is_floating_point(batch) or torch.is_complex(batch): kwargs.update({"dtype": self._model_dtype}) return batch.to(**kwargs) return batch def _safe_step(self, fn: Callable, *args, **kwargs) -> Union[torch.Tensor, dict, Any]: try: return fn(*args, **kwargs) except RuntimeError as e: if "out of memory" in str(e): for p in self.wrapped_model.parameters(): if p.grad is not None: del p.grad torch.cuda.empty_cache() try: return fn(*args, **kwargs) except RuntimeError as _e: rprint("CUDA: Out Of Memory.") if "out of memory" in str(_e): print_gpu_users_by_device() if self.tracker is not None: self.tracker.end(status="FAILED") if WORLD_SIZE > 1: dist.destroy_process_group() sys.exit(1) else: raise e # noqa: TRY201 def _train_logic( self, module: AcceleratorModule, model: nn.Module, optimizer: Optimizer, batch: Any, scheduler: Optional[LRScheduler], ): """Runs all the training logic.""" if ASYNC and not ASYNC_TRAIN_GROUP: return # code snippet taken from https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2545 no_sync_context = ( functools.partial(self.accelerator.no_sync, model=model) if self.accelerator.distributed_type != DistributedType.DEEPSPEED and not self.state.is_last_training_batch else nullcontext ) with no_sync_context(): self.callback.on_before_training_step(batch) # forward pass with self._debug_timings.record_times("training_step"): if self.safe_steps: loss = self._safe_step(module.training_step, batch) else: loss = module.training_step(batch) if self.grad_accumulation_steps > 1: # normalize loss by the number of gradient accumulation steps loss /= self.grad_accumulation_steps self.callback.on_after_training_step() # track _loss = loss.detach() self.train_loss_state.add_batch_loss(_loss) self.train_loss_state.add_total_loss(_loss) self.callback.on_before_backward(loss) if not module._extended: # backpropagation kwargs = {} if self.grad_accumulation_steps > 1 and self.accelerator.distributed_type == DistributedType.DEEPSPEED: # disable gradient scaling when using gradient accumulation and DeepSpeed: # https://github.com/huggingface/transformers/pull/35808 kwargs["scale_wrt_gas"] = False with self._debug_timings.record_times("backward"): self.accelerator.backward(loss, **kwargs) self.callback.on_after_backward() if self.do_sync: if self.grad_accumulation_steps > 1: with torch.inference_mode(): self.train_loss_state.num_batches -= self.accum_steps_done self.train_loss_state.num_steps -= self.accum_steps_done norm = None if self.clip_grad > 0.0 and self.accelerator.distributed_type != DistributedType.DEEPSPEED: norm = self.accelerator.clip_grad_norm_(model.parameters(), self.clip_grad) if (self.state.global_step + 1) % self.log_every == 0: batch_loss = self.train_loss_state.get_batch_loss() if MASTER_PROCESS and self.monitor.grad_norm and norm is None: norm = self._get_grad_norm() self.monitor.log_train_loss_and_grad_norm(batch_loss, norm) if not module._extended: self.callback.on_before_optimizer_step(optimizer) with self._debug_timings.record_times("optimizer_step"): optimizer.step() self.callback.on_after_optimizer_step(optimizer) if scheduler is not None and not self.hps.step_scheduler_per_epoch: self.callback.on_before_scheduler_step(scheduler) with self._debug_timings.record_times("scheduler_step"): scheduler.step() self.callback.on_after_scheduler_step(scheduler) # reset gradients self.callback.on_before_zero_grad(optimizer) with self._debug_timings.record_times("zero_grad"): optimizer.zero_grad(set_to_none=self.set_to_none) self.callback.on_after_zero_grad(optimizer) if ( self.reset_optimizer_every_n_steps is not None and (self.state.global_step + 1) % self.reset_optimizer_every_n_steps == 0 ): optimizer.state.clear() self.accum_steps_done = 0 else: self.accum_steps_done += 1 self._debug_timings.print_cache(reset_cache=True) def batch_iterator(self, module: AcceleratorModule, dataloader: list[tuple[int, DataLoader]], model: nn.Module): """Batch iterator for training handling checkpointing.""" if not model.training: model.train() for name, _ in module._registered_models: getattr(module, name).train() if self.shuffle_train: global_seed = get_seed(default=0) set_seed(global_seed + self.state.epoch) for _, dl in dataloader: if hasattr(dl, "set_epoch"): dl.set_epoch(self.state.epoch) elif hasattr(dl.batch_sampler, "set_epoch"): dl.batch_sampler.set_epoch(self.state.epoch) for idx, (_, dl) in enumerate(dataloader): if self.state.train_dataloader_idx == idx: _dataloader = self.accelerator.skip_first_batches(dl, self.state.train_step) break clear_device_cache(garbage_collection=True) start = self.state.train_step # determine total steps for the current epoch if self._multiple_train_datasets: total_steps_in_epoch = self._max_steps else: total_steps_in_epoch = math.ceil(len(dataloader[0][1]) / self.grad_accumulation_steps) # calculate remaining steps in current epoch remaining_steps = total_steps_in_epoch - start # for progress bar, use max_steps if defined, otherwise use dataloader length progress_total = self.hps.max_steps if self.hps.max_steps is not None else total_steps_in_epoch progress_initial = self.state.global_step if self.hps.max_steps is not None else start training_dataloader_pbar = None if remaining_steps > 0: # TODO: add support for multiple iterations over different datasets without going back to epoch iterator # since epochs are not considered when using multiple training datasets _tqdm_str = ( f"🚀 Training in Epoch {self.state.epoch + 1}/{self.hps.epochs}" if not self._multiple_train_datasets else "🚀 Training" ) training_dataloader_pbar = tqdm( iterable=range(progress_total), # dummy iterable to show progress bar total=progress_total, initial=progress_initial, desc=_tqdm_str, position=0, colour="green", **_tqdm_kwargs, ) train_step = 0 for dl_idx, (max_step, dl) in enumerate(dataloader): if self._multiple_train_datasets and dl_idx != self.state.train_dataloader_idx: continue # skip if not the current dataloader repeat_dataloader = True while repeat_dataloader: # for curriculum learning, we need to repeat the dataloader in case # the user wants more steps on this data last_dataloader_batch = False should_finish_iteration = False for i, batch in enumerate(dl): last_dataloader_batch = i == len(dl) - 1 self.state.train_step = train_step self.state.is_last_training_batch = ( self.state.is_last_epoch and i == total_steps_in_epoch - 1 ) or (self.hps.max_steps is not None and self.state.global_step + 1 >= self.hps.max_steps) self.do_sync = ( self.state.batch_iteration + 1 ) % self.grad_accumulation_steps == 0 or self.state.is_last_training_batch self.accelerator.gradient_state._set_sync_gradients(self.do_sync) if (self.state.global_step + 1) % self.log_every == 0 and self.do_sync: lr = ( self._scheduler.get_last_lr()[-1] if self._scheduler is not None else self._optimizer.param_groups[0]["lr"] ) # TODO we can fuse these functions to only report once to the server self.monitor.log_learning_rate(lr) self.monitor.log_cpu_utilization() self.monitor.log_gpu_utilization() batch = self._prepare_batch(batch) if self.prepare_batch else batch with self._debug_timings.record_times("batch_iteration"): yield batch if ( self.cleanup_cache_every_n_steps is not None and (self.state.global_step + 1) % self.cleanup_cache_every_n_steps == 0 and self.do_sync ): clear_device_cache(garbage_collection=True) if ( self._checkpointing_every_n_steps and (self.state.global_step + 1) % self.checkpoint_every == 0 and not self.state.is_last_training_batch and self.do_sync ): self._save_checkpoint( self.state.epoch, self.state.train_step + 1, self.state.global_step + 1, self.state.evaluations_done, ) self.state.batch_iteration += 1 train_step += 1 if self.state.batch_iteration % self.grad_accumulation_steps == 0: self.state.global_step += 1 training_dataloader_pbar.update(1) should_finish_iteration = self._multiple_train_datasets and self.state.global_step >= max_step # check if we've reached max_steps for current dataloader if should_finish_iteration or ( self.hps.max_steps is not None and self.state.global_step >= self.hps.max_steps ): break repeat_dataloader = ( self._multiple_train_datasets and not last_dataloader_batch and not should_finish_iteration ) self.state.train_dataloader_idx += 1 if training_dataloader_pbar is not None: training_dataloader_pbar.close() self.state.is_end_of_epoch = True self.state.train_step = 0 def _save_checkpoint( self, epoch: int, train_step: int, global_step: int, evaluations_done: int, finished: bool = False ): """Save checkpoint at a given point in time (`epoch` and `train_step`).""" self.callback.on_save_checkpoint() if MASTER_PROCESS: tqdm.write(f"\r{get_time_prefix()} Saving checkpoint...") os.makedirs(self.checkpoint_path, exist_ok=True) self.accelerator.wait_for_everyone() checkpoint_path = self.checkpoint_path if self.multiple_checkpoints: if ( MASTER_PROCESS and self.max_checkpoints is not None and len(os.listdir(checkpoint_path)) >= self.max_checkpoints ): min_checkpoint = min(os.listdir(checkpoint_path), key=lambda x: int(x.split("_")[-1])) shutil.rmtree(os.path.join(checkpoint_path, min_checkpoint)) last_checkpoint_num = int(self._get_current_checkpoint_path(ignore_resume_idx=True).split("_")[-1]) new_checkpoint_path = os.path.join(checkpoint_path, f"checkpoint_{last_checkpoint_num + 1}") if MASTER_PROCESS: os.makedirs(new_checkpoint_path, exist_ok=True) checkpoint_path = new_checkpoint_path self.accelerator.save_state(checkpoint_path, safe_serialization=self.safe_serialization) if len(self._module._registered_accelerators) > 0: additional_checkpoint_path = os.path.join(checkpoint_path, "additional_checkpoints") if MASTER_PROCESS: os.makedirs(additional_checkpoint_path, exist_ok=True) seen = set() count = 1 for local_accelerator in self._module._registered_accelerators.values(): id_local_accelerator = id(local_accelerator) if id_local_accelerator != id(self.accelerator) and id_local_accelerator not in seen: seen.add(id_local_accelerator) local_path = os.path.join(additional_checkpoint_path, f"accelerator{count}") local_accelerator.save_state(local_path, safe_serialization=self.safe_serialization) count += 1 loss_tracker_path = os.path.join(checkpoint_path, TRAIN_LOSS_STATE_FILE) self.train_loss_state.save(loss_tracker_path) self.state.num_checkpoints_made += 1 if MASTER_PROCESS: training_state_dict = self.state.to_dict() training_state_dict["epoch"] = epoch training_state_dict["train_step"] = train_step training_state_dict["global_step"] = global_step training_state_dict["evaluations_done"] = evaluations_done training_state_dict["finished"] = finished training_state_path = os.path.join(checkpoint_path, STATE_FILE) self.state.save(training_state_path, training_state_dict) tqdm.write(f"\033[A\033[K{get_time_prefix()} Checkpoint saved.") self.monitor.log_checkpoint() def epoch_iterator(self): """Epoch iterator handling logic for checkpointing.""" start = self.state.epoch for epoch in range(start, self.hps.epochs): self.state.epoch = epoch self.monitor.log_epoch(epoch) self.state.is_end_of_epoch = False self.state.is_last_epoch = epoch == self.hps.epochs - 1 self.callback.on_epoch_start() yield epoch self.callback.on_epoch_end() if not self._module._extended and self._scheduler is not None and self.hps.step_scheduler_per_epoch: self.callback.on_before_scheduler_step(self._scheduler) self._scheduler.step() self.callback.on_after_scheduler_step(self._scheduler) if self._checkpointing_when_epoch_ends and (self.state.epoch + 1) % self.checkpoint_every == 0: self._save_checkpoint( self.state.epoch + 1, self.state.train_step, # always 0 at this stage self.state.global_step, self.state.evaluations_done, # flag as finished if checkpointing at the end of the last epoch finished=self.state.is_last_epoch, ) def _prepare_additional_models(self, module: AcceleratorModule, train_dataloaders: list[DataLoader]): # in-place modification of the module if len(module._registered_models) == 0: return for idx, (additional_model_name, additional_model) in enumerate(module._registered_models): if additional_model is None: raise RuntimeError(f"Model '{additional_model_name}' was registered but is `None`.") additional_optimizer_name, additional_optimizer = module._registered_optimizers[idx] additional_scheduler_name, additional_scheduler = module._registered_schedulers[idx] if additional_scheduler is not None and additional_optimizer is None: raise RuntimeError( f"Model '{additional_model_name}' was registered with a scheduler but no optimizer." ) if additional_optimizer is not None: accelerator = self.accelerator if self.accelerator.distributed_type == DistributedType.DEEPSPEED: accelerator = Accelerator() additional_model, additional_optimizer, additional_scheduler, _ = accelerator.prepare( additional_model, additional_optimizer, additional_scheduler, train_dataloaders ) module._registered_accelerators[id(additional_model)] = accelerator module._registered_accelerators[id(additional_optimizer)] = accelerator module._registered_accelerators[id(additional_scheduler)] = accelerator setattr(module, additional_model_name, additional_model) if additional_optimizer is not None: setattr(module, additional_optimizer_name, additional_optimizer) if additional_scheduler is not None: setattr(module, additional_scheduler_name, additional_scheduler) def _prepare_ddp( self, module: AcceleratorModule, optimizer: Optimizer, scheduler: LRScheduler, train_dataloaders: list[DataLoader], ): module.model, optimizer, scheduler, *train_dataloaders = self.accelerator.prepare( module.model, optimizer, scheduler, *train_dataloaders ) module.model = _DistributedDataParallel(module.model) if module.teacher is not None: module.teacher = self.accelerator.prepare_model(module.teacher) return module, optimizer, scheduler, train_dataloaders def _prepare_fsdp( self, module: AcceleratorModule, optimizer: Optimizer, scheduler: LRScheduler, train_dataloaders: list[DataLoader], ): # preparing model before dataloaders is only supported by FSDP apparently, and this is the # recommended setting to prepare training. module.model = self.accelerator.prepare_model(module.model) # ignore model preparation since it was already done before (only in the case of FSDP) optimizer, scheduler, *train_dataloaders = self.accelerator.prepare(optimizer, scheduler, *train_dataloaders) if module.teacher is not None: module.teacher = self.accelerator.prepare_model(module.teacher) return module, optimizer, scheduler, train_dataloaders def _prepare_deepspeed( self, module: AcceleratorModule, optimizer: Optimizer, scheduler: LRScheduler, train_dataloaders: list[DataLoader], ): if not self.enable_prepare_logging: from deepspeed.utils import logger logger.setLevel(logging.WARNING) # DeepSpeed requires contiguous parameters for param in module.model.parameters(): if not param.is_contiguous(): param.data = param.data.contiguous() module.model, optimizer, scheduler, *train_dataloaders = self.accelerator.prepare( module.model, optimizer, scheduler, *train_dataloaders ) # TODO DeepSpeed does not support a model without an optimizer in this setting, so we leave # the teacher model as is (i.e. a replica per process). return module, optimizer, scheduler, train_dataloaders def _prepare( self, module: AcceleratorModule, train_dataloader: Optional[list[tuple[int, DataLoader]]], val_dataloader: Optional[dict[Any, DataLoader]], optimizer: Optional[Optimizer], scheduler: Optional[LRScheduler], batch_device_placement: bool = True, ) -> tuple[ nn.Module, Optional[nn.Module], Optional[list[tuple[int, DataLoader]]], Optional[dict[Any, DataLoader]], Optional[Optimizer], Optional[LRScheduler], ]: """ Call Accelerate's backend to prepare instances for distributed training. This will also load states for objects in case of resuming training. """ if self.gradient_checkpointing and hasattr(module.model, "gradient_checkpointing_enable"): module.model.gradient_checkpointing_enable(self.gradient_checkpointing_kwargs) if self.compile and DEBUG_MODE < 2: module.compile() if val_dataloader is not None: for k, dataloader in val_dataloader.items(): val_dataloader[k] = self.accelerator.prepare_data_loader(dataloader) # prepare the dataloader even if it is a custom batch sampler to avoid some errors when using DeepSpeed # and only evaluating if ( hasattr(dataloader.batch_sampler, "_custom_batch_sampler") and dataloader.batch_sampler._custom_batch_sampler ): val_dataloader[k] = dataloader # go back to the original dataloader _train_dataloaders = [dl for _, dl in train_dataloader] if train_dataloader is not None else [None] if self.accelerator.distributed_type == DistributedType.FSDP: _prepare_fn = self._prepare_fsdp elif self.accelerator.distributed_type == DistributedType.DEEPSPEED: _prepare_fn = self._prepare_deepspeed else: _prepare_fn = self._prepare_ddp module, optimizer, scheduler, train_dataloaders = _prepare_fn(module, optimizer, scheduler, _train_dataloaders) self._prepare_additional_models(module, train_dataloaders) if train_dataloaders[0] is not None: if len(train_dataloaders) == 1: # here we change -1 to the actual max step train_dataloader[0] = (self._max_steps, train_dataloader[0][1]) max_steps = [max_step for max_step, _ in train_dataloader] train_dataloader = [ ( max_step, ( orig_dataloader if hasattr(orig_dataloader.batch_sampler, "_custom_batch_sampler") and orig_dataloader.batch_sampler._custom_batch_sampler else dataloader ), ) for max_step, dataloader, orig_dataloader in zip(max_steps, train_dataloaders, _train_dataloaders) ] if scheduler is not None: self.accelerator.register_for_checkpointing(scheduler) # load states if resuming if self.resume: self.callback.on_resume() if os.path.exists(self.checkpoint_path): checkpoint_path = self._get_current_checkpoint_path() if checkpoint_path.endswith("checkpoint_0"): raise FileNotFoundError("Checkpoint directory is empty or not found.") self.accelerator.load_state(checkpoint_path) if len(self._module._registered_accelerators) > 0: additional_checkpoint_path = os.path.join(checkpoint_path, "additional_checkpoints") if os.path.exists(additional_checkpoint_path): seen = set() for local_accelerator in self._module._registered_accelerators.values(): id_local_accelerator = id(local_accelerator) if id_local_accelerator != id(self.accelerator) and id_local_accelerator not in seen: seen.add(id_local_accelerator) count = len(seen) local_path = os.path.join(additional_checkpoint_path, f"accelerator{count}") if os.path.exists(local_path): local_accelerator.load_state(local_path) else: raise FileNotFoundError( f"Additional checkpoint for accelerator {count} was not found." ) else: raise FileNotFoundError(f"Additional checkpoints folder was not found in '{checkpoint_path}'.") else: raise FileNotFoundError(f"'{self.checkpoint_path}' was not found.") if not batch_device_placement: cpu = torch.device("cpu") if train_dataloader[0] is not None: for _, dl, _ in train_dataloader: dl.device = cpu if val_dataloader is not None: for dl in val_dataloader.values(): dl.device = cpu train_dataloader = train_dataloader if train_dataloader[0] is not None else None module._prepared = True return module.model, module.teacher, train_dataloader, val_dataloader, optimizer, scheduler def _get_current_checkpoint_path(self, ignore_resume_idx: bool = False) -> str: """ Get the checkpoint path based on the 'resume' argument or the latest checkpoint. If this returns a path ending with "checkpoint_0", it means that the checkpoint directory is empty or not found. """ checkpoint_path = self.checkpoint_path if self.multiple_checkpoints: num_checkpoints = len(os.listdir(checkpoint_path)) if os.path.exists(checkpoint_path) else 0 if num_checkpoints > 0: if type(self.resume) is int and self.resume != -1 and not ignore_resume_idx: # load the checkpoint at the given index checkpoint_path = os.path.join(checkpoint_path, f"checkpoint_{self.resume}") else: # find the latest checkpoint by getting the maximum checkpoint number latest_checkpoint = max(os.listdir(checkpoint_path), key=lambda x: int(x.split("_")[-1])) checkpoint_path = os.path.join(checkpoint_path, latest_checkpoint) else: # to handle creation afterwards checkpoint_path = os.path.join(checkpoint_path, "checkpoint_0") return checkpoint_path def _get_optimizer(self, module: AcceleratorModule) -> Optimizer: """Get optimizer from either module or trainer.""" optimizer = module.get_optimizer() if optimizer is None: optimizer = self.hps.optimizer fused_available = "fused" in inspect.signature(optimizer).parameters optim_kwargs = self.hps.optim_kwargs optim_kwargs["fused"] = fused_available and "cuda" in self.accelerator.device.type filtered_kwargs = filter_kwargs(optim_kwargs, optimizer) optimizer = optimizer(module.model.parameters(), **filtered_kwargs) return optimizer def _get_scheduler( self, module: AcceleratorModule, optimizer: Optimizer, num_training_steps: int, num_epochs: int ) -> Optional[LRScheduler]: """Get scheduler from either module or trainer.""" scheduler = module.get_scheduler(optimizer, num_training_steps, num_epochs) if self.hps.scheduler is not None and scheduler is None: schlr_kwargs = self.hps.scheduler_kwargs schlr_kwargs["last_epoch"] = -1 schlr_kwargs["steps_per_epoch"] = num_training_steps total_steps = num_training_steps * num_epochs schlr_kwargs["num_training_steps"] = total_steps schlr_kwargs["epochs"] = num_epochs if "num_warmup_steps" in schlr_kwargs and isinstance(schlr_kwargs["num_warmup_steps"], float): if schlr_kwargs["num_warmup_steps"] < 0.0 or schlr_kwargs["num_warmup_steps"] > 1.0: raise ValueError( "If 'num_warmup_steps' is a ratio (float value), it needs to be a value between 0 and 1." ) schlr_kwargs["num_warmup_steps"] = math.ceil(total_steps * schlr_kwargs["num_warmup_steps"]) elif "warmup_ratio" in schlr_kwargs: if schlr_kwargs["warmup_ratio"] > 1.0: raise ValueError( "'warmup_ratio' value in scheduler configuration needs to be a value between 0 and 1." ) schlr_kwargs["num_warmup_steps"] = math.ceil(total_steps * schlr_kwargs["warmup_ratio"]) scheduler = self.hps.scheduler filtered_kwargs = filter_kwargs(schlr_kwargs, scheduler) scheduler = scheduler(optimizer, **filtered_kwargs) return scheduler def _get_dataloaders( self, module: AcceleratorModule, train_dataset: Optional[ Union[Dataset, list[Union[tuple[int, Dataset], tuple[int, Dataset, dict]]], _CurriculumLearning] ] = None, val_dataset: Optional[Union[list[Dataset], dict[Any, Dataset]]] = None, ) -> tuple[Optional[list[tuple[int, DataLoader]]], Optional[dict[Any, DataLoader]]]: """ Get DataLoaders for training and validation. Each DataLoader will be wrapped in a dictionary. """ is_tuple = hasattr(self.hps.batch_size, "__len__") if is_tuple and len(self.hps.batch_size) != 2: raise ValueError( "'batch_size' in hyper parameters needs to be an integer value or a tuple with 2 values " "(one for training and the other for validation)." ) train_batch_size = self.hps.batch_size[0] if is_tuple else self.hps.batch_size val_batch_size = self.hps.batch_size[1] if is_tuple else self.hps.batch_size dl_args = { "pin_memory": self.dataloader_pin_memory, "num_workers": self.dataloader_num_workers, "drop_last": self.dataloader_drop_last, } train_dataloader = module.get_train_dataloader(train_dataset) assert train_dataloader is not None or train_dataset is not None, ( "Either 'train_dataset' or 'get_train_dataloader' must be given." ) # ignoring 'train_dataset' if 'get_train_dataloader' was implemented in AcceleratorModule if train_dataset is not None and train_dataloader is None: shuffle_train = ( self.shuffle_train if self.sampler_train is None and self.batch_sampler_train is None else None ) dl_train_kwargs = { "shuffle": shuffle_train, "sampler": self.sampler_train, "batch_size": train_batch_size, "collate_fn": self.collate_fn_train, "batch_sampler": self.batch_sampler_train, **dl_args, } if isinstance(train_dataset, Dataset): # -1 will be converted dynamically afterwards train_dataloader = [(-1, DataLoader(train_dataset, **dl_train_kwargs))] elif isinstance(train_dataset, _CurriculumLearning): train_dataset.convert_datasets_to_dataloaders(**dl_train_kwargs) train_dataloader = train_dataset.convert_to_max_step_per_dataloader(self.hps.max_steps) elif isinstance(train_dataset, list): # if there are only 2 elements in a tuple, the third element an empty dataloader kwargs for i, _tuple in enumerate(train_dataset): if len(_tuple) == 2: train_dataset[i] = (*(_tuple), {}) train_dataloader = [] for max_step, dataset, dataloader_kwargs in train_dataset: if not isinstance(dataloader_kwargs, dict): raise TypeError( "If 'train_dataset' is a list of tuples, the third element must be a dictionary of " "keyword arguments for the dataloader." ) # update global dataloader kwargs without modifying the original dict dl_specific_kwargs = {**dl_train_kwargs} dl_specific_kwargs.update(dataloader_kwargs) train_dataloader.append((max_step, DataLoader(dataset, **dl_specific_kwargs))) else: raise TypeError(f"Invalid type for 'train_dataset': {type(train_dataset)}") # remove the third element of the tuples (if exist) train_dataloader = [(_tuple[0], _tuple[1]) for _tuple in train_dataloader] val_dataloader = module.get_validation_dataloader(val_dataset) if val_dataloader is not None and not isinstance(val_dataloader, (list, dict)): val_dataloader = [val_dataloader] # ignoring 'val_dataset' if 'get_validation_dataloader' was implemented in AcceleratorModule if val_dataset is not None and val_dataloader is None: val_dataset = ( val_dataset if isinstance(val_dataset, dict) else {str(i): ds for i, ds in enumerate(val_dataset)} ) val_dataloader = {} for k, dataset in val_dataset.items(): val_dataloader[k] = DataLoader( dataset, batch_size=val_batch_size, collate_fn=self.collate_fn_val, sampler=self.sampler_val, batch_sampler=self.batch_sampler_val, **dl_args, ) # check if we need uneven batches if any(d.batch_size is None for _, d in train_dataloader) or ( val_dataloader is not None and any(d.batch_size is None for d in val_dataloader.values()) ): self.accelerator.even_batches = False if self.accelerator.distributed_type == DistributedType.DEEPSPEED: self.accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( self._deepspeed_default_micro_batch_size ) return train_dataloader, val_dataloader def _get_module( self, module: Union[AcceleratorModule, str, Union[tuple[str, str], tuple[str, Any]]], **kwargs: Any ) -> AcceleratorModule: """Get module corresponding to the arguments given.""" if isinstance(module, str): return AcceleratorModule.from_hf(module, **kwargs) elif isinstance(module, tuple): return AcceleratorModule.from_hf(*module, **kwargs) # delete any existing temp state if MASTER_PROCESS and os.path.exists(self.model_path): deleted_count = 0 for p in os.listdir(self.model_path): if p.startswith("_temp_state_"): full_path = os.path.join(self.model_path, p) if os.path.isdir(full_path): rprint(f"Deleting {full_path}...", start_char="") shutil.rmtree(full_path) deleted_count += 1 rprint(f"Deleted {full_path}.", start_char="") if deleted_count > 0: rprint(f"Deleted {deleted_count} temp states.") self.accelerator.wait_for_everyone() return module def _init_trackers(self) -> Optional[str]: """Initialize all trackers along with the training configuration from Hyper Parameters and 'additional_tracker_config'.""" self.accelerator.log_with = [self.tracker.logger_type] track_name = os.path.basename(self.model_path) if self.track_name is None else self.track_name init_kwargs = self.tracker.get_init_kwargs(**self.init_kwargs) config = self.hps.get_config() config["effective_batch_size"] = ( tuple(batch_size * self.accelerator.num_processes for batch_size in self.hps.batch_size) if isinstance(self.hps.batch_size, (tuple, list)) else self.hps.batch_size * self.accelerator.num_processes ) if self.grad_accumulation_steps > 1: obj = config["effective_batch_size"] if isinstance(obj, tuple): config["effective_batch_size"] = (obj[0] * self.grad_accumulation_steps, obj[1]) else: config["effective_batch_size"] = (obj * self.grad_accumulation_steps, obj) config["grad_accumulation_steps"] = self.grad_accumulation_steps config["gradient_checkpointing"] = self.gradient_checkpointing config["gradient_checkpointing_kwargs"] = self.gradient_checkpointing_kwargs config["clip_grad"] = self.clip_grad config["num_processes"] = self.accelerator.num_processes config["accmt_version"] = __version__ if self.hps.max_steps is not None: config.pop("epochs") tracker_config = config | self.additional_tracker_config # register signals to end process safely def end_process(signum, frame): if self.tracker is not None: self.tracker.end(status="KILLED") sys.exit(0) def end_on_exception(exc_type, exc_value, exc_traceback): if issubclass(exc_type, KeyboardInterrupt): sys.__excepthook__(exc_type, exc_value, exc_traceback) return if self.tracker is not None: self.tracker.end(status="FAILED") traceback.print_exception(exc_type, exc_value, exc_traceback) signal.signal(signal.SIGTERM, end_process) signal.signal(signal.SIGINT, end_process) sys.excepthook = end_on_exception if MASTER_PROCESS: # TODO with a Tracker Wrapper this should be fixed. _is_url = is_url(self.logging_dir) if _is_url and not self._logging: raise RuntimeError(f"Cannot log results in '{self.logging_dir}' because 'log_with' was not declared.") self.accelerator.init_trackers(track_name, config=tracker_config, init_kwargs=init_kwargs) self.tracker.set_tracking_uri(self.logging_dir) return self.tracker.run_id def _get_grad_norm(self, norm_type: float = 2.0) -> Union[torch.Tensor, float]: """Calculates grad norm of model.""" if self.accelerator.distributed_type == DistributedType.DEEPSPEED: return self.wrapped_model.get_global_grad_norm() total_norm = 0 for p in self.unwrapped_model.parameters(): if p.grad is not None: total_norm += p.grad.detach().norm(norm_type) ** norm_type return total_norm ** (1.0 / norm_type)
[docs] def log_artifact(self, path: str): """ Logs an artifact to the current run. Args: path (`str`): Path to the file to be logged as an artifact. """ if self._logging and DEBUG_MODE < 1 and self.tracker_initialized: self.tracker.log_artifact(path)
[docs] def log_artifacts(self, path: str): """ Logs multiple artifacts from a directory to the current run. Args: path (`str`): Path to the directory to be logged as an artifact. """ if self._logging and DEBUG_MODE < 1 and self.tracker_initialized: self.tracker.log_artifacts(path)
def _prepare_metrics( self, metrics: Union[Metric, list[Metric], dict[Any, Union[Metric, list[Metric]]]], val_dataloader: Optional[dict[Any, DataLoader]], ) -> dict[Any, list[Metric]]: """Prepare metrics in relation to validation datasets, running checks for types and fixing them if possible.""" if isinstance(metrics, Metric): metrics = {k: [metrics] for k in val_dataloader.keys()} elif isinstance(metrics, list): metrics = {k: metrics for k in val_dataloader.keys()} elif isinstance(metrics, dict): assert all(k in val_dataloader for k in metrics), ( f"There is a mismatch between given metrics and validation datasets. Got {list(metrics.keys())} " f"for 'metrics' and {list(val_dataloader.keys())} for validation datasets." ) metrics = {k: (v if isinstance(v, list) else [v]) for k, v in metrics.items()} return metrics
[docs] def register_model_saving( self, model_saving: str, saving_below: Optional[float] = None, saving_above: Optional[float] = None, ): """ Register a type of model saving. Args: model_saving (`str`): Type of model saving. It can be `"best_valid_loss"` (default), `"best_train_loss"` or in format of `"best_{METRIC}"`. **NOTE**: `"best_"` is optional. Also, all metrics should relate directly to metrics and validation datasets. This can also be in the form of `"best_{METRIC}@{DATASET}"` (metric at a specific dataset), `"best_{METRIC}@{DATASET1}@{DATASET2}"` (metric at dataset1 and dataset2), `"best_{METRIC1}@{DATASET1}/{METRIC1}@{dataset2}"` (best metric1 at dataset1 and best metric2 at dataset2), `"best_{METRIC1}/{METRIC2}@{DATASET2}"` (best metric1 between all datasets containing this metric and best metric2 at dataset2 only), etc. saving_below (`float`, *optional*, defaults to `None`): Register this model saving to only be saved whenever its values are lower than this. saving_above (`float`, *optional*, defaults to `None`): Register this model saving to only be saved whenever its values are above than this. """ saving_below = saving_below if saving_below is not None else float("inf") saving_above = saving_above if saving_above is not None else float("-inf") model_saving = f"best_{model_saving}" if not model_saving.startswith("best_") else model_saving self.model_saving[model_saving] = (saving_below, saving_above)
def _get_comparator(self, metric: str) -> str: """Get comparator for a given metric.""" for metrics in self.metrics.values(): for _metric in metrics: if metric == _metric.main_metric: return _metric.comparator raise RuntimeError(f"No comparator was found for metric '{metric}'.") def evaluate( self, module: AcceleratorModule, dataset: Dataset, eval_logic_fn_name: str = "test_step", results_output: Optional[str] = "results.json", verbose: bool = True, module_hooks: bool = True, *, metrics: Optional[Union[Metric, list[Metric], dict[Any, Union[Metric, list[Metric]]]]] = None, compile: Optional[bool] = None, batch_size: Optional[int] = None, device_placement: Optional[bool] = None, num_workers: Optional[int] = None, pin_memory: Optional[bool] = None, collate_fn: Optional[Callable] = None, prepare_batch: Optional[bool] = None, enable_prepare_logging: Optional[bool] = None, ) -> dict[str, Any]: """ Evaluate the model on a given dataset. Args: module (`AcceleratorModule`): The module to evaluate. dataset (`Dataset`): The dataset to evaluate on. eval_logic_fn_name (`str`, *optional*, defaults to `"test_step"`): The name of the evaluation logic function to use. results_output (`str`, *optional*, defaults to `"results.json"`): The path to the file to save the results to. verbose (`bool`, *optional*, defaults to `True`): Whether to print the results to the console. module_hooks (`bool`, *optional*, defaults to `True`): Whether to call the `before_eval` and `after_eval` hooks of the module. metrics (`Metric`, *optional*, defaults to `None`): The metrics to use for evaluation. If `None`, the metrics used in the trainer will be used. compile (`bool`, *optional*, defaults to `None`): Whether to compile the model. If `None`, the compile setting used in the trainer will be used. batch_size (`int`, *optional*, defaults to `None`): The batch size to use for evaluation. If `None`, the batch size used in the trainer will be used. device_placement (`bool`, *optional*, defaults to `None`): Whether to place the batch on the device. If `None`, the device placement setting used in the trainer will be used. num_workers (`int`, *optional*, defaults to `None`): The number of workers to use for evaluation in the dataloader. If `None`, the number of workers used in the trainer will be used. pin_memory (`bool`, *optional*, defaults to `None`): Whether to pin the memory of the batch. If `None`, the pin memory setting used in the trainer will be used. collate_fn (`Callable`, *optional*, defaults to `None`): The collate function to use for evaluation. prepare_batch (`bool`, *optional*, defaults to `None`): Whether to prepare the batch based on Mixed Precision. This only takes effect when using DeepSpeed. If `None`, the prepare batch setting used in the trainer will be used. enable_prepare_logging (`bool`, *optional*, defaults to `None`): Whether to enable logging preparation (DeepSpeed). If `None`, the enable prepare logging setting used in the trainer will be used. Returns: `dict`: The results of the evaluation. """ metrics = metrics if metrics is not None else self.metrics compile = compile if compile is not None else self.compile device_placement = device_placement if device_placement is not None else self.batch_device_placement num_workers = num_workers if num_workers is not None else self.dataloader_num_workers pin_memory = pin_memory if pin_memory is not None else self.dataloader_pin_memory prepare_batch = prepare_batch if prepare_batch is not None else self.prepare_batch enable_prepare_logging = ( enable_prepare_logging if enable_prepare_logging is not None else self.enable_prepare_logging ) batch_size = batch_size if batch_size is not None else self.hps.batch_size if isinstance(batch_size, tuple): batch_size = batch_size[-1] evaluator = Evaluator( metrics=metrics, compile=compile, batch_size=batch_size, device_placement=device_placement, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate_fn, prepare_batch=prepare_batch, enable_prepare_logging=enable_prepare_logging, ) return evaluator.evaluate(module, dataset, eval_logic_fn_name, results_output, verbose, module_hooks) def evaluate_on_test( self, module: AcceleratorModule, dataset: Dataset, results_output: Optional[str] = "results.json", verbose: bool = True, module_hooks: bool = True, *, metrics: Optional[Union[Metric, list[Metric], dict[Any, Union[Metric, list[Metric]]]]] = None, compile: Optional[bool] = None, batch_size: Optional[int] = None, device_placement: Optional[bool] = None, num_workers: Optional[int] = None, pin_memory: Optional[bool] = None, collate_fn: Optional[Callable] = None, prepare_batch: Optional[bool] = None, enable_prepare_logging: Optional[bool] = None, ) -> dict[str, Any]: """ Alias for `evaluate` with `eval_logic_fn_name` set to `"test_step"`. Args: module (`AcceleratorModule`): The module to evaluate. dataset (`Dataset`): The dataset to evaluate on. results_output (`str`, *optional*, defaults to `"results.json"`): The path to the file to save the results to. verbose (`bool`, *optional*, defaults to `True`): Whether to print the results to the console. module_hooks (`bool`, *optional*, defaults to `True`): Whether to call the `before_eval` and `after_eval` hooks of the module. metrics (`Metric`, *optional*, defaults to `None`): The metrics to use for evaluation. If `None`, the metrics used in the trainer will be used. compile (`bool`, *optional*, defaults to `None`): Whether to compile the model. If `None`, the compile setting used in the trainer will be used. batch_size (`int`, *optional*, defaults to `None`): The batch size to use for evaluation. If `None`, the batch size used in the trainer will be used. device_placement (`bool`, *optional*, defaults to `None`): Whether to place the batch on the device. If `None`, the device placement setting used in the trainer will be used. num_workers (`int`, *optional*, defaults to `None`): The number of workers to use for evaluation in the dataloader. If `None`, the number of workers used in the trainer will be used. pin_memory (`bool`, *optional*, defaults to `None`): Whether to pin the memory of the batch. If `None`, the pin memory setting used in the trainer will be used. collate_fn (`Callable`, *optional*, defaults to `None`): The collate function to use for evaluation. prepare_batch (`bool`, *optional*, defaults to `None`): Whether to prepare the batch based on Mixed Precision. This only takes effect when using DeepSpeed. If `None`, the prepare batch setting used in the trainer will be used. enable_prepare_logging (`bool`, *optional*, defaults to `None`): Whether to enable logging preparation (DeepSpeed). If `None`, the enable prepare logging setting used in the trainer will be used. Returns: `dict`: The results of the evaluation. """ return self.evaluate( module=module, dataset=dataset, eval_logic_fn_name="test_step", results_output=results_output, verbose=verbose, metrics=metrics, compile=compile, batch_size=batch_size, device_placement=device_placement, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate_fn, prepare_batch=prepare_batch, enable_prepare_logging=enable_prepare_logging, module_hooks=module_hooks, ) def evaluate_on_validation( self, module: AcceleratorModule, dataset: Dataset, results_output: Optional[str] = "results.json", verbose: bool = True, module_hooks: bool = True, *, metrics: Optional[Union[Metric, list[Metric], dict[Any, Union[Metric, list[Metric]]]]] = None, compile: Optional[bool] = None, batch_size: Optional[int] = None, device_placement: Optional[bool] = None, num_workers: Optional[int] = None, pin_memory: Optional[bool] = None, collate_fn: Optional[Callable] = None, prepare_batch: Optional[bool] = None, enable_prepare_logging: Optional[bool] = None, ) -> dict[str, Any]: """ Alias for `evaluate` with `eval_logic_fn_name` set to `"validation_step"`. Args: module (`AcceleratorModule`): The module to evaluate. dataset (`Dataset`): The dataset to evaluate on. results_output (`str`, *optional*, defaults to `"results.json"`): The path to the file to save the results to. verbose (`bool`, *optional*, defaults to `True`): Whether to print the results to the console. module_hooks (`bool`, *optional*, defaults to `True`): Whether to call the `before_eval` and `after_eval` hooks of the module. metrics (`Metric`, *optional*, defaults to `None`): The metrics to use for evaluation. If `None`, the metrics used in the trainer will be used. compile (`bool`, *optional*, defaults to `None`): Whether to compile the model. If `None`, the compile setting used in the trainer will be used. batch_size (`int`, *optional*, defaults to `None`): The batch size to use for evaluation. If `None`, the batch size used in the trainer will be used. device_placement (`bool`, *optional*, defaults to `None`): Whether to place the batch on the device. If `None`, the device placement setting used in the trainer will be used. num_workers (`int`, *optional*, defaults to `None`): The number of workers to use for evaluation in the dataloader. If `None`, the number of workers used in the trainer will be used. pin_memory (`bool`, *optional*, defaults to `None`): Whether to pin the memory of the batch. If `None`, the pin memory setting used in the trainer will be used. collate_fn (`Callable`, *optional*, defaults to `None`): The collate function to use for evaluation. prepare_batch (`bool`, *optional*, defaults to `None`): Whether to prepare the batch based on Mixed Precision. This only takes effect when using DeepSpeed. If `None`, the prepare batch setting used in the trainer will be used. enable_prepare_logging (`bool`, *optional*, defaults to `None`): Whether to enable logging preparation (DeepSpeed). If `None`, the enable prepare logging setting used in the trainer will be used. Returns: `dict`: The results of the evaluation. """ return self.evaluate( module=module, dataset=dataset, eval_logic_fn_name="validation_step", results_output=results_output, verbose=verbose, metrics=metrics, compile=compile, batch_size=batch_size, device_placement=device_placement, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate_fn, prepare_batch=prepare_batch, enable_prepare_logging=enable_prepare_logging, module_hooks=module_hooks, )