from typing import Dict, List, Optional from opencompass.models.base import BaseModel from opencompass.utils import get_logger try: from vllm import LLM, SamplingParams except ImportError: LLM, SamplingParams = None, None DEFAULT_MODEL_KWARGS = dict(trust_remote_code=True) class VLLM(BaseModel): """Model Wrapper for VLLM.""" def __init__( self, path: str, max_seq_len: int = 2048, model_kwargs: dict = None, generation_kwargs: dict = dict(), meta_template: Optional[Dict] = None, mode: str = 'none', use_fastchat_template: bool = False, end_str: Optional[str] = None, ): super().__init__(path=path, max_seq_len=max_seq_len, meta_template=meta_template) assert LLM, ('Please install VLLM with `pip install vllm`. ' 'note: torch==2.1.2 is required.') self.logger = get_logger() self._load_model(path, model_kwargs) self.tokenizer = self.model.get_tokenizer() self.generation_kwargs = generation_kwargs self.generation_kwargs.pop('do_sample', None) assert mode in ['none', 'mid'] self.mode = mode self.use_fastchat_template = use_fastchat_template self.end_str = end_str def _load_model(self, path: str, add_model_kwargs: dict = None, num_retry: int = 3): model_kwargs = DEFAULT_MODEL_KWARGS.copy() if add_model_kwargs is not None: model_kwargs.update(add_model_kwargs) self.model = LLM(path, **model_kwargs) def generate(self, inputs: List[str], max_out_len: int, **kwargs) -> List[str]: """Generate results given a list of inputs. Args: inputs (List[str]): A list of strings. max_out_len (int): The maximum length of the output. Returns: List[str]: A list of generated strings. """ if self.mode == 'mid': input_ids = self.tokenizer(inputs, truncation=False)['input_ids'] inputs = [] for input_id in input_ids: if len(input_id) > self.max_seq_len - max_out_len: half = int((self.max_seq_len - max_out_len) / 2) inputs.append( self.tokenizer.decode(input_id[:half], skip_special_tokens=True) + self.tokenizer.decode(input_id[-half:], skip_special_tokens=True)) else: inputs.append( self.tokenizer.decode(input_id, skip_special_tokens=True)) generation_kwargs = kwargs.copy() generation_kwargs.update(self.generation_kwargs) generation_kwargs.update({'max_tokens': max_out_len}) sampling_kwargs = SamplingParams(**generation_kwargs) outputs = self.model.generate(inputs, sampling_kwargs) prompt_list, output_strs = [], [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text if self.end_str: generated_text = generated_text.split(self.end_str)[0] prompt_list.append(prompt) output_strs.append(generated_text) return output_strs def prompts_preproccess(self, inputs: List[str]): if self.use_fastchat_template: try: from fastchat.model import get_conversation_template except ModuleNotFoundError: raise ModuleNotFoundError( 'Fastchat is not implemented. You can use ' "'pip install \"fschat[model_worker,webui]\"' " 'to implement fastchat.') conv = get_conversation_template('vicuna') conv.append_message(conv.roles[0], inputs[0]) conv.append_message(conv.roles[1], None) inputs = [conv.get_prompt()] return inputs def get_token_len(self, prompt: str) -> int: """Get lengths of the tokenized strings. Args: prompt (str): Input string. Returns: int: Length of the input tokens """ return len(self.model.get_tokenizer().encode(prompt))