from typing import Dict, Iterable, List, Optional, Union import numpy as np import torch.distributed as dist from opencompass.models.base import BaseModel from opencompass.models.base_api import APITemplateParser from opencompass.utils.logging import get_logger from opencompass.utils.prompt import PromptList PromptType = Union[PromptList, str] class LLaMA2AccessoryModel(BaseModel): """LLaMA2-Accessory model wrapper. Project: https://github.com/Alpha-VLLM/LLaMA2-Accessory Args: tokenizer_only (bool): whether to load tokenizer only meta_template (dict): meta template for the model additional_stop_symbols: (Iterable[str]): additional symbols that mark the end of generation, e.g. the "###" symbol for separating turns in the chat template. from_pretrained_kwargs: kwargs that will be passed to `accessory.MetaModel.from_pretrained` for model instantiation. """ def __init__(self, tokenizer_only: bool = False, meta_template: Optional[Dict] = None, additional_stop_symbols: Iterable[str] = (), **from_pretrained_kwargs): if tokenizer_only: self._load_tokenizer(from_pretrained_kwargs) else: self._load_model(from_pretrained_kwargs) self.additional_stop_symbols = additional_stop_symbols self.max_seq_len = from_pretrained_kwargs.get('max_seq_len', 4096) self.template_parser = APITemplateParser(meta_template) self.logger = get_logger() def _load_model(self, from_pretrained_kwargs): from accessory.model.meta import MetaModel from accessory.util.misc import init_distributed_mode if not dist.is_initialized(): init_distributed_mode() model_parallel_group = dist.GroupMember.WORLD from_pretrained_kwargs['mp_group'] = model_parallel_group self.model = MetaModel.from_pretrained(**from_pretrained_kwargs) self.tokenizer = self.model.tokenizer self.logger = get_logger() def _load_tokenizer(self, from_pretrained_kwargs): from accessory.model.tokenizer import ( Tokenizer, probe_tokenizer_path_from_pretrained) if 'tokenizer_path' in from_pretrained_kwargs: tokenizer_path = from_pretrained_kwargs['tokenizer_path'] else: pretrained_path = from_pretrained_kwargs['pretrained_path'] if isinstance(pretrained_path, str): pretrained_path = [pretrained_path] tokenizer_path = probe_tokenizer_path_from_pretrained( pretrained_path[-1]) self.tokenizer = Tokenizer(tokenizer_path) def generate(self, inputs: List[str], max_out_len: int) -> List[str]: results = self.model.generate( prompts=inputs, max_gen_len=max_out_len, temperature=0., additional_stop_symbols=self.additional_stop_symbols) return results def get_ppl(self, inputs: List[str], mask_length: Optional[List[int]] = None): assert mask_length is None, 'mask_length is not supported' evaluation_results = self.model.evaluate_examples(examples=inputs) ppl = evaluation_results['ppl'] return np.array(ppl, dtype=np.float32) def get_token_len(self, prompt: str) -> int: return len(self.tokenizer.encode(prompt, True, True))