Spaces:
Running
Running
# Copyright 2024 the LlamaFactory team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from functools import partial | |
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple | |
from .processors.feedback import preprocess_feedback_dataset | |
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example | |
from .processors.pretrain import preprocess_pretrain_dataset | |
from .processors.supervised import ( | |
preprocess_packed_supervised_dataset, | |
preprocess_supervised_dataset, | |
print_supervised_dataset_example, | |
) | |
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example | |
if TYPE_CHECKING: | |
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments | |
from ..hparams import DataArguments | |
from .template import Template | |
def get_preprocess_and_print_func( | |
data_args: "DataArguments", | |
training_args: "Seq2SeqTrainingArguments", | |
stage: Literal["pt", "sft", "rm", "ppo", "kto"], | |
template: "Template", | |
tokenizer: "PreTrainedTokenizer", | |
processor: Optional["ProcessorMixin"], | |
) -> Tuple[Callable, Callable]: | |
if stage == "pt": | |
preprocess_func = partial( | |
preprocess_pretrain_dataset, | |
tokenizer=tokenizer, | |
data_args=data_args, | |
) | |
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) | |
elif stage == "sft" and not training_args.predict_with_generate: | |
if data_args.packing: | |
preprocess_func = partial( | |
preprocess_packed_supervised_dataset, | |
template=template, | |
tokenizer=tokenizer, | |
data_args=data_args, | |
) | |
else: | |
preprocess_func = partial( | |
preprocess_supervised_dataset, | |
template=template, | |
tokenizer=tokenizer, | |
processor=processor, | |
data_args=data_args, | |
) | |
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) | |
elif stage == "rm": | |
preprocess_func = partial( | |
preprocess_pairwise_dataset, | |
template=template, | |
tokenizer=tokenizer, | |
processor=processor, | |
data_args=data_args, | |
) | |
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) | |
elif stage == "kto": | |
preprocess_func = partial( | |
preprocess_feedback_dataset, | |
template=template, | |
tokenizer=tokenizer, | |
processor=processor, | |
data_args=data_args, | |
) | |
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) | |
else: | |
preprocess_func = partial( | |
preprocess_unsupervised_dataset, | |
template=template, | |
tokenizer=tokenizer, | |
processor=processor, | |
data_args=data_args, | |
) | |
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) | |
return preprocess_func, print_function | |