Spaces:
Runtime error
Runtime error
File size: 8,198 Bytes
4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 b11f272 4fe4082 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
from __future__ import annotations
import json
import random
import re
from abc import ABC, abstractmethod
from typing import List, Dict, Union, Optional
from huggingface_hub import InferenceClient
from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import AutoTokenizer
ROLE_SYSTEM = 'system'
ROLE_USER = 'user'
ROLE_ASSISTANT = 'assistant'
SUPPORTED_MISTRAL_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2']
SUPPORTED_NOUS_MODELS = ['NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO']
SUPPORTED_LLAMA_MODELS = ['meta-llama/Llama-2-70b-chat-hf',
'meta-llama/Llama-2-13b-chat-hf',
'meta-llama/Llama-2-7b-chat-hf']
ALL_SUPPORTED_MODELS = SUPPORTED_MISTRAL_MODELS + SUPPORTED_NOUS_MODELS + SUPPORTED_LLAMA_MODELS
def select_model(model_name: str, system_prompt: str, **kwargs) -> Model:
if model_name in SUPPORTED_MISTRAL_MODELS:
return MistralModel(system_prompt, model_name)
elif model_name in SUPPORTED_NOUS_MODELS:
return NousHermesModel(system_prompt, model_name)
elif model_name in SUPPORTED_LLAMA_MODELS:
return LlamaModel(system_prompt, model_name)
else:
raise ValueError(f'Model {model_name} not supported')
class Model(ABC):
name: str
messages: List[Dict[str, str]]
system_prompt: str
def __init__(self, model_name: str, system_prompt: str):
self.name = model_name
self.system_prompt = system_prompt
self.messages = [
{'role': ROLE_SYSTEM, 'content': system_prompt}
]
@abstractmethod
def __call__(self, *args, **kwargs) -> Union[str, Dict]:
raise NotImplementedError
def add_message(self, role: str, content: str):
assert role in [ROLE_SYSTEM, ROLE_USER, ROLE_ASSISTANT]
self.messages.append({'role': role, 'content': content})
def clear_conversations(self):
self.messages.clear()
self.add_message(ROLE_SYSTEM, self.system_prompt)
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return self.name
class HFAPIModel(Model):
def __call__(self, user_prompt: str, *args,
use_json: bool = False,
temperature: float = 0,
timeout: float = None,
cache: bool = False,
json_retry_count: int = 5,
**kwargs) -> Union[str, Dict]:
"""
Returns the model's response.
If use_json = True, will try its best to return a json dict, but not guaranteed.
If we cannot parse the JSON, we will return the response string directly.
"""
self.add_message(ROLE_USER, user_prompt)
response = self.get_response(temperature, use_json, timeout, cache)
if use_json:
for i in range(json_retry_count):
# cache only if both instruct to do and first try
response = self.get_response(temperature, use_json, timeout, cache and i == 0)
json_obj = self.find_first_valid_json(response)
if json_obj is not None:
response = json_obj
break
self.add_message(ROLE_ASSISTANT, response)
return response
@retry(stop=stop_after_attempt(5), wait=wait_random_exponential(max=10), reraise=True) # retry if exception
def get_response(self, temperature: float, use_json: bool, timeout: float, cache: bool) -> str:
client = InferenceClient(model=self.name, timeout=timeout)
# client = InferenceClient(model=self.name, token=random.choice(HF_API_TOKENS), timeout=timeout)
if not cache:
client.headers["x-use-cache"] = "0"
# print(self.formatter(self.messages)) # debug
r = client.text_generation(self.format_messages(),
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else None,
max_new_tokens=4096)
return r
@abstractmethod
def format_messages(self) -> str:
raise NotImplementedError
def get_short_name(self) -> str:
"""
Returns the last part of the model name.
For example, "mistralai/Mixtral-8x7B-Instruct-v0.1" -> "Mixtral-8x7B-Instruct-v0.1"
"""
return self.name.split('/')[-1]
@staticmethod
def find_first_valid_json(s) -> Optional[Dict]:
s = re.sub(r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', lambda m: m.group(0)[1:], s) # remove all invalid escapes chars
for i in range(len(s)):
if s[i] != '{':
continue
for j in range(i + 1, len(s) + 1):
if s[j - 1] != '}':
continue
try:
potential_json = s[i:j]
json_obj = json.loads(potential_json, strict=False)
return json_obj # Return the first valid JSON object found
except json.JSONDecodeError:
pass # Continue searching if JSON decoding fails
return None # Return None if no valid JSON object is found
class MistralModel(HFAPIModel):
def __init__(self, system_prompt: str, model_name: str = 'mistralai/Mixtral-8x7B-Instruct-v0.1') -> None:
assert model_name in ['mistralai/Mixtral-8x7B-Instruct-v0.1',
'mistralai/Mistral-7B-Instruct-v0.2'], 'Model not supported'
super().__init__(model_name, system_prompt)
def format_messages(self) -> str:
messages = self.messages
# mistral doesn't support system prompt, so we need to convert it to user prompt
if messages[0]['role'] == ROLE_SYSTEM:
assert len(self.messages) >= 2
messages = [{'role' : ROLE_USER,
'content': messages[0]['content'] + '\n' + messages[1]['content']}] + messages[2:]
tokenizer = AutoTokenizer.from_pretrained(self.name)
r = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, max_length=4096)
# print(r)
return r
class NousHermesModel(HFAPIModel):
def __init__(self, system_prompt: str, model_name: str = 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO') -> None:
assert model_name in ['NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO'], 'Model not supported'
super().__init__(model_name, system_prompt)
def format_messages(self) -> str:
messages = self.messages
assert len(messages) >= 2 # must be at least a system and a user
assert messages[0]['role'] == ROLE_SYSTEM and messages[1]['role'] == ROLE_USER
tokenizer = AutoTokenizer.from_pretrained(self.name)
r = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, max_length=4096)
# print(r)
return r
class LlamaModel(HFAPIModel):
def __init__(self, system_prompt: str, model_name: str = 'meta-llama/Llama-2-70b-chat-hf') -> None:
assert model_name in ['meta-llama/Llama-2-70b-chat-hf',
'meta-llama/Llama-2-13b-chat-hf',
'meta-llama/Llama-2-7b-chat-hf'], 'Model not supported'
super().__init__(model_name, system_prompt)
def format_messages(self) -> str:
"""
<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>
{user_message} [/INST]
"""
messages = self.messages
assert len(messages) >= 2 # must be at least a system and a user
r = f'<s>[INST] <<SYS>>\n{messages[0]["content"]}\n<</SYS>>\n\n{messages[1]["content"]} [/INST]'
for msg in messages[2:]:
role, content = msg['role'], msg['content']
if role == ROLE_SYSTEM:
assert ValueError
elif role == ROLE_USER:
if r.endswith('</s>'):
r += '<s>'
r += f'[INST] {content} [/INST]'
elif role == ROLE_ASSISTANT:
r += f'{content}</s>'
else:
raise ValueError
return r
|