Source code for accmt.collate_fns

# Copyright 2025 ghanvert. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Union

import numpy as np
import torch
import transformers
from torch.utils.data._utils.collate import default_collate
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase


class BaseDataCollator(ABC):
    @abstractmethod
    def collate_tokenizer_inputs(self, batch: list[Union[BatchEncoding, dict]]) -> Union[tuple, Any]:
        pass

    def __call__(self, batch: list) -> Union[tuple, Any]:
        first_elem = batch[0]
        if isinstance(first_elem, (dict, BatchEncoding)):
            if "input_ids" in first_elem:
                return self.collate_tokenizer_inputs(batch)
            else:
                collated_dict = {}
                for key in first_elem.keys():
                    values = [elem[key] for elem in batch]
                    collated_dict[key] = self.__call__(values)
                return collated_dict
        elif isinstance(first_elem, (list, tuple)):
            return type(first_elem)(self.__call__(list(elem)) for elem in zip(*batch))

        return default_collate(batch)


[docs] @dataclass class DataCollatorForSeq2Seq(BaseDataCollator): """ `DataCollatorForSeq2Seq` from `transformers` library with a recursive approach to handle diverse structures. For documentation, refer to `DataCollatorForSeq2Seq` class (https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorForSeq2Seq). """ tokenizer: PreTrainedTokenizerBase model: Optional[Any] = None padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None label_pad_token_id: int = -100 def collate_tokenizer_inputs(self, features: list[Union[BatchEncoding, dict]]): label_name = "label" if "label" in features[0].keys() else "labels" labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None # reconvert list[None] to None if necessary # this might occur when we pass {..., "labels": None} if labels is not None and all(label is None for label in labels): labels = None non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] # run through tokenizer without labels to ensure no side effects batch = pad_without_fast_tokenizer_warning( self.tokenizer, non_labels_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) # we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD if labels is not None: if no_padding: if isinstance(features[0][label_name], list): batch["labels"] = list(labels) else: batch["labels"] = [np.concatenate([label, []]) for label in labels] else: max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length if self.pad_to_multiple_of is not None: max_label_length = ( (max_label_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of ) padding_side = self.tokenizer.padding_side if isinstance(features[0][label_name], list): batch["labels"] = [ label + [self.label_pad_token_id] * (max_label_length - len(label)) if padding_side == "right" else [self.label_pad_token_id] * (max_label_length - len(label)) + label for label in labels ] else: batch["labels"] = [ np.concatenate( [ label, np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64), ] ) if padding_side == "right" else np.concatenate( [ np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64), label, ] ) for label in labels ] # reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument if batch.get("labels", None) is not None: batch["labels"] = torch.from_numpy(np.array(batch["labels"], dtype=np.int64)) else: batch["labels"] = None # prepare decoder_input_ids if ( labels is not None and self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels") ): decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"]) batch["decoder_input_ids"] = decoder_input_ids return batch
@dataclass class DataCollatorWithPadding(BaseDataCollator, transformers.DataCollatorWithPadding): """ `DataCollatorWithPadding` from `transformers` library with a recursive approach to handle diverse structures. For documentation, refer to `DataCollatorWithPadding` class (https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorWithPadding). """ tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" def collate_tokenizer_inputs(self, features: list[Union[BatchEncoding, dict]]): return transformers.DataCollatorWithPadding.__call__(self, features) @dataclass class DataCollatorForTokenClassification(BaseDataCollator, transformers.DataCollatorForTokenClassification): """ `DataCollatorForTokenClassification` from `transformers` library with a recursive approach to handle diverse structures. For documentation, refer to `DataCollatorForTokenClassification` class (https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorForTokenClassification). """ tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None label_pad_token_id: int = -100 return_tensors: str = "pt" def collate_tokenizer_inputs(self, features: list[Union[BatchEncoding, dict]]): return transformers.DataCollatorForTokenClassification.__call__(self, features)
[docs] @dataclass class DataCollatorForLanguageModeling(BaseDataCollator, transformers.DataCollatorForLanguageModeling): """ `DataCollatorForLanguageModeling` from `transformers` library with a recursive approach to handle diverse structures. For documentation, refer to `DataCollatorForLanguageModeling` class (https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorForLanguageModeling). """ tokenizer: PreTrainedTokenizerBase mlm: bool = True mlm_probability: float = 0.15 pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" def collate_tokenizer_inputs(self, features: list[Union[BatchEncoding, dict]]): return transformers.DataCollatorForLanguageModeling.__call__(self, features)
@dataclass class DataCollatorForWholeWordMask(BaseDataCollator, transformers.DataCollatorForWholeWordMask): """ `DataCollatorForWholeWordMask` from `transformers` library with a recursive approach to handle diverse structures. For documentation, refer to `DataCollatorForWholeWordMask` class (https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorForWholeWordMask). """ tokenizer: PreTrainedTokenizerBase mlm: bool = True mlm_probability: float = 0.15 pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" def collate_tokenizer_inputs(self, features: list[Union[BatchEncoding, dict]]): return transformers.DataCollatorForWholeWordMask.__call__(self, features) class DataCollatorForPermutationLanguageModeling( BaseDataCollator, transformers.DataCollatorForPermutationLanguageModeling ): """ `DataCollatorForPermutationLanguageModeling` from `transformers` library with a recursive approach to handle diverse structures. For documentation, refer to `DataCollatorForPermutationLanguageModeling` class (https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorForPermutationLanguageModeling). """ tokenizer: PreTrainedTokenizerBase plm_probability: float = 1 / 6 max_span_length: int = 5 # maximum length of a span of masked tokens return_tensors: str = "pt" def collate_tokenizer_inputs(self, features: list[Union[BatchEncoding, dict]]): return transformers.DataCollatorForPermutationLanguageModeling.__call__(self, features)