# Copyright 2025 ghanvert. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import math
import os
import shutil
import signal
import sys
import time
import traceback
from collections import defaultdict
from collections.abc import Mapping
from contextlib import nullcontext
from typing import Any, Callable, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from accelerate import DistributedType
from accelerate.utils import ProjectConfiguration
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from .callbacks import Callback, CallbackMaster
from .dist_utils import Gatherer, rprint, time_prefix
from .hyperparameters import HyperParameters
from .metrics import Metric
from .model_wrapper import _DistributedDataParallel
from .modules import AcceleratorModule
from .monitor import Monitor
from .states import LossState, TrainingState
from .tqdm import tqdm
from .tracker import _tracker_map
from .tunnel import AsyncDiskQueue, AsyncState, ModelTunnel
from .utility import ASYNC, ASYNC_HASH, ASYNC_TRAIN_GROUP, DEBUG_MODE, MASTER_PROCESS, WORLD_SIZE
from .utils import (
cleanup,
filter_kwargs,
get_number_and_unit,
get_seed,
is_url,
operator_map,
print_gpu_users_by_device,
set_seed,
)
CHECKPOINT_DIR = "checkpoint"
STATE_FILE = "state.json"
TRAIN_LOSS_STATE_FILE = "train_loss_state.pt"
_bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} - ETA: {remaining}{postfix} - {rate_s}"
_tqdm_kwargs = {"leave": False, "ncols": 100, "bar_format": _bar_format}
[docs]
class Trainer:
"""Class to implement full training process."""
[docs]
def __init__(
self,
hps_config: Union[str, dict, HyperParameters],
model_path: str,
track_name: Optional[str] = None,
enable_checkpointing: bool = True,
multiple_checkpoints: bool = False,
max_checkpoints: Optional[int] = None,
resume: Optional[Union[bool, int]] = None,
disable_model_saving: bool = False,
patience: Optional[Union[int, dict[str, Any]]] = None,
evaluate_every_n_steps: Optional[int] = None,
checkpoint_every: Optional[str] = "epoch",
logging_dir: str = "logs",
log_with: Optional[str] = None,
log_every: Optional[int] = -1,
grad_accumulation_steps: Optional[int] = None,
gradient_checkpointing: bool = False,
gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
clip_grad: Optional[float] = 1.0,
set_to_none: bool = True,
shuffle_train: bool = True,
sampler: Optional[Union[Any, list]] = None,
collate_fn: Optional[Callable] = None,
collate_fn_train: Optional[Callable] = None,
collate_fn_val: Optional[Callable] = None,
max_shard_size: str = "10GB",
safe_serialization: bool = False,
compile: bool = False,
compile_kwargs: Optional[dict[str, Any]] = None,
safe_mode: bool = True,
train_loss_metric_name: str = "train_loss",
val_loss_metric_name: str = "val_loss",
dataloader_pin_memory: bool = True,
dataloader_num_workers: Optional[int] = None,
dataloader_drop_last: bool = False,
eval_when_finish: bool = True,
eval_when_start: bool = False,
monitor: Optional[Monitor] = None,
metrics: Optional[Union[Metric, list[Metric], dict[Any, Union[Metric, list[Metric]]]]] = None,
cleanup_cache_every_n_steps: Optional[int] = None,
callback: Optional[Union[Callback, list[Callback]]] = None,
additional_tracker_config: Optional[dict[str, Any]] = None,
batch_device_placement: bool = True,
prepare_batch: bool = True,
safe_steps: bool = True,
destroy_after_training: bool = True,
enable_prepare_logging: bool = False,
**kwargs: Optional[Any],
):
"""
Trainer constructor to set configuration.
Args:
hps_config (`str`, `dict`, or `HyperParameters`):
YAML hyperparameters file path, dictionary or `HyperParameters`.
model_path (`str`):
Path to save model.
track_name (`str`, *optional*, defaults to `None`):
Track name for trackers. If set to `None` (default), the track name will be
the model's folder name.
enable_checkpointing (`bool`, *optional*, defaults to `True`):
Enable checkpointing.
multiple_checkpoints (`bool`, *optional*, defaults to `False`):
Enable multiple checkpoints.
max_checkpoints (`int`, *optional*, defaults to `None`):
Maximum number of checkpoints to keep. If set to `None`, all checkpoints will be kept.
resume (`bool` or `int`, *optional*, defaults to `None`):
Whether to resume from checkpoint. Default option is `None`, which means resuming from checkpoint
will be handled automatically, whether the checkpoint directory exists or not.
If set to `True`, the latest checkpoint will be loaded.
If set to an integer, the checkpoint will be loaded from the given index (if `multiple_checkpoints` is `True`).
If set to `-1`, the latest checkpoint will be loaded (if `multiple_checkpoints` is `True`).
disable_model_saving (`bool`, *optional*, defaults to `False`):
Disable any model saving registered (by default, `"best_valid_loss"` is registered, or if there are none evaluations to do,
default will be `"best_train_loss"`).
patience (`int` or `dict`, *optional*, defaults to `None`):
Set up a patience parameter for model savings. If set, every model saving will check if the previous metric was higher.
If the metric has not improved over the N model savings (`patience`), then the training process will stop. Can also
implement patience per model saving in a dictionary.
evaluate_every_n_steps (`int`, *optional*, defaults to `None`):
Evaluate model in validation dataset (if implemented) every N steps. If this is set
to `None` (default option), evaluation will happen at the end of every epoch.
checkpoint_every (`str`, *optional*, defaults to `epoch`):
Checkpoint every N epochs, steps or evaluations. Requires a number and a unit in a string.
The following examples are valid:
- `"epoch"`, `"ep"`, `"1epoch"`, `"1ep"`, `"1 epoch"`, `"1 ep"`: 1 Epoch
- `"step"`, `"st"`, `"1step"`, `"1st"`, `"1 step"`, `"1 st"`: 1 Step
- `"evaluation"`, `"eval"`, `"1evaluation"`, `"1eval"`, `"1 evaluation"`, `"1 eval"`: 1 Evaluation
(a character `s` at the end of the string is also valid)
If set to `None`, checkpointing will be disabled.
logging_dir (`str`, *optional*, defaults to `logs`):
Path where to save logs to show progress. It can be an IP address (local or remote), HTTP or HTTPS link,
or simply a directory.
log_with (`str`, *optional*, defaults to `None`):
Logger to log metrics. It can be one of the following:
- `mlflow`
NOTE: MLFlow is the only one supported right now. Other trackers are not currently available.
log_every (`int`, *optional*, defaults to `-1`):
Log train loss every N steps. If set to `-1`, training loss will be logged at the end of every epoch (or if gradient accumulation
is enabled, the value will be the length of the training dataloader divided by the number of accumulation steps).
If gradient accumulation is enabled and the value is not `-1`, this value will be multiplied by the number of accumulation steps.
grad_accumulation_steps (`int`, *optional*, defaults to `None`):
Accumulate gradients for N steps. Useful for training large models and simulate
large batches when memory is not enough. If set to `None` or `1`, no accumulation will be perfomed.
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
Use gradient checkpointing. It requires a `gradient_checkpointing_enable` method in the model (models from
HuggingFace's `transformers` library have this method already implemented) with a single argument `gradient_checkpointing_kwargs`
(can be a dictionary or `None`).
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
Keyword arguments for `gradient_checkpointing_enable` method.
clip_grad (`float`, *optional*, defaults to 1.0):
Performs gradient clipping in between backpropagation and optimizer's step function.
set_to_none (`bool`, *optional*, defaults to `True`):
From PyTorch documentation: "instead of setting to zero, set the grads to None. This will
in general have lower memory footprint, and can modestly improve performance." Some
optimizers have a different behaviour if the gradient is 0 or None. See PyTorch docs
for more information: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
shuffle_train (`bool`, *optional*, defaults to `True`):
Whether to shuffle train DataLoader.
sampler (`list` or `Any`, *optional*, defaults to `None`):
Sampler (or list of samplers) for train DataLoader.
collate_fn (`Callable`, *optional*, defaults to `None`):
Collate function to be implemented in both train and validation dataloaders.
collate_fn_train (`Callable`, *optional*, defaults to `None`):
Collate function to be implemented in train dataloader. Cannot be imlpemented if `collate_fn` was
already declared.
collate_fn_val (`Callable`, *optional*, defaults to `None`):
Collate function to be implemented in validation dataloader. Cannot be implemented if `collate_fn` was
already declared.
max_shard_size (`str`, *optional*, defaults to `10GB`):
Max model shard size to be used.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save model using safe tensors or the traditional PyTorch way. If `True`, some tensors
will be lost.
compile (`bool`, *optional*, defaults to `False`):
Whether to call `torch.compile` on model (and teacher, if implemented).
compile_kwargs (`dict`, *optional*, defaults to `None`):
`torch.compile` kwargs for additional customization.
safe_mode (`bool`, *optional*, defaults to `True`):
Run forward passes of the model in safe mode. This means that the forward pass of the model will run
through the corresponding wrapper (DDP, FSDP or DeepSpeedEngine). If not running in safe mode, forward pass
will skip the wrapper and run directly on the module (instance of `nn.Module`). Running with safe mode disabled
will slightly improve throughput, although gradients consistency and mixed precision could be affected because
skipping the wrapper's forward pass might skip internal parallel functionality.
**NOTE**: This parameter takes no effect running with FSDP since forward passes are already done through this wrapper.
train_loss_metric_name (`str`, *optional*, defaults to `train_loss`):
Metric name for train loss in logs.
val_loss_metric_name (`str`, *optional*, defaults to `val_loss`):
Metric name for validation loss in logs.
dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
Enables pin memory option in DataLoader (only if GPU is enabled).
dataloader_num_workers (`int`, *optional*, defaults to `None`):
Number of processes for DataLoader. This defaults to `None`, meaning the number of workers will be equal to the
number of processes set for training.
dataloader_drop_last (`bool`, *optional*, defaults to `False`):
Whether to drop last batch on DataLoader or not.
eval_when_finish (`bool`, *optional*, defaults to `True`):
At the end of training, evaluate model on validation dataset (if available). This option is only valid when
`evaluate_every_n_steps` is not `None`.
eval_when_start (`bool`, *optional*, defaults to `False`):
Start training with evaluation (if available).
monitor (`Monitor` or `dict`, *optional*, defaults to `None`):
Monitor arguments to keep track of variables during training. If not specified, 'train_loss' and 'validation_loss' will
be set to `True` by default.
NOTE: Learning rate, GPU and CPU monitoring will only be reported during training, not evaluation. Also, GPU and CPU
monitoring will only be reported on main process (index 0).
metrics (`Metric`, `list` or `dict`, *optional*, defaults to `None`):
List of additional metrics of type 'Metric' to track. When doing multiple evaluations, this should be a dictionary
of metrics (or list of metrics), where each key corresponds to the dataset to evaluate (specified in `val_dataset`
in `fit` function) and the value corresponds to a `Metric` or list of metrics. If metrics are given as only `Metric`
or list of metrics, these metrics will apply for all evaluations. If you want specific metrics for specific evaluations,
consider dividing your metrics per validation dataset in a dictionary.
cleanup_cache_every_n_steps (`int`, *optional*, defaults to `None`):
Cleanup CPU and CUDA caches every N steps. Default is no cleanup.
NOTE: On every epoch and evaluation call we cleanup cache.
callback (`Callback` or `list`, *optional*, defaults to `None`):
`Callback` or callbacks to implement.
additional_tracker_config (`dict`, *optional*, defaults to `None`):
Additional configuration specification for tracker (e.g. hyper-parameters).
batch_device_placement (`bool`, *optional*, defaults to `True`):
Move batches to correct device automatically. If `False`, batches will be in CPU.
prepare_batch (`bool`, *optional*, defaults to `True`):
Prepares a batch dynamically when using Mixed Precision. When using DeepSpeed, we need to scale down
the floating point tensors to be able to do calculations with the model. If not using DeepSpeed,
this argument takes no effect.
safe_steps (`bool`, *optional*, defaults to `True`):
Run safe training and validation steps to avoid OOMs (Out Of Memory errors) and retry steps. If a retry does not
solve the problem, a list of users using GPUs will pop up and the OOM error will raise.
destroy_after_training (`bool`, *optional*, defaults to `True`):
Destroy the process group after training. Set to `False` if you're running multiple trainings in the same script.
enable_prepare_logging (`bool`, *optional*, defaults to `False`):
Enable internal model preparation logging. When using DeepSpeed, there are many messages that appear
in the terminal that can be annoying.
kwargs (`Any`, *optional*):
Extra arguments for specific `init` function in Tracker, e.g. `run_name`, `tags`, etc.
"""
# do some previous checks
self.log_with = log_with.lower() if isinstance(log_with, str) else log_with
self.tracker = _tracker_map[self.log_with]() if self.log_with is not None and DEBUG_MODE < 1 else None
assert isinstance(hps_config, (str, dict, HyperParameters)), (
"'hps_config' needs to be either a string, dictionary or HyperParameters class."
)
assert clip_grad is None or isinstance(clip_grad, float), "'clip_grad' argument needs to be a float."
from . import IS_GPU, accelerator
self.is_gpu = IS_GPU
self.accelerator = accelerator
self.hps = HyperParameters.from_config(hps_config) if isinstance(hps_config, (str, dict)) else hps_config
self.track_name = track_name
self.checkpoint_path = os.path.join(model_path, CHECKPOINT_DIR)
self.model_path = model_path
if type(resume) is int:
if not multiple_checkpoints:
raise ValueError(
"Cannot specify a checkpoint index in 'resume' when 'multiple_checkpoints' is disabled."
)
elif resume == 0 or resume < -1:
raise ValueError(
"Checkpoint index in 'resume' must be greater than 0 (or -1 to resume from latest checkpoint)."
)
self.resume = (
(
resume
if resume is not None
else os.path.exists(self.checkpoint_path) and len(os.listdir(self.checkpoint_path)) > 0
)
if DEBUG_MODE < 3
else False
)
self.metrics: dict[Any, list[Metric]] = metrics
self.disable_model_saving = disable_model_saving
self.model_saving: dict[
str, tuple[float, float]
] = {} # key: model saving, value: (saving_below, saving_above)
if patience is not None and isinstance(patience, int):
self.patience = patience if patience is not None else -1
if self.patience == 0:
raise ValueError("The 'patience' argument in Trainer should have a value greater than 0.")
elif isinstance(patience, dict):
for k, v in patience.items():
if v == 0:
raise ValueError(
"The 'patience' argument when declared as a dictionary needs to have values above 0. "
f"Got {v} in '{k}'."
)
elif patience is not None:
raise ValueError("'patience' must be either an integer value or a dictionary.")
else:
self.patience = -1
self.evaluate_every_n_steps = evaluate_every_n_steps
self.enable_checkpointing = enable_checkpointing if DEBUG_MODE < 3 else False
self.multiple_checkpoints = multiple_checkpoints
if max_checkpoints is not None and max_checkpoints <= 0:
raise ValueError("'max_checkpoints' must be greater than 0 or `None`.")
self.max_checkpoints = max_checkpoints
self.checkpoint_every, self.checkpoint_strat = get_number_and_unit(checkpoint_every)
self.logging_dir = logging_dir
self.log_every = log_every
self.grad_accumulation_steps = grad_accumulation_steps if grad_accumulation_steps is not None else 1
self.gradient_checkpointing = gradient_checkpointing
self.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs
self.accelerator.gradient_accumulation_steps = self.grad_accumulation_steps
self.clip_grad = clip_grad if clip_grad is not None else 0.0
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
self.accelerator.deepspeed_plugin.deepspeed_config["gradient_clipping"] = self.clip_grad
self.set_to_none = set_to_none
self.shuffle_train = shuffle_train
self.sampler = sampler
if collate_fn is not None and (collate_fn_train is not None or collate_fn_val is not None):
raise ValueError("'collate_fn' cannot be declared along with 'collate_fn_train' or 'collate_fn_val'.")
self.collate_fn = collate_fn
self.collate_fn_train = collate_fn_train if collate_fn is None else collate_fn
self.collate_fn_val = collate_fn_val if collate_fn is None else collate_fn
self.max_shard_size = max_shard_size
self.safe_serialization = safe_serialization
self.compile = compile
self.compile_kwargs = compile_kwargs if compile_kwargs is not None else {}
self.safe_mode = safe_mode
self.train_loss_metric_name = train_loss_metric_name
self.val_loss_metric_name = val_loss_metric_name
self.dataloader_pin_memory = dataloader_pin_memory if IS_GPU else False
self.dataloader_num_workers = (
dataloader_num_workers if dataloader_num_workers is not None else self.accelerator.num_processes
)
if (DEBUG_MODE > 0 and self.dataloader_num_workers != 0) or self.accelerator.num_processes == 1:
# force when debugging to not have problems with dataloader during breakpoints
self.dataloader_num_workers = 0
self.dataloader_drop_last = dataloader_drop_last
self.samplers = sampler
self.eval_when_finish = eval_when_finish
self.eval_when_start = eval_when_start if DEBUG_MODE < 4 else False
self.monitor = monitor if isinstance(monitor, Monitor) else Monitor.from_config(monitor)
self.cleanup_cache_every_n_steps = cleanup_cache_every_n_steps
callback = callback if callback is not None else Callback()
callback = callback if isinstance(callback, list) else [callback]
self.callback = CallbackMaster(callback)
self.additional_tracker_config = additional_tracker_config if additional_tracker_config is not None else {}
self.batch_device_placement = batch_device_placement
self.prepare_batch = prepare_batch
self.safe_steps = safe_steps
self.destroy_after_training = destroy_after_training
self.enable_prepare_logging = enable_prepare_logging
self.init_kwargs = kwargs
self.accelerator.project_configuration = ProjectConfiguration(
project_dir=".", logging_dir=logging_dir, total_limit=1
)
self._logging = self.log_with is not None
self.state = TrainingState()
self.gatherer = Gatherer()
# adding a total (at maximum) of 64 bytes for additional tensors
self.train_loss_state = LossState(self.accelerator, self.accelerator.device, self.log_every, pin_memory=IS_GPU)
self.val_loss_state: dict[Any, LossState] = None # prepare val loss states in 'fit' function
self._checkpointing_every_n_steps = self.enable_checkpointing and self.checkpoint_strat == "step"
self._checkpointing_after_evaluation = self.enable_checkpointing and self.checkpoint_strat == "eval"
self._checkpointing_when_epoch_ends = self.enable_checkpointing and self.checkpoint_strat == "epoch"
self._module: AcceleratorModule = None
self._scheduler: LRScheduler = None
self._optimizer: Optimizer = None
self._multiple_evaluations = False
self.unwrapped_model: nn.Module = None
self.wrapped_model = None
self.async_state = AsyncState(self.model_path) if ASYNC else None
self.async_queue = AsyncDiskQueue(self.model_path, self.accelerator) if ASYNC else None
self.tunnel = ModelTunnel(ASYNC_HASH) if ASYNC else None
# initialize trackers
self.run_id = None
self.tracker_initialized = False
if self._logging and DEBUG_MODE < 1 and ((ASYNC and ASYNC_TRAIN_GROUP) or not ASYNC):
self.run_id = self._init_trackers()
self.tracker_initialized = True
self._model_dtype = torch.float32
[docs]
def fit(
self,
module: Union[AcceleratorModule, str, Union[tuple[str, str], tuple[str, Any]]],
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Union[Dataset, list[Dataset], dict[str, Dataset]]] = None,
**kwargs: Any,
):
"""
Function to train a given `AcceleratorModule`.
Args:
module (`AcceleratorModule`, `str` or `tuple`):
`AcceleratorModule` class containig the training logic. This can also be a string specifying a
HuggingFace model, or a tuple of type (model, type), where 'model' is a string for the HuggingFace model,
and 'type' is a string or class (from transformers library) for the model type.
train_dataset (`torch.utils.data.Dataset`, *optional*, defaults to `None`):
`Dataset` class from PyTorch containing the train dataset logic. If not provided, then
`get_train_dataloader` from `module` will be used to get the train DataLoader.
val_dataset (`torch.utils.data.Dataset`, `list` or `dict`, *optional*, defaults to `None`):
`Dataset` class from PyTorch containing the validation dataset logic. This can also be a list or a dictionary
of `Dataset`, in that case, multiple evaluations will run following the logic of `validation_step` and
specified metrics. Metric names reported for a multiple evaluation setting will add a '_' followed by a key
related to the dataset (e.g. 'accuracy_1' or 'accuracy_another_dataset').
If this dataset is not specified, then the validation logic of `AcceleratorModule`
(if specified) will be skipped.
kwargs (`Any`):
Keyword arguments for `from_pretrained` function for model initialization.
"""
# reset loss states in case of another fit function call in the script
cleanup()
self.train_loss_state.reset()
if self.val_loss_state is not None:
for v in self.val_loss_state.values():
v.reset()
module = self._get_module(module, **kwargs)
module.log_every = self.log_every
self._module = module
model = module.model
self.unwrapped_model = model
if model is None or not isinstance(model, nn.Module):
raise RuntimeError(
"`AcceleratorModule` subclass requires `self.model` and needs to be an instance of `nn.Module`."
)
teacher = module.teacher
if torch.cuda.is_available():
model.to(self.accelerator.device)
if teacher is not None:
teacher.eval()
teacher.to(self.accelerator.device)
module.state = self.state
module.accelerator = self.accelerator
module.device = self.accelerator.device
if MASTER_PROCESS and DEBUG_MODE < 3 and (self.enable_checkpointing or not self.disable_model_saving):
os.makedirs(self.model_path, exist_ok=True)
val_dataset = val_dataset if val_dataset is None or isinstance(val_dataset, (list, dict)) else [val_dataset]
self._multiple_evaluations = val_dataset is not None and len(val_dataset) > 1
train_dataloader, val_dataloader = self._get_dataloaders(module, train_dataset, val_dataset)
self.metrics = self._prepare_metrics(self.metrics, val_dataloader)
if len(self.model_saving) == 0:
if val_dataloader is not None:
self.model_saving["best_valid_loss"] = (float("inf"), float("-inf"))
else:
self.model_saving["best_train_loss"] = (float("inf"), float("-inf"))
if isinstance(self.patience, int):
self.state.patience_left = {k: self.patience for k in self.model_saving.keys()}
else:
if not all(k in self.model_saving for k in self.patience.keys()):
raise RuntimeError("Keys declared in 'patience' do not match model savings.")
self.state.patience_left = {
k: (self.patience[k] if k in self.patience else -1) for k in self.model_saving.keys()
}
if self.metrics is not None:
for k, v in self.metrics.items():
self.state.additional_metrics[k] = {m.main_metric: 0 for m in v}
else:
for k in val_dataloader.keys():
self.state.additional_metrics[k] = {}
if self.resume:
checkpoint_path = self._get_current_checkpoint_path()
if checkpoint_path.endswith("checkpoint_0"):
raise FileNotFoundError("Checkpoint directory is empty or not found.")
training_state_path = os.path.join(checkpoint_path, STATE_FILE)
loss_tracker_path = os.path.join(checkpoint_path, TRAIN_LOSS_STATE_FILE)
self.state.load(training_state_path)
self.train_loss_state.load(loss_tracker_path)
if self.state.finished:
raise RuntimeError("Training process has been flagged as finished.")
module.state = self.state
self.monitor._set_extra(
self.accelerator, self.state, self.train_loss_metric_name, self.val_loss_metric_name, self.tracker
)
self.monitor._tracking = self.tracker is not None
if self.accelerator.distributed_type == DistributedType.FSDP:
# preparing model before dataloaders is only supported by FSDP apparently, and this is the
# recommended setting to prepare training.
model = self.accelerator.prepare_model(model)
self.val_loss_state = {
k: LossState(
self.accelerator, self.accelerator.device, -1, include_per_batch=False, pin_memory=self.is_gpu
)
for k in val_dataloader.keys()
}
optimizer = self._get_optimizer(module)
if self.hps.step_scheduler_per_epoch:
scheduler = self._get_scheduler(module, optimizer, self.hps.epochs, self.hps.epochs)
elif self.hps.max_steps is not None:
num_training_steps = self.hps.max_steps
self.hps.epochs = math.ceil(num_training_steps / (len(train_dataloader) / self.accelerator.num_processes))
scheduler = self._get_scheduler(
module, optimizer, num_training_steps, 1
) # ignore epochs to avoid multiplication
else:
scheduler = self._get_scheduler(
module, optimizer, round(len(train_dataloader) / self.accelerator.num_processes), self.hps.epochs
)
if ASYNC:
if ASYNC_TRAIN_GROUP:
self.tunnel.init(model)
self.async_state.init()
self.async_state.update(tunnel_ready=True, run_id=self.run_id)
# only MASTER_PROCESS returns a valid 'run_id', and 'update' function already handles that.
else:
self.async_state.wait_for_tunnel()
model, teacher, train_dataloader, val_dataloader, optimizer, scheduler = self._prepare(
module,
model,
teacher,
train_dataloader,
val_dataloader,
optimizer,
scheduler,
batch_device_placement=self.batch_device_placement,
)
self._model_dtype = next(model.parameters()).dtype
if ASYNC and not ASYNC_TRAIN_GROUP:
# force train dataloader, optimizer and scheduler to be None in evaluation group since they're not being used.
train_dataloader = None
optimizer = None
scheduler = None
self._scheduler = scheduler
self._optimizer = optimizer
module.scheduler = scheduler
module.optimizer = optimizer
self.wrapped_model = model
if self.log_every < 0: # report training loss at the last step (or end of an epoch)
self.log_every = len(train_dataloader) // self.grad_accumulation_steps
elif self.grad_accumulation_steps > 1:
self.log_every = self.grad_accumulation_steps * self.log_every
for callback in self.callback.children:
callback.module = module
callback.trainer = self
callback.state = self.state
self.callback.on_fit_start()
if ASYNC and not ASYNC_TRAIN_GROUP:
self.dispatch_async_eval(module, model, val_dataloader)
else:
self.loop(module, model, train_dataloader, val_dataloader, optimizer, scheduler)
if ASYNC and ASYNC_TRAIN_GROUP:
self.async_state.update(train_finished=True)
# wait until evaluation group is finished
while not self.async_state.evaluation_finished:
time.sleep(0.5)
# evaluation group delegates the job to train group
eval_runs_pending = self.async_state.evaluations_in_queue
if eval_runs_pending > 0:
for _ in range(eval_runs_pending):
self._async_eval(module, model, val_dataloader)
self.tunnel.close()
self.state.finished = True
self.callback.on_fit_end()
self.accelerator.free_memory(model, train_dataloader, val_dataloader, scheduler, optimizer, scheduler)
if self.log_with is not None:
self.accelerator.get_tracker(self.log_with).finish()
if self.destroy_after_training and WORLD_SIZE > 1:
# done to avoid pytorch distributed warnings if script finishes here
dist.destroy_process_group()
else:
module.model = self.unwrapped_model
# TODO still getting memory leaks if running multiple trainings using the very same module
def loop(
self,
module: AcceleratorModule,
model: nn.Module,
train_dataloader: DataLoader,
val_dataloader: Optional[dict[Any, DataLoader]],
optimizer: Optimizer,
scheduler: Optional[LRScheduler],
):
"""Runs a training loop."""
if self.state.evaluations_done == 0 and self.eval_when_start:
self.launch_eval(module, model, val_dataloader)
for _ in self.epoch_iterator():
for batch in self.batch_iterator(train_dataloader, model):
self._train_logic(module, model, optimizer, batch, scheduler)
if (
self.evaluate_every_n_steps is not None
and (self.state.global_step + 1) % self.evaluate_every_n_steps == 0
):
self.launch_eval(module, model, val_dataloader)
if self.evaluate_every_n_steps is None or (self.eval_when_finish and self.state.is_last_epoch):
self.launch_eval(module, model, val_dataloader)
def dispatch_async_eval(
self, module: AcceleratorModule, model: nn.Module, dataloader: dict[Any, DataLoader], delay: float = 0.1
):
while not self.async_state.train_finished:
self._async_eval(module, model, dataloader)
# continue checking for evaluations
time.sleep(delay)
self.async_state.update(evaluation_finished=True)
def _async_eval(self, module: AcceleratorModule, model: nn.Module, dataloader: dict[str, DataLoader]):
evals_in_queue = self.async_state.evaluations_in_queue
if evals_in_queue > 0:
# read last model from SHM
self.tunnel.read(model)
if evals_in_queue >= 2:
# read next model from disk and write it into SHM
state_dict = self.async_queue.dequeue()
self.tunnel.write_state_dict(state_dict, non_blocking=True)
self.async_state.update(evaluations_in_queue=-1)
self.eval(module, model, dataloader)
def launch_eval(
self,
module: AcceleratorModule,
model: nn.Module,
dataloader: dict[Any, DataLoader],
):
if ASYNC:
self.accelerator.wait_for_everyone()
unwrapped_model = self.accelerator.unwrap_model(model)
if self.async_state.evaluations_in_queue == 0:
# SHM is free
self.tunnel.write(unwrapped_model)
else:
# SHM waiting, then we write to disk
self.async_queue.enqueue(unwrapped_model)
self.async_state.update(evaluations_in_queue=1)
self.accelerator.wait_for_everyone()
else:
self.eval(module, model, dataloader)
should_save_model = not (
self.eval_when_start and self.state.evaluations_done == 1
) # not doing first requested evaluation
if self._checkpointing_after_evaluation and should_save_model:
self._save_checkpoint(
self.state.epoch + (0 if not self.state.is_end_of_epoch else 1),
self.state.train_step + (1 if not self.state.is_end_of_epoch else 0),
self.state.global_step + (1 if not self.state.is_end_of_epoch else 0),
self.state.evaluations_done,
finished=self.state.finished,
)
@torch.inference_mode()
def eval(self, module: AcceleratorModule, model: nn.Module, dataloader: Optional[dict[Any, DataLoader]]):
"""Runs evaluation on a given dataloader."""
no_patience_left = all(v == 0 for v in self.state.patience_left.values())
if DEBUG_MODE >= 5 or no_patience_left or dataloader is None:
return
if model.training:
model.eval()
cleanup()
self.callback.on_evaluation_start()
for k, val_dataloader in dataloader.items():
val_str = f" ({k}) " if self._multiple_evaluations else " "
for i, batch in tqdm(
iterable=enumerate(val_dataloader),
total=len(val_dataloader),
desc=f"📊{val_str}Evaluating in Epoch {self.state.epoch + 1}/{self.hps.epochs}",
position=1,
colour="cyan",
**_tqdm_kwargs,
):
self.state.val_step = i
self.state.is_last_validation_batch = i == len(val_dataloader) - 1
batch = self._prepare_batch(batch) if self.prepare_batch else batch
self._validation_logic(module, k, batch)
self.state.additional_metrics[k]["valid_loss"] = self.val_loss_state[k].get_total_loss()
if self.metrics is not None:
for metric in self.metrics[k]:
if not metric._parallel and MASTER_PROCESS:
# we don't want to call '_compute' for metrics that are implemented in main process,
# since the state on other processes is empty
metric_dict = metric._compute()
for m, v in metric_dict.items():
if not isinstance(v, (float, int, torch.Tensor, np.ndarray)):
raise ValueError(
f"Value in metric's dict does not accept {type(v)}, only "
f"`float`, `int`, `torch.Tensor` (torch) or `NDArray` (numpy)"
)
self.state.additional_metrics[k][m] = (
v if not isinstance(v, (torch.Tensor, np.ndarray)) else v.item()
)
elif metric._parallel:
metric_dict = metric._compute()
# we are not fixing objects since in parallel mode they're already converted to python values
self.state.additional_metrics[k].update(metric_dict)
# re-format metrics, instead of a dict dataset_key (key) and metrics (dictionary value), gather
# all metrics into a single dictionary with the format {metric__dataset_key: value}.
# e.g. {"accuracy__dataset1": 0.21, "accuracy__dataset2": 0.67}
log_dict = {}
for _metric_name, _value in self.state.additional_metrics[k].items():
if _metric_name.startswith("best_"):
continue
_metric_name = f"{_metric_name}__{k}" if self._multiple_evaluations else _metric_name
log_dict[_metric_name] = _value
run_id = self.async_state.run_id if ASYNC and MASTER_PROCESS else None
self.monitor.log_additional_metrics(log_dict, run_id=run_id)
self.state.evaluations_done += 1
should_save_model = not (
self.eval_when_start and self.state.evaluations_done == 1
) # not doing first requested evaluation
# save model
if self.model_saving is not None and should_save_model and DEBUG_MODE < 3:
self._save_model_on_criteria(model)
else:
# reset total loss state for validation since it's not being used
for k in self.val_loss_state.keys():
self.val_loss_state[k].total_loss.zero_()
self.val_loss_state[k].num_steps.zero_()
self.state.val_step = 0
# flag as finished if doing very last evaluation
self.state.finished = self.state.is_last_training_batch and self.state.is_last_epoch
self.callback.on_evaluation_end()
def _save_model_on_criteria(self, model: nn.Module):
"""Save model depending on criteria defined in `model_saving`"""
self.accelerator.wait_for_everyone()
train_loss = self.train_loss_state.get_total_loss()
can_save = not (self.eval_when_start and self.state.evaluations_done == 1)
def _check_and_save(model_saving: str):
_model_saving = model_saving
model_saving_without_prefix = model_saving.removeprefix("best_")
# we already have all metrics calculated per dataset in self.state.additional_metrics
metrics_and_datasets = defaultdict(
list
) # e.g. {"accuracy": ["dataset1", "dataset2"], "metric": [dataset_keys, ...]}
for metric in model_saving_without_prefix.split("/"):
metric, *datasets = metric.split("@")
if len(datasets) == 0:
# if datasets are not specified for a metric, then it means that we need to average
# across all datasets
datasets = self.state.additional_metrics.keys()
for dataset in datasets:
metrics_and_datasets[metric].append(dataset)
# now create a buffer per metric, where each value in the buffer corresponds to the
# metric found in a dataset
metric_buffer = defaultdict(list) # e.g. {"accuracy": [0.2, 0.5], "metric": [values, ...]}
for dataset_key, metrics_dict in self.state.additional_metrics.items():
for metric, value in metrics_dict.items():
if dataset_key in set(metrics_and_datasets[metric]):
metric_buffer[metric].append(value)
# now average those metrics in buffer
metric_avgs = {k: (np.mean(v) if len(v) > 1 else v[0]) for k, v in metric_buffer.items()}
_metrics = [ms.split("@")[0] for ms in model_saving_without_prefix.split("/")]
count = 0
for metric in _metrics:
best_metric_str = f"best_{metric}"
comparator = self._get_comparator(metric) if metric != "valid_loss" else "<"
compare = operator_map[comparator]
new = metric_avgs[metric]
# calculate average between previous metrics in wanted datasets
prev = []
for dataset_key in set(metrics_and_datasets[metric]):
if best_metric_str not in self.state.additional_metrics[dataset_key]:
# only register best metrics in wanted datasets
self.state.additional_metrics[dataset_key][best_metric_str] = (
float("inf") if comparator in {"<", "<=", "=="} else float("-inf")
)
prev.append(self.state.additional_metrics[dataset_key][best_metric_str])
prev = np.mean(prev) if len(prev) > 1 else prev[0]
saving_below, saving_above = self.model_saving[_model_saving]
is_better = compare(new, prev) and new < saving_below and new > saving_above
if is_better:
count += 1
# register best metrics for all wanted datasets
for dataset, new_metric_calculated in zip(metrics_and_datasets[metric], metric_buffer[metric]):
local_prev = self.state.additional_metrics[dataset][best_metric_str]
if compare(new_metric_calculated, local_prev):
self.state.additional_metrics[dataset][best_metric_str] = new_metric_calculated
if count == len(_metrics):
# all these metrics have improved
if MASTER_PROCESS and can_save and not self.disable_model_saving:
model_path = os.path.join(self.model_path, _model_saving.replace("/", "__"))
self._save_model(model, model_path)
elif can_save and self.state.patience_left[_model_saving] > 0:
self.state.patience_left[_model_saving] -= 1
if len(self.model_saving) > 0:
if train_loss < self.state.best_train_loss:
self.state.best_train_loss = train_loss
if "best_train_loss" in self.model_saving:
if MASTER_PROCESS and can_save and not self.disable_model_saving:
model_path = os.path.join(self.model_path, "best_train_loss")
self._save_model(model, model_path)
elif (
can_save and "best_train_loss" in self.model_saving and self.state.patience_left["best_train_loss"] > 0
):
self.state.patience_left["best_train_loss"] -= 1
for model_saving in self.model_saving: # TODO we could implement a tqdm bar maybe...
_check_and_save(model_saving)
# if all model savings have no patience anymore, finish training process
count = 0
for model_saving in self.model_saving.keys():
count += self.state.patience_left[model_saving] == 0
self.accelerator.wait_for_everyone()
if count == len(self.model_saving):
rprint("Ran out of patience. Process finished.")
self.state.finished = True
if MASTER_PROCESS:
state_in_checkpoint = os.path.join(self.checkpoint_path, STATE_FILE)
self.state.save(state_in_checkpoint)
for model_saving in self.model_saving:
model_saving_path = os.path.join(self.model_path, model_saving)
os.makedirs(model_saving_path, exist_ok=True)
model_saving_path = os.path.join(model_saving_path, STATE_FILE)
self.state.save(model_saving_path)
self.accelerator.end_training()
exit(0)
def _save_model(self, model: nn.Module, path: str):
"""Save model inside a path."""
tqdm.write(f"\r{time_prefix()} Saving model...")
os.makedirs(path, exist_ok=True)
unwrapped_model = self.accelerator.unwrap_model(model)
state_dict = unwrapped_model.state_dict() if not self.compile else unwrapped_model._orig_mod.state_dict()
if hasattr(unwrapped_model, "save_pretrained"): # special function for models from transformers library
unwrapped_model.save_pretrained(
path,
is_main_process=True,
state_dict=state_dict,
max_shard_size=self.max_shard_size,
save_function=self.accelerator.save,
safe_serialization=self.safe_serialization,
)
else:
pt_state_dict = os.path.join(path, "pytorch_model.pt")
self.accelerator.save(state_dict, pt_state_dict, safe_serialization=self.safe_serialization)
training_state_path = os.path.join(path, STATE_FILE)
self.state.save(training_state_path)
tqdm.write(f"\033[A\033[K{time_prefix()} Model saved.")
def _validation_logic(self, module: AcceleratorModule, dataloader_key: Any, batch: Any):
"""Runs all the validation logic."""
self.callback.on_before_validation_step(batch)
if self.safe_steps:
metrics = self._safe_step(module.validation_step, dataloader_key, batch)
else:
metrics = module.validation_step(dataloader_key, batch)
if isinstance(metrics, torch.Tensor):
# assume it's loss value, so convert wrap it into a dictionary
metrics = {"loss": metrics}
self.callback.on_after_validation_step()
# track loss
loss = metrics["loss"].detach()
self.val_loss_state[dataloader_key].add_total_loss(loss)
# track metrics
if self.metrics is not None:
for metric in self.metrics[dataloader_key]:
if metric.main_metric not in metrics:
raise RuntimeError("Make sure to align 'validation_step' with declared metrics.")
metric_compute_arguments = metrics[metric.main_metric]
if not isinstance(metric_compute_arguments, tuple):
metric_compute_arguments = (metric_compute_arguments,)
if not metric._parallel:
metric_compute_arguments = (
*(
(
self.gatherer.all_gather_dictionary(arg)
if isinstance(arg, dict)
else self.accelerator.gather_for_metrics(arg)
)
for arg in metric_compute_arguments
), # leave it as tuple
)
if MASTER_PROCESS and metric_compute_arguments[0] is not None:
metric.add_batch(*metric_compute_arguments)
elif metric_compute_arguments[0] is not None:
metric.add_batch(*metric_compute_arguments)
def _prepare_batch(self, batch: Any) -> Any:
"""
Prepare elements in a batch based on Mixed Precision. This function only takes effect when using DeepSpeed.
"""
if self.accelerator.distributed_type != DistributedType.DEEPSPEED:
return batch
return self._prepare_nested_batch(batch)
def _prepare_nested_batch(self, batch: Any) -> Any:
"""
Prepare nested batch. This function is derived from `transformers` library
(https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py).
"""
if isinstance(batch, Mapping):
return type(batch)({k: self._prepare_nested_batch(v) for k, v in batch.items()})
elif isinstance(batch, (tuple, list)):
return type(batch)(self._prepare_nested_batch(v) for v in batch)
elif isinstance(batch, torch.Tensor):
kwargs = {"device": self.accelerator.device}
if torch.is_floating_point(batch) or torch.is_complex(batch):
kwargs.update({"dtype": self._model_dtype})
return batch.to(**kwargs)
return batch
def _safe_step(self, fn: Callable, *args, **kwargs) -> Union[torch.Tensor, dict, Any]:
try:
return fn(*args, **kwargs)
except RuntimeError as e:
if "out of memory" in str(e):
for p in self.wrapped_model.parameters():
if p.grad is not None:
del p.grad
torch.cuda.empty_cache()
try:
return fn(*args, **kwargs)
except RuntimeError as _e:
rprint("CUDA: Out Of Memory.")
if "out of memory" in str(_e):
print_gpu_users_by_device()
if self.tracker is not None:
self.tracker.end(status="FAILED")
if WORLD_SIZE > 1:
dist.destroy_process_group()
exit(1)
else:
raise e
def _train_logic(
self,
module: AcceleratorModule,
model: nn.Module,
optimizer: Optimizer,
batch: Any,
scheduler: Optional[LRScheduler],
):
"""Runs all the training logic."""
if ASYNC and not ASYNC_TRAIN_GROUP:
return
self.callback.on_before_training_step(batch)
with self.accelerator.accumulate(model) if self.grad_accumulation_steps > 1 else nullcontext():
# forward pass
if self.safe_steps:
loss = self._safe_step(module.training_step, batch)
else:
loss = module.training_step(batch)
self.callback.on_after_training_step()
# track
_loss = loss.detach()
self.train_loss_state.add_batch_loss(_loss)
self.train_loss_state.add_total_loss(_loss)
self.callback.on_before_backward(loss)
if not module._extended:
# backpropagation
self.accelerator.backward(loss)
self.callback.on_after_backward()
norm = None
if (
self.accelerator.sync_gradients
and self.clip_grad > 0.0
and self.accelerator.distributed_type != DistributedType.DEEPSPEED
):
norm = self.accelerator.clip_grad_norm_(model.parameters(), self.clip_grad)
if self.state.global_step % self.log_every == 0:
batch_loss = self.train_loss_state.get_batch_loss()
if MASTER_PROCESS and self.monitor.grad_norm and norm is None:
norm = self._get_grad_norm()
self.monitor.log_train_loss_and_grad_norm(batch_loss, norm)
if not module._extended:
self.callback.on_before_optimizer_step(optimizer)
optimizer.step()
self.callback.on_after_optimizer_step(optimizer)
if scheduler is not None and not self.hps.step_scheduler_per_epoch:
self.callback.on_before_scheduler_step(scheduler)
scheduler.step()
self.callback.on_after_scheduler_step(scheduler)
# reset gradients
self.callback.on_before_optimizer_step(optimizer)
optimizer.zero_grad(set_to_none=self.set_to_none)
self.callback.on_after_zero_grad(optimizer)
def batch_iterator(self, dataloader: DataLoader, model: nn.Module):
"""Batch iterator for training handling checkpointing."""
if not model.training:
model.train()
if self.shuffle_train:
global_seed = get_seed(default=0)
set_seed(global_seed + self.state.epoch)
dataloader.set_epoch(self.state.epoch)
_dataloader = self.accelerator.skip_first_batches(dataloader, self.state.train_step)
cleanup()
start = self.state.train_step
# determine total steps for the current epoch
total_steps_in_epoch = len(dataloader)
# calculate remaining steps in current epoch
remaining_steps = total_steps_in_epoch - start
# for progress bar, use max_steps if defined, otherwise use dataloader length
progress_total = self.hps.max_steps if self.hps.max_steps is not None else total_steps_in_epoch
progress_initial = self.state.global_step if self.hps.max_steps is not None else start
if remaining_steps > 0:
for i, batch in tqdm(
iterable=enumerate(_dataloader, start),
total=progress_total,
initial=progress_initial,
desc=f"🚀 Training in Epoch {self.state.epoch + 1}/{self.hps.epochs}",
position=0,
colour="green",
**_tqdm_kwargs,
):
self.state.train_step = i
self.state.is_last_training_batch = (i == total_steps_in_epoch - 1) or (
self.hps.max_steps is not None and self.state.global_step + 1 >= self.hps.max_steps
)
if self.state.global_step % self.log_every == 0:
lr = (
self._scheduler.get_last_lr()[-1]
if self._scheduler is not None
else self._optimizer.param_groups[0]["lr"]
)
# TODO we can fuse these functions to only report once to the server
self.monitor.log_learning_rate(lr)
self.monitor.log_cpu_utilization()
self.monitor.log_gpu_utilization()
batch = self._prepare_batch(batch) if self.prepare_batch else batch
yield batch
if (
self.cleanup_cache_every_n_steps is not None
and (self.state.global_step + 1) % self.cleanup_cache_every_n_steps == 0
):
cleanup()
if self._checkpointing_every_n_steps and (self.state.global_step + 1) % self.checkpoint_every == 0:
self._save_checkpoint(
self.state.epoch,
self.state.train_step + 1,
self.state.global_step + 1,
self.state.evaluations_done,
)
self.state.global_step += 1
# check if we've reached max_steps
if self.hps.max_steps is not None and self.state.global_step >= self.hps.max_steps:
break
# if length of _dataloader is 0, then we do not iterate
self.state.is_end_of_epoch = True
self.state.train_step = 0
def _save_checkpoint(
self, epoch: int, train_step: int, global_step: int, evaluations_done: int, finished: bool = False
):
"""Save checkpoint at a given point in time (`epoch` and `train_step`)."""
self.callback.on_save_checkpoint()
if MASTER_PROCESS:
tqdm.write(f"\r{time_prefix()} Saving checkpoint...")
os.makedirs(self.checkpoint_path, exist_ok=True)
self.accelerator.wait_for_everyone()
checkpoint_path = self.checkpoint_path
if self.multiple_checkpoints:
if (
MASTER_PROCESS
and self.max_checkpoints is not None
and len(os.listdir(checkpoint_path)) >= self.max_checkpoints
):
min_checkpoint = min(os.listdir(checkpoint_path), key=lambda x: int(x.split("_")[-1]))
shutil.rmtree(os.path.join(checkpoint_path, min_checkpoint))
last_checkpoint_num = int(self._get_current_checkpoint_path(ignore_resume_idx=True).split("_")[-1])
new_checkpoint_path = os.path.join(checkpoint_path, f"checkpoint_{last_checkpoint_num + 1}")
if MASTER_PROCESS:
os.makedirs(new_checkpoint_path, exist_ok=True)
checkpoint_path = new_checkpoint_path
self.accelerator.save_state(checkpoint_path, safe_serialization=self.safe_serialization)
loss_tracker_path = os.path.join(checkpoint_path, TRAIN_LOSS_STATE_FILE)
self.train_loss_state.save(loss_tracker_path)
self.state.num_checkpoints_made += 1
if MASTER_PROCESS:
training_state_dict = self.state.to_dict()
training_state_dict["epoch"] = epoch
training_state_dict["train_step"] = train_step
training_state_dict["global_step"] = global_step
training_state_dict["evaluations_done"] = evaluations_done
training_state_dict["finished"] = finished
training_state_path = os.path.join(checkpoint_path, STATE_FILE)
self.state.save(training_state_path, training_state_dict)
tqdm.write(f"\033[A\033[K{time_prefix()} Checkpoint saved.")
self.monitor.log_checkpoint()
def epoch_iterator(self):
"""Epoch iterator handling logic for checkpointing."""
start = self.state.epoch
for epoch in range(start, self.hps.epochs):
self.state.epoch = epoch
self.monitor.log_epoch(epoch)
self.state.is_end_of_epoch = False
self.state.is_last_epoch = epoch == self.hps.epochs - 1
self.callback.on_epoch_start()
yield epoch
self.callback.on_epoch_end()
if not self._module._extended and self._scheduler is not None and self.hps.step_scheduler_per_epoch:
self.callback.on_before_scheduler_step(self._scheduler)
self._scheduler.step()
self.callback.on_after_scheduler_step(self._scheduler)
if self._checkpointing_when_epoch_ends and (self.state.epoch + 1) % self.checkpoint_every == 0:
self._save_checkpoint(
self.state.epoch + 1,
self.state.train_step, # always 0 at this stage
self.state.global_step,
self.state.evaluations_done,
# flag as finished if checkpointing at the end of the last epoch
finished=self.state.is_last_epoch,
)
def _prepare(
self,
module: AcceleratorModule,
model: nn.Module,
teacher: Optional[nn.Module],
train_dataloader: Optional[DataLoader],
val_dataloader: Optional[dict[Any, DataLoader]],
optimizer: Optional[Optimizer],
scheduler: Optional[LRScheduler],
batch_device_placement: bool = True,
) -> tuple[nn.Module, Optional[nn.Module], DataLoader, Optional[DataLoader], Optimizer, Optional[LRScheduler]]:
"""
Call Accelerate's backend to prepare instances for distributed training. This will also load states for objects
in case of resuming training.
"""
if not self.enable_prepare_logging and self.accelerator.distributed_type == DistributedType.DEEPSPEED:
from deepspeed.utils import logger
logger.setLevel(logging.WARNING)
if self.gradient_checkpointing:
if hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable(self.gradient_checkpointing_kwargs)
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
# DeepSpeed requires contiguous parameters
for param in model.parameters():
if not param.is_contiguous():
param.data = param.data.contiguous()
if self.compile and DEBUG_MODE < 2:
model = torch.compile(model, **self.compile_kwargs)
if teacher is not None:
teacher = torch.compile(teacher, **self.compile_kwargs)
if val_dataloader is not None:
for k, dataloader in val_dataloader.items():
val_dataloader[k] = self.accelerator.prepare_data_loader(dataloader)
if self.accelerator.distributed_type == DistributedType.FSDP:
# ignore model preparation since it was already done before (only in the case of FSDP)
train_dataloader, optimizer, scheduler = self.accelerator.prepare(train_dataloader, optimizer, scheduler)
else:
model, train_dataloader, optimizer, scheduler = self.accelerator.prepare(
model, train_dataloader, optimizer, scheduler
)
if self.accelerator.distributed_type != DistributedType.DEEPSPEED and teacher is not None:
teacher = self.accelerator.prepare_model(teacher)
if self.safe_mode or self.accelerator.distributed_type == DistributedType.FSDP:
module.model = model
if self.accelerator.distributed_type == DistributedType.MULTI_GPU:
module.model = _DistributedDataParallel(module.model)
if scheduler is not None:
self.accelerator.register_for_checkpointing(scheduler)
# load states if resuming
if self.resume:
self.callback.on_resume()
if os.path.exists(self.checkpoint_path):
checkpoint_path = self._get_current_checkpoint_path()
if checkpoint_path.endswith("checkpoint_0"):
raise FileNotFoundError("Checkpoint directory is empty or not found.")
self.accelerator.load_state(checkpoint_path)
else:
raise FileNotFoundError(f"'{self.checkpoint_path}' was not found.")
cpu = torch.device("cpu")
if not batch_device_placement and train_dataloader is not None:
train_dataloader.device = cpu
for k in val_dataloader.keys():
val_dataloader[k].device = cpu
return model, teacher, train_dataloader, val_dataloader, optimizer, scheduler
def _get_current_checkpoint_path(self, ignore_resume_idx: bool = False) -> str:
"""
Get the checkpoint path based on the 'resume' argument or the latest checkpoint.
If this returns a path ending with "checkpoint_0", it means that the checkpoint directory is empty or not found.
"""
checkpoint_path = self.checkpoint_path
if self.multiple_checkpoints:
num_checkpoints = len(os.listdir(checkpoint_path)) if os.path.exists(checkpoint_path) else 0
if num_checkpoints > 0:
if type(self.resume) is int and self.resume != -1 and not ignore_resume_idx:
# load the checkpoint at the given index
checkpoint_path = os.path.join(checkpoint_path, f"checkpoint_{self.resume}")
else:
# find the latest checkpoint by getting the maximum checkpoint number
latest_checkpoint = max(os.listdir(checkpoint_path), key=lambda x: int(x.split("_")[-1]))
checkpoint_path = os.path.join(checkpoint_path, latest_checkpoint)
else:
# to handle creation afterwards
checkpoint_path = os.path.join(checkpoint_path, "checkpoint_0")
return checkpoint_path
def _get_optimizer(self, module: AcceleratorModule) -> Optimizer:
"""Get optimizer from either module or trainer."""
optimizer = module.get_optimizer()
if optimizer is None:
optimizer = self.hps.optimizer
fused_available = "fused" in inspect.signature(optimizer).parameters
optim_kwargs = self.hps.optim_kwargs
optim_kwargs["fused"] = fused_available and "cuda" in self.accelerator.device.type
filtered_kwargs = filter_kwargs(optim_kwargs, optimizer)
optimizer = optimizer(module.model.parameters(), **filtered_kwargs)
return optimizer
def _get_scheduler(
self, module: AcceleratorModule, optimizer: Optimizer, num_training_steps: int, num_epochs: int
) -> Optional[LRScheduler]:
"""Get scheduler from either module or trainer."""
scheduler = module.get_scheduler(optimizer, num_training_steps, num_epochs)
if self.hps.scheduler is not None and scheduler is None:
schlr_kwargs = self.hps.scheduler_kwargs
schlr_kwargs["last_epoch"] = -1
schlr_kwargs["steps_per_epoch"] = num_training_steps
total_steps = num_training_steps * num_epochs
schlr_kwargs["num_training_steps"] = total_steps
schlr_kwargs["epochs"] = num_epochs
if "num_warmup_steps" in schlr_kwargs and isinstance(schlr_kwargs["num_warmup_steps"], float):
if schlr_kwargs["num_warmup_steps"] < 0.0 or schlr_kwargs["num_warmup_steps"] > 1.0:
raise ValueError(
"If 'num_warmup_steps' is a ratio (float value), it needs to be a value between 0 and 1."
)
schlr_kwargs["num_warmup_steps"] = round(total_steps * schlr_kwargs["num_warmup_steps"])
elif "warmup_ratio" in schlr_kwargs:
if schlr_kwargs["warmup_ratio"] > 1.0:
raise ValueError(
"'warmup_ratio' value in scheduler configuration needs to be a value between 0 and 1."
)
schlr_kwargs["num_warmup_steps"] = round(total_steps * schlr_kwargs["warmup_ratio"])
scheduler = self.hps.scheduler
filtered_kwargs = filter_kwargs(schlr_kwargs, scheduler)
scheduler = scheduler(optimizer, **filtered_kwargs)
return scheduler
def _get_dataloaders(
self,
module: AcceleratorModule,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Union[list[Dataset], dict[Any, Dataset]]] = None,
) -> tuple[DataLoader, Optional[dict[Any, DataLoader]]]:
"""Get DataLoaders for training and validation. Validation dataloaders will be wrapped in a dictionary."""
is_tuple = hasattr(self.hps.batch_size, "__len__")
if is_tuple and len(self.hps.batch_size) != 2:
raise ValueError(
"'batch_size' in hyper parameters needs to be an integer value or a tuple with 2 values "
"(one for training and the other for validation)."
)
train_batch_size = self.hps.batch_size[0] if is_tuple else self.hps.batch_size
val_batch_size = self.hps.batch_size[1] if is_tuple else self.hps.batch_size
dl_args = {
"pin_memory": self.dataloader_pin_memory,
"num_workers": self.dataloader_num_workers,
"drop_last": self.dataloader_drop_last,
}
train_dataloader = module.get_train_dataloader(train_dataset)
assert train_dataloader is not None or train_dataset is not None, (
"Either 'train_dataset' or 'get_train_dataloader' must be given."
)
# ignoring 'train_dataset' if 'get_train_dataloader' was implemented in AcceleratorModule
if train_dataset is not None and train_dataloader is None:
shuffle_train = self.shuffle_train if self.sampler is None else None
train_dataloader = DataLoader(
train_dataset,
shuffle=shuffle_train,
sampler=self.samplers,
batch_size=train_batch_size,
collate_fn=self.collate_fn_train,
**dl_args,
)
val_dataloader = module.get_validation_dataloader(val_dataset)
if val_dataloader is not None and not isinstance(val_dataloader, (list, dict)):
val_dataloader = [val_dataloader]
# ignoring 'val_dataset' if 'get_validation_dataloader' was implemented in AcceleratorModule
if val_dataset is not None and val_dataloader is None:
val_dataset = (
val_dataset if isinstance(val_dataset, dict) else {str(i): ds for i, ds in enumerate(val_dataset)}
)
val_dataloader = {}
for k, dataset in val_dataset.items():
val_dataloader[k] = DataLoader(
dataset,
batch_size=val_batch_size,
collate_fn=self.collate_fn_val,
**dl_args,
)
return train_dataloader, val_dataloader
def _get_module(
self, module: Union[AcceleratorModule, str, Union[tuple[str, str], tuple[str, Any]]], **kwargs: Any
) -> AcceleratorModule:
"""Get module corresponding to the arguments given."""
if isinstance(module, str):
return AcceleratorModule.from_hf(module, **kwargs)
elif isinstance(module, tuple):
return AcceleratorModule.from_hf(*module, **kwargs)
return module
def _init_trackers(self) -> Optional[str]:
"""Initialize all trackers along with the training configuration from Hyper Parameters and 'additional_tracker_config'."""
self.accelerator.log_with = [self.tracker.logger_type]
track_name = os.path.basename(self.model_path) if self.track_name is None else self.track_name
init_kwargs = self.tracker.get_init_kwargs(**self.init_kwargs)
config = self.hps.get_config()
config["effective_batch_size"] = (
tuple(batch_size * self.accelerator.num_processes for batch_size in self.hps.batch_size)
if isinstance(self.hps.batch_size, (tuple, list))
else self.hps.batch_size * self.accelerator.num_processes
)
if self.grad_accumulation_steps > 1:
obj = config["effective_batch_size"]
if isinstance(obj, tuple):
config["effective_batch_size"] = (obj[0] * self.grad_accumulation_steps, obj[1])
else:
config["effective_batch_size"] = (obj * self.grad_accumulation_steps, obj)
config["grad_accumulation_steps"] = self.grad_accumulation_steps
config["gradient_checkpointing"] = self.gradient_checkpointing
config["gradient_checkpointing_kwargs"] = self.gradient_checkpointing_kwargs
config["clip_grad"] = self.clip_grad
config["num_processes"] = self.accelerator.num_processes
tracker_config = config | self.additional_tracker_config
# register signals to end process safely
def end_process(signum, frame):
if self.tracker is not None:
self.tracker.end(status="KILLED")
exit(0)
def end_on_exception(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
if self.tracker is not None:
self.tracker.end(status="FAILED")
traceback.print_exception(exc_type, exc_value, exc_traceback)
signal.signal(signal.SIGTERM, end_process)
signal.signal(signal.SIGINT, end_process)
sys.excepthook = end_on_exception
if MASTER_PROCESS:
# TODO with a Tracker Wrapper this should be fixed.
_is_url = is_url(self.logging_dir)
if _is_url and not self._logging:
raise RuntimeError(f"Cannot log results in '{self.logging_dir}' because 'log_with' was not declared.")
self.accelerator.init_trackers(track_name, config=tracker_config, init_kwargs=init_kwargs)
self.tracker.set_tracking_uri(self.logging_dir)
return self.tracker.run_id
def _get_grad_norm(self, norm_type: float = 2.0) -> Union[torch.Tensor, float]:
"""Calculates grad norm of model."""
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
return self.wrapped_model.get_global_grad_norm()
total_norm = 0
for p in self.unwrapped_model.parameters():
if p.grad is not None:
total_norm += p.grad.detach().norm(norm_type) ** norm_type
return total_norm ** (1.0 / norm_type)
[docs]
def log_artifact(self, path: str):
"""
Logs an artifact to the current run.
Args:
path (`str`):
Path to the file to be logged as an artifact.
"""
if self._logging and DEBUG_MODE < 1 and self.tracker_initialized:
self.tracker.log_artifact(path)
[docs]
def log_artifacts(self, path: str):
"""
Logs multiple artifacts from a directory to the current run.
Args:
path (`str`):
Path to the directory to be logged as an artifact.
"""
if self._logging and DEBUG_MODE < 1 and self.tracker_initialized:
self.tracker.log_artifacts(path)
def _prepare_metrics(
self,
metrics: Union[Metric, list[Metric], dict[Any, Union[Metric, list[Metric]]]],
val_dataloader: Optional[dict[Any, DataLoader]],
) -> dict[Any, list[Metric]]:
"""Prepare metrics in relation to validation datasets, running checks for types and fixing them if possible."""
if isinstance(metrics, Metric):
metrics = {k: [metrics] for k in val_dataloader.keys()}
elif isinstance(metrics, list):
metrics = {k: metrics for k in val_dataloader.keys()}
elif isinstance(metrics, dict):
assert all(k in val_dataloader for k in metrics), (
f"There is a mismatch between given metrics and validation datasets. Got {list(metrics.keys())} "
f"for 'metrics' and {list(val_dataloader.keys())} for validation datasets."
)
metrics = {k: (v if isinstance(v, list) else [v]) for k, v in metrics.items()}
return metrics
[docs]
def register_model_saving(
self,
model_saving: str,
saving_below: Optional[float] = None,
saving_above: Optional[float] = None,
):
"""
Register a type of model saving.
Args:
model_saving (`str`):
Type of model saving. It can be `"best_valid_loss"` (default), `"best_train_loss"` or in format of
`"best_{METRIC}"`. **NOTE**: `"best_"` is optional. Also, all metrics should relate directly to metrics
and validation datasets. This can also be in the form of `"best_{METRIC}@{DATASET}"` (metric at a specific dataset),
`"best_{METRIC}@{DATASET1}@{DATASET2}"` (metric at dataset1 and dataset2), `"best_{METRIC1}@{DATASET1}/{METRIC1}@{dataset2}"`
(best metric1 at dataset1 and best metric2 at dataset2), `"best_{METRIC1}/{METRIC2}@{DATASET2}"` (best metric1 between all
datasets containing this metric and best metric2 at dataset2 only), etc.
saving_below (`float`, *optional*, defaults to `None`):
Register this model saving to only be saved whenever its values are lower than this.
saving_above (`float`, *optional*, defaults to `None`):
Register this model saving to only be saved whenever its values are above than this.
"""
saving_below = saving_below if saving_below is not None else float("inf")
saving_above = saving_above if saving_above is not None else float("-inf")
model_saving = f"best_{model_saving}" if not model_saving.startswith("best_") else model_saving
self.model_saving[model_saving] = (saving_below, saving_above)
def _get_comparator(self, metric: str) -> str:
"""Get comparator for a given metric."""
for metrics in self.metrics.values():
for _metric in metrics:
if metric == _metric.main_metric:
return _metric.comparator
raise RuntimeError(f"No comparator was found for metric '{metric}'.")