File size: 3,479 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from typing import Dict, Iterable, List, Optional, Union
import numpy as np
import torch.distributed as dist
from opencompass.models.base import BaseModel
from opencompass.models.base_api import APITemplateParser
from opencompass.utils.logging import get_logger
from opencompass.utils.prompt import PromptList
PromptType = Union[PromptList, str]
class LLaMA2AccessoryModel(BaseModel):
"""LLaMA2-Accessory model wrapper.
Project: https://github.com/Alpha-VLLM/LLaMA2-Accessory
Args:
tokenizer_only (bool): whether to load tokenizer only
meta_template (dict): meta template for the model
additional_stop_symbols: (Iterable[str]): additional symbols that mark
the end of generation, e.g. the "###" symbol for separating turns
in the chat template.
from_pretrained_kwargs: kwargs that will be passed to
`accessory.MetaModel.from_pretrained` for model instantiation.
"""
def __init__(self,
tokenizer_only: bool = False,
meta_template: Optional[Dict] = None,
additional_stop_symbols: Iterable[str] = (),
**from_pretrained_kwargs):
if tokenizer_only:
self._load_tokenizer(from_pretrained_kwargs)
else:
self._load_model(from_pretrained_kwargs)
self.additional_stop_symbols = additional_stop_symbols
self.max_seq_len = from_pretrained_kwargs.get('max_seq_len', 4096)
self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger()
def _load_model(self, from_pretrained_kwargs):
from accessory.model.meta import MetaModel
from accessory.util.misc import init_distributed_mode
if not dist.is_initialized():
init_distributed_mode()
model_parallel_group = dist.GroupMember.WORLD
from_pretrained_kwargs['mp_group'] = model_parallel_group
self.model = MetaModel.from_pretrained(**from_pretrained_kwargs)
self.tokenizer = self.model.tokenizer
self.logger = get_logger()
def _load_tokenizer(self, from_pretrained_kwargs):
from accessory.model.tokenizer import (
Tokenizer, probe_tokenizer_path_from_pretrained)
if 'tokenizer_path' in from_pretrained_kwargs:
tokenizer_path = from_pretrained_kwargs['tokenizer_path']
else:
pretrained_path = from_pretrained_kwargs['pretrained_path']
if isinstance(pretrained_path, str):
pretrained_path = [pretrained_path]
tokenizer_path = probe_tokenizer_path_from_pretrained(
pretrained_path[-1])
self.tokenizer = Tokenizer(tokenizer_path)
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
results = self.model.generate(
prompts=inputs,
max_gen_len=max_out_len,
temperature=0.,
additional_stop_symbols=self.additional_stop_symbols)
return results
def get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None):
assert mask_length is None, 'mask_length is not supported'
evaluation_results = self.model.evaluate_examples(examples=inputs)
ppl = evaluation_results['ppl']
return np.array(ppl, dtype=np.float32)
def get_token_len(self, prompt: str) -> int:
return len(self.tokenizer.encode(prompt, True, True))
|