# Copyright 2025 ghanvert. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union
import numpy as np
import torch
from torch.utils.data._utils.collate import default_collate
from transformers.tokenization_utils_base import BatchEncoding
def collate_tokenizer_inputs(batch: list, pad_token_id: int, label_pad_token_id: int, padding_side: str):
include_attention_mask = "attention_mask" in batch[0]
include_labels = "labels" in batch[0]
inputs = []
labels = []
for feature in batch:
inputs.append(len(feature["input_ids"]))
if include_labels:
labels.append(len(feature["labels"]))
max_input_length = max(inputs)
if include_labels:
max_label_length = max(labels)
inputs = []
attention_masks = []
labels = []
for feature in batch:
inputs_remainder = [pad_token_id] * (max_input_length - len(feature["input_ids"]))
if include_attention_mask:
attention_masks_remainder = [0] * (max_input_length - len(feature["input_ids"]))
if include_labels:
labels_remainder = [label_pad_token_id] * (max_label_length - len(feature["labels"]))
if include_labels and isinstance(feature["labels"], list):
feature = {
"input_ids": feature["input_ids"] + inputs_remainder,
"attention_mask": (
feature["attention_mask"] + attention_masks_remainder if include_attention_mask else None
),
"labels": (feature["labels"] + labels_remainder),
}
elif padding_side == "right":
feature = {
"input_ids": np.concatenate([feature["input_ids"], inputs_remainder]).astype(np.int64),
"attention_mask": (
np.concatenate([feature["attention_mask"], attention_masks_remainder]).astype(np.int64)
if include_attention_mask
else None
),
"labels": (
np.concatenate([feature["labels"], labels_remainder]).astype(np.int64) if include_labels else None
),
}
else:
feature = {
"input_ids": np.concatenate([inputs_remainder, feature["input_ids"]]).astype(np.int64),
"attention_mask": (
np.concatenate([attention_masks_remainder, feature["attention_mask"]]).astype(np.int64)
if include_attention_mask
else None
),
"labels": (
np.concatenate([labels_remainder, feature["labels"]]).astype(np.int64) if include_labels else None
),
}
inputs.append(feature["input_ids"])
if include_attention_mask:
attention_masks.append(feature["attention_mask"])
if include_labels:
labels.append(feature["labels"])
output_dict = {"input_ids": torch.from_numpy(np.stack(inputs))}
if include_attention_mask:
output_dict["attention_mask"] = torch.from_numpy(np.stack(attention_masks))
if include_labels:
output_dict["labels"] = torch.from_numpy(np.stack(labels))
return output_dict
# function derived from 'transformers' library: https://github.com/huggingface/transformers/blob/main/src/transformers/data/data_collator.py#L52
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
"""
Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
"""
# To avoid errors when using Feature extractors
if not hasattr(tokenizer, "deprecation_warnings"):
return tokenizer.pad(*pad_args, **pad_kwargs)
# Save the state of the warning, then disable it
warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
try:
padded = tokenizer.pad(*pad_args, **pad_kwargs)
finally:
# Restore the state of the warning.
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
return padded
def stack_tensor_dict(tensor_dicts: list[dict[torch.Tensor]]):
keys = tensor_dicts[0].keys()
return {key: torch.stack([d[key] for d in tensor_dicts]) for key in keys}
def stack_iterables(iterables: list[list | tuple]):
return torch.stack([torch.tensor(iterable) for iterable in iterables])
[docs]
class DataCollatorForSeq2Seq:
"""
Automatically adds efficient padding for 'inputs', 'attention_mask' and 'labels'.
This works for multiple inputs from the dataset logic. If any of the objects does not
correspond to a dictionary-like structure of a decoded tokenizer's output, it will
apply the default collate function derived from PyTorch.
The output of a dictionary-like with key 'input_ids' will have the following keys:
- `input_ids`
- `attention_mask` (if found)
- `labels` (if found)
This implementation derives from `transformers` library:
https://github.com/huggingface/transformers/blob/main/src/transformers/data/data_collator.py#L543
Args:
tokenizer (`Any`):
Tokenizer using HuggingFace standard.
label_pad_token_id (`int`, *optional*, defaults to `-100`):
Label pad token id. Labels with this value will be ignored in the training process.
"""
[docs]
def __init__(self, tokenizer: Any, label_pad_token_id: int = -100):
self.tokenizer = tokenizer
self.pad_token_id = self.tokenizer.pad_token_id
self.label_pad_token_id = label_pad_token_id
self.padding_side = self.tokenizer.padding_side
def _collate_nested_dicts(self, data):
elem = data[0]
if isinstance(elem, (dict, BatchEncoding)) and "input_ids" in elem:
return collate_tokenizer_inputs(
data,
pad_token_id=self.pad_token_id,
label_pad_token_id=self.label_pad_token_id,
padding_side=self.padding_side,
)
elif isinstance(elem, (dict, BatchEncoding)) and "input_ids" not in elem:
return {key: self._collate_nested_dicts([d[key] for d in data]) for key in elem}
return default_collate(data)
def _collate_multiple(self, batch: list, length_elems: int) -> Union[tuple, Any]:
stacked_elems = [[] for _ in range(length_elems)]
for elem in batch:
for elem_index in range(length_elems):
stacked_elems[elem_index].append(elem[elem_index])
del batch
for elem_index in range(length_elems):
elem = stacked_elems[elem_index][0]
if isinstance(elem, (dict, BatchEncoding)) and "input_ids" in elem:
stacked_elems[elem_index] = collate_tokenizer_inputs(
stacked_elems[elem_index],
pad_token_id=self.pad_token_id,
label_pad_token_id=self.label_pad_token_id,
padding_side=self.padding_side,
)
elif isinstance(elem, dict):
stacked_elems[elem_index] = self._collate_nested_dicts(stacked_elems[elem_index])
else:
stacked_elems[elem_index] = default_collate(stacked_elems[elem_index])
return tuple(stacked_elems) if length_elems > 1 else stacked_elems[0]
def __call__(self, batch: list) -> Union[tuple, Any]:
length_elems = len(batch[0]) if isinstance(batch[0], tuple) else 1
if length_elems == 1:
if isinstance(batch[0], BatchEncoding) or (isinstance(batch[0], dict) and "input_ids" in batch[0]):
return collate_tokenizer_inputs(
batch,
pad_token_id=self.pad_token_id,
label_pad_token_id=self.label_pad_token_id,
padding_side=self.padding_side,
)
else:
return default_collate(batch)
return self._collate_multiple(batch, length_elems)
[docs]
class DataCollatorForLongestSequence:
"""
Automatically adds efficient padding for inputs, while preserving static labels.
If output of `__getitem__` Dataset logic looks like:
`return x, y` (x being a dictionary containing keys `input_ids` and `attention_mask`)
then the output of the collator function will be `(x, y)`, `x` being the padded inputs with
the same keys and `y` the stacked labels.
If output of `__getitem__` Dataset logic looks like:
`return x` (x being a dictionary containing keys `input_ids` and `attention_mask`)
then the output of the collator function will be `x`, being the padded inputs with the same keys.
NOTE: This collator should be used when labels on your dataset logic are not sequences. If that's the case,
see `DataCollatorForSeq2Seq`.
Args:
tokenizer (`Any`):
Tokenizer using HuggingFace standard.
"""
[docs]
def __init__(self, tokenizer: Any, torch_stack: bool = True):
self.tokenizer = tokenizer
self.pad_token_id = self.tokenizer.pad_token_id
self.padding_side = self.tokenizer.padding_side
self.torch_stack = torch_stack
self.device = None
def __call__(self, batch: list):
inputs = []
for feature in batch:
# if feature is a tuple, then it would be of type (inputs, targets)
if isinstance(feature, tuple):
feature = feature[0] # just take first element
inputs.append(len(feature["input_ids"]))
max_input_length = max(inputs)
inputs = []
attention_masks = []
labels = []
for feature in batch:
if isinstance(feature, tuple):
labels.append(feature[1])
feature = feature[0]
inputs_remainder = [self.pad_token_id] * (max_input_length - len(feature["input_ids"]))
attention_masks_remainder = [0] * (max_input_length - len(feature["input_ids"]))
if self.padding_side == "right":
feature = {
"input_ids": np.concatenate([feature["input_ids"], inputs_remainder]).astype(np.int64),
"attention_mask": np.concatenate([feature["attention_mask"], attention_masks_remainder]).astype(
np.int64
),
}
else:
feature = {
"input_ids": np.concatenate([inputs_remainder, feature["input_ids"]]).astype(np.int64),
"attention_mask": np.concatenate([attention_masks_remainder, feature["attention_mask"]]).astype(
np.int64
),
}
inputs.append(feature["input_ids"])
attention_masks.append(feature["attention_mask"])
output = {
"input_ids": torch.from_numpy(np.stack(inputs)),
"attention_mask": torch.from_numpy(np.stack(attention_masks)),
}
if len(labels) > 0:
if isinstance(labels[0], dict):
keys = labels[0].keys()
if self.torch_stack:
out_labels = {k: torch.stack([label[k] for label in labels]) for k in keys}
else:
out_labels = {k: [label[k] for label in labels] for k in keys}
elif isinstance(labels[0], torch.Tensor):
out_labels = labels
if self.torch_stack:
out_labels = torch.stack(out_labels)
elif isinstance(labels[0], (list, tuple, np.ndarray)):
out_labels = [torch.tensor(label, device=self.device) for label in labels]
if self.torch_stack:
out_labels = torch.stack(out_labels)
else:
out_labels = None
return output, out_labels
return output
[docs]
class DataCollatorForLanguageModeling:
"""
Collator function to implement automatic language modeling, such as
Masked Language Modeling.
Args:
tokenizer (`Any`):
Tokenizer using HuggingFace standard.
mlm (`bool`, *optional*, defaults to `True`):
Implements Masked Language Modeling.
mlm_probability (`float`, *optional*, defaults to `0.15`):
How much masking is implemented in Masked Language Modeling.
ignore_index (`int`, *optional*, defaults to `-100`):
Label pad token id. Labels with this value will be ignored in the training process.
masked_to_mask (`float`, *optional*, defaults to `0.8`):
Probability to replace masked input tokens with mask token. The half remaining percent will
replace masked input tokens with random word, and the other half will keep the masked input tokens
unchanged. If `apply_random_words` is set to `False`, then the entire remaining percent will be unchanged.
apply_random_words (`bool`, *optional*, defaults to `True`):
Whether to apply random words during Masked Language Modeling.
force_one_output (`bool`, *optional*, defaults to `False`):
Whether to force output one output. If Dataset object `__getitem__` function returns a tuple, only the first
element will be considered and extra targets will be dropped.
"""
[docs]
def __init__(
self,
tokenizer: Any,
mlm: bool = True,
mlm_probability: float = 0.15,
ignore_index: int = -100,
masked_to_mask: float = 0.8,
apply_random_words: bool = True,
force_one_output: bool = False,
) -> Union[dict, tuple[dict, torch.Tensor]]:
self.tokenizer = tokenizer
self.mlm = mlm
self.mlm_probability = mlm_probability
self.ignore_index = ignore_index
self.masked_to_mask = masked_to_mask
self.apply_random_words = apply_random_words
self.force_one_output = force_one_output
def __call__(self, batch: list) -> dict:
has_extra_targets = isinstance(batch[0], (tuple, list))
if not has_extra_targets:
tokenizer_dict = pad_without_fast_tokenizer_warning(self.tokenizer, batch, return_tensors="pt")
else:
tokenizer_dict_batch, extra_targets = [], []
for elems in batch:
tokenizer_dict_batch.append(elems[0])
if not self.force_one_output:
extra_targets.append(elems[1:])
tokenizer_dict = pad_without_fast_tokenizer_warning(
self.tokenizer, tokenizer_dict_batch, return_tensors="pt"
)
special_tokens_mask = tokenizer_dict.pop("special_tokens_mask", None)
if self.mlm:
tokenizer_dict["input_ids"], tokenizer_dict["labels"] = self.torch_mask_tokens(
tokenizer_dict["input_ids"], special_tokens_mask=special_tokens_mask
)
else:
labels = tokenizer_dict["input_ids"].clone()
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = self.ignore_index
tokenizer_dict["labels"] = labels
if has_extra_targets and not self.force_one_output:
num_elems = len(extra_targets[0])
extra_targets_return = [[] for _ in range(num_elems)]
for target in extra_targets:
for idx in range(num_elems):
tgt = target[idx]
extra_targets_return[idx].append(tgt)
stack_funcs = []
for idx, extra_target_return in enumerate(extra_targets_return):
first_elem = extra_target_return[0]
if isinstance(first_elem, torch.Tensor):
stack_funcs.append(torch.stack)
elif isinstance(first_elem, dict):
stack_funcs.append(stack_tensor_dict)
elif isinstance(first_elem, (tuple, list)):
stack_funcs.append(stack_iterables)
else:
stack_funcs.append(torch.tensor)
extra_targets_return = [
stack_funcs[idx](extra_target_return) for idx, extra_target_return in enumerate(extra_targets_return)
]
if has_extra_targets and not self.force_one_output:
return tokenizer_dict, *extra_targets_return
return tokenizer_dict
def torch_mask_tokens(self, inputs: torch.Tensor, special_tokens_mask):
labels = inputs.clone()
probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = self.ignore_index
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.masked_to_mask)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
if self.apply_random_words:
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
return inputs, labels