# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, Sequence import torch from transformers import DataCollatorForSeq2Seq @dataclass class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): r""" Data collator for pairwise data. """ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: r""" Pads batched data to the longest sequence in the batch. We generate 2 * n examples where the first n examples represent chosen examples and the last n examples represent rejected examples. """ concatenated_features = [] for key in ("chosen", "rejected"): for feature in features: target_feature = { "input_ids": feature["{}_input_ids".format(key)], "attention_mask": feature["{}_attention_mask".format(key)], "labels": feature["{}_labels".format(key)], } if "pixel_values" in feature: target_feature["pixel_values"] = feature["pixel_values"] if "{}_token_type_ids".format(key) in feature: target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] concatenated_features.append(target_feature) return super().__call__(concatenated_features) @dataclass class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): r""" Data collator for KTO data. """ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: target_features = [] kl_features = [] kto_tags = [] for feature in features: target_feature = { "input_ids": feature["input_ids"], "attention_mask": feature["attention_mask"], "labels": feature["labels"], } kl_feature = { "input_ids": feature["kl_input_ids"], "attention_mask": feature["kl_attention_mask"], "labels": feature["kl_labels"], } if "pixel_values" in feature: target_feature["pixel_values"] = feature["pixel_values"] if "token_type_ids" in feature: target_feature["token_type_ids"] = feature["token_type_ids"] kl_feature["token_type_ids"] = feature["kl_token_type_ids"] target_features.append(target_feature) kl_features.append(kl_feature) kto_tags.append(feature["kto_tags"]) batch = super().__call__(target_features) kl_batch = super().__call__(kl_features) batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_labels"] = kl_batch["labels"] if "token_type_ids" in batch: batch["kl_token_type_ids"] = kl_batch["token_type_ids"] batch["kto_tags"] = torch.tensor(kto_tags) return batch