# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple from .processors.feedback import preprocess_feedback_dataset from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example from .processors.pretrain import preprocess_pretrain_dataset from .processors.supervised import ( preprocess_packed_supervised_dataset, preprocess_supervised_dataset, print_supervised_dataset_example, ) from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from ..hparams import DataArguments from .template import Template def get_preprocess_and_print_func( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[Callable, Callable]: if stage == "pt": preprocess_func = partial( preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args, ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not training_args.predict_with_generate: if data_args.packing: preprocess_func = partial( preprocess_packed_supervised_dataset, template=template, tokenizer=tokenizer, data_args=data_args, ) else: preprocess_func = partial( preprocess_supervised_dataset, template=template, tokenizer=tokenizer, processor=processor, data_args=data_args, ) print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) elif stage == "rm": preprocess_func = partial( preprocess_pairwise_dataset, template=template, tokenizer=tokenizer, processor=processor, data_args=data_args, ) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) elif stage == "kto": preprocess_func = partial( preprocess_feedback_dataset, template=template, tokenizer=tokenizer, processor=processor, data_args=data_args, ) print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) else: preprocess_func = partial( preprocess_unsupervised_dataset, template=template, tokenizer=tokenizer, processor=processor, data_args=data_args, ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) return preprocess_func, print_function