TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
3.48 kB
from typing import Dict, Iterable, List, Optional, Union
import numpy as np
import torch.distributed as dist
from opencompass.models.base import BaseModel
from opencompass.models.base_api import APITemplateParser
from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList
PromptType = Union[PromptList, str]
class LLaMA2AccessoryModel(BaseModel):
"""LLaMA2-Accessory model wrapper.
Project: https://github.com/Alpha-VLLM/LLaMA2-Accessory
Args:
tokenizer_only (bool): whether to load tokenizer only
meta_template (dict): meta template for the model
additional_stop_symbols: (Iterable[str]): additional symbols that mark
the end of generation, e.g. the "###" symbol for separating turns
in the chat template.
from_pretrained_kwargs: kwargs that will be passed to
`accessory.MetaModel.from_pretrained` for model instantiation.
"""
def __init__(self,
tokenizer_only: bool = False,
meta_template: Optional[Dict] = None,
additional_stop_symbols: Iterable[str] = (),
**from_pretrained_kwargs):
if tokenizer_only:
self._load_tokenizer(from_pretrained_kwargs)
else:
self._load_model(from_pretrained_kwargs)
self.additional_stop_symbols = additional_stop_symbols
self.max_seq_len = from_pretrained_kwargs.get('max_seq_len', 4096)
self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger()
def _load_model(self, from_pretrained_kwargs):
from accessory.model.meta import MetaModel
from accessory.util.misc import init_distributed_mode
if not dist.is_initialized():
init_distributed_mode()
model_parallel_group = dist.GroupMember.WORLD
from_pretrained_kwargs['mp_group'] = model_parallel_group
self.model = MetaModel.from_pretrained(**from_pretrained_kwargs)
self.tokenizer = self.model.tokenizer
self.logger = get_logger()
def _load_tokenizer(self, from_pretrained_kwargs):
from accessory.model.tokenizer import (
Tokenizer, probe_tokenizer_path_from_pretrained)
if 'tokenizer_path' in from_pretrained_kwargs:
tokenizer_path = from_pretrained_kwargs['tokenizer_path']
else:
pretrained_path = from_pretrained_kwargs['pretrained_path']
if isinstance(pretrained_path, str):
pretrained_path = [pretrained_path]
tokenizer_path = probe_tokenizer_path_from_pretrained(
pretrained_path[-1])
self.tokenizer = Tokenizer(tokenizer_path)
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
results = self.model.generate(
prompts=inputs,
max_gen_len=max_out_len,
temperature=0.,
additional_stop_symbols=self.additional_stop_symbols)
return results
def get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None):
assert mask_length is None, 'mask_length is not supported'
evaluation_results = self.model.evaluate_examples(examples=inputs)
ppl = evaluation_results['ppl']
return np.array(ppl, dtype=np.float32)
def get_token_len(self, prompt: str) -> int:
return len(self.tokenizer.encode(prompt, True, True))