Source code for accmt.states

# Copyright 2025 ghanvert. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from dataclasses import asdict, dataclass, field
from typing import Any, Optional

import torch
from accelerate import Accelerator

from .utils.globals import MASTER_PROCESS


[docs] @dataclass class TrainingState: """ General training state. Args: batch_iteration (`int`): Batch iteration index. This is incremented every time a train step is done. This ignores gradient accumulation. If gradient accumulation is not used, this is the same as `global_step`. global_step (`int`): Global step index. This is incremented every time a train step and a gradient accumulation step is done. train_step (`int`): Training step index inside a training dataloader (can be considered as batch index). val_step (`int`): Validation step index inside an evaluation loop (can be considered as batch index). epoch (`int`): Epoch index. is_end_of_epoch (`bool`): Flag to check if current state is at the end of an epoch. is_last_training_batch (`bool`): Flag to check if current state is processing the last training batch. is_last_validation_batch (`bool`): Flag to check if current state is processing the last validation batch is_last_epoch (`bool`): Flag to check if current state is processing the last epoch. evaluations_done (`int`): Number of evaluations done. additional_metrics (`dict`): Additional metrics (e.g. accuracy, bleu, f1, etc). model_savings (`dict`): Bests model saving values. patience_left (`dict`): Patience left per model saving (in case it's implemented, otherwise values are set to -1). best_train_loss (`float`): Best training loss achieved. best_valid_loss (`float`): Best validation loss achieved. finished (`bool`, *optional*, defaults to `False`): Flag to identify if the process has already finished. num_checkpoints_made (`int`, *optional*, defaults to `0`): Number of checkpoints made. train_dataloader_idx (`int`, *optional*, defaults to `None`): Index of the train dataloader being used. """ batch_iteration: int = field(default=0) global_step: int = field(default=0) train_step: int = field(default=0) val_step: int = field(default=0) epoch: int = field(default=0) is_end_of_epoch: bool = field(default=False) is_last_training_batch: bool = field(default=False) is_last_validation_batch: bool = field(default=False) is_last_epoch: bool = field(default=False) evaluations_done: int = field(default=0) additional_metrics: dict[str, dict[str, Any]] = field(default_factory=dict) patience_left: dict[str, int] = field(default_factory=dict) best_train_loss: float = field(default=float("inf")) finished: bool = field(default=False) num_checkpoints_made: int = field(default=0) train_dataloader_idx: int = field(default=0) def update( self, *, batch_iteration: Optional[int] = None, global_step: Optional[int] = None, train_step: Optional[int] = None, val_step: Optional[int] = None, epoch: Optional[int] = None, is_end_of_epoch: Optional[bool] = None, is_last_training_batch: Optional[bool] = None, is_last_validation_batch: Optional[bool] = None, is_last_epoch: Optional[bool] = None, evaluations_done: Optional[int] = None, additional_metrics: Optional[dict[str, dict[str, Any]]] = None, patience_left: Optional[dict[str, int]] = None, best_train_loss: Optional[float] = None, finished: Optional[bool] = None, num_checkpoints_made: Optional[int] = None, train_dataloader_idx: Optional[int] = None, **kwargs: Any, ): # ignore positional arguments for safety updates = { "batch_iteration": batch_iteration, "global_step": global_step, "train_step": train_step, "val_step": val_step, "epoch": epoch, "is_end_of_epoch": is_end_of_epoch, "is_last_epoch": is_last_epoch, "is_last_training_batch": is_last_training_batch, "is_last_validation_batch": is_last_validation_batch, "evaluations_done": evaluations_done, "additional_metrics": additional_metrics, "patience_left": patience_left, "best_train_loss": best_train_loss, "finished": finished, "num_checkpoints_made": num_checkpoints_made, "train_dataloader_idx": train_dataloader_idx, } for key, value in updates.items(): if value is not None: setattr(self, key, value) def to_dict(self) -> dict[str, Any]: return asdict(self) def from_dict(self, _dict: dict[str, Any]): self.update(**_dict) def save(self, to: str, _dict: Optional[dict[str, Any]] = None): state = self.to_dict() if _dict is None else _dict json.dump(state, open(to, "w"), indent=4, ensure_ascii=False) def load(self, path: str): _dict = json.load(open(path)) self.from_dict(_dict)
class LossState: @torch.inference_mode() def __init__( self, accelerator: Accelerator, device: torch.device, log_every: int = -1, pin_memory: bool = True, include_per_batch: bool = True, ): self.accelerator = accelerator if device == torch.device("cpu") and pin_memory: pin_memory = False if log_every < 0: int_dtype = torch.int64 # value of -1 or lower means maximum dtype elif log_every < 2147483648: int_dtype = torch.int32 else: int_dtype = torch.int64 kwargs = {"device": device, "pin_memory": pin_memory} # Keep all tensors in the same device to avoid GPU -> CPU transfer. # Also, all tensors must be same data type to avoid internal convertions. self.num_batches = torch.tensor(0, dtype=int_dtype, **kwargs) if include_per_batch else None self.batch_loss = torch.tensor(0, dtype=torch.float32, **kwargs) if include_per_batch else None self.num_steps = torch.tensor(0, dtype=int_dtype, **kwargs) self.total_loss = torch.tensor(0, dtype=torch.float32, **kwargs) # create this extra tensor to avoid creating additional tensors to compute additions (~1.07x speedup) self._incrementor = torch.tensor(1, dtype=int_dtype, **kwargs) # This produces extra memory consumption in all devices, so per device # there are going to be (at maximum) 28 bytes in total allocated. self._include_per_batch = include_per_batch @torch.inference_mode() def get_batch_loss(self) -> float: if not self._include_per_batch: self.accelerator.end_training() raise RuntimeError("Batch loss calculation is not implemented. Use 'include_per_batch'.") batch_loss = self.batch_loss / self.num_batches batch_loss = self.accelerator.reduce(batch_loss, reduction="mean") # GPU intercommunication self.num_batches.zero_() # reset batches count self.batch_loss.zero_() # reset batch loss tracker return batch_loss.item() # CPU transfer @torch.inference_mode() def add_batch_loss(self, value: torch.Tensor): if not self._include_per_batch: self.accelerator.end_training() raise RuntimeError("Batch loss calculation is not implemented. Use 'include_per_batch'.") self.batch_loss.add_(value) self.num_batches.add_(self._incrementor) @torch.inference_mode() def get_total_loss(self) -> float: total_loss = self.total_loss / self.num_steps total_loss = self.accelerator.reduce(total_loss, reduction="mean") # GPU intercommunication self.total_loss.zero_() self.num_steps.zero_() return total_loss.item() # CPU transter @torch.inference_mode def add_total_loss(self, value: torch.Tensor): self.total_loss.add_(value) self.num_steps.add_(self._incrementor) @torch.inference_mode() def reset(self): if self._include_per_batch: self.num_batches.zero_() self.batch_loss.zero_() self.num_steps.zero_() self.total_loss.zero_() def save(self, path: str): _dict = { "steps": self.accelerator.gather(self.num_steps).cpu(), "total_loss": self.accelerator.gather(self.total_loss).cpu(), } if self._include_per_batch: _dict.update( { "batches": self.accelerator.gather(self.num_batches).cpu(), "batch_loss": self.accelerator.gather(self.batch_loss).cpu(), } ) if MASTER_PROCESS: torch.save(_dict, path) @torch.inference_mode() def load(self, path: str): _dict = torch.load(path) with self.accelerator.split_between_processes(_dict) as inputs: if self._include_per_batch: self.num_batches.fill_(inputs["batches"].squeeze_()) self.batch_loss.fill_(inputs["batch_loss"].squeeze_()) self.num_steps.fill_(inputs["steps"].squeeze_()) self.total_loss.fill_(inputs["total_loss"].squeeze_())