|
import json |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import Dict, List, Optional |
|
|
|
import numpy as np |
|
import requests |
|
|
|
from opencompass.registry import MODELS |
|
from opencompass.utils.logging import get_logger |
|
|
|
from .base_api import BaseAPIModel |
|
|
|
|
|
@MODELS.register_module() |
|
class LightllmAPI(BaseAPIModel): |
|
|
|
is_api: bool = True |
|
|
|
def __init__( |
|
self, |
|
path: str = 'LightllmAPI', |
|
url: str = 'http://localhost:8080/generate', |
|
input_format: str = '<input_text_to_replace>', |
|
max_seq_len: int = 2048, |
|
meta_template: Optional[Dict] = None, |
|
retry: int = 2, |
|
generation_kwargs: Optional[Dict] = dict(), |
|
): |
|
|
|
super().__init__(path=path, |
|
max_seq_len=max_seq_len, |
|
meta_template=meta_template, |
|
retry=retry, |
|
generation_kwargs=generation_kwargs) |
|
self.logger = get_logger() |
|
self.url = url |
|
self.input_format = input_format |
|
self.generation_kwargs = generation_kwargs |
|
self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024) |
|
|
|
def generate(self, inputs: List[str], max_out_len: int, |
|
**kwargs) -> List[str]: |
|
"""Generate results given a list of inputs. |
|
|
|
Args: |
|
inputs (List[str]): A list of strings or PromptDicts. |
|
The PromptDict should be organized in OpenCompass' |
|
API format. |
|
max_out_len (int): The maximum length of the output. |
|
|
|
Returns: |
|
List[str]: A list of generated strings. |
|
""" |
|
|
|
with ThreadPoolExecutor() as executor: |
|
results = list( |
|
executor.map(self._generate, inputs, |
|
[self.max_out_len] * len(inputs))) |
|
return results |
|
|
|
def _generate(self, input: str, max_out_len: int) -> str: |
|
max_num_retries = 0 |
|
while max_num_retries < self.retry: |
|
self.wait() |
|
header = {'content-type': 'application/json'} |
|
try: |
|
input = self.input_format.replace('<input_text_to_replace>', |
|
input) |
|
data = dict(inputs=input, parameters=self.generation_kwargs) |
|
raw_response = requests.post(self.url, |
|
headers=header, |
|
data=json.dumps(data)) |
|
except requests.ConnectionError: |
|
self.logger.error('Got connection error, retrying...') |
|
continue |
|
try: |
|
response = raw_response.json() |
|
generated_text = response['generated_text'] |
|
if isinstance(generated_text, list): |
|
generated_text = generated_text[0] |
|
return generated_text |
|
except requests.JSONDecodeError: |
|
self.logger.error('JsonDecode error, got', |
|
str(raw_response.content)) |
|
except KeyError: |
|
self.logger.error(f'KeyError. Response: {str(response)}') |
|
max_num_retries += 1 |
|
|
|
raise RuntimeError('Calling LightllmAPI failed after retrying for ' |
|
f'{max_num_retries} times. Check the logs for ' |
|
'details.') |
|
|
|
def get_ppl(self, inputs: List[str], max_out_len: int, |
|
**kwargs) -> List[float]: |
|
"""Generate results given a list of inputs. |
|
|
|
Args: |
|
inputs (List[str]): A list of strings or PromptDicts. |
|
The PromptDict should be organized in OpenCompass' |
|
API format. |
|
max_out_len (int): The maximum length of the output. |
|
|
|
Returns: |
|
List[str]: A list of generated strings. |
|
""" |
|
|
|
with ThreadPoolExecutor() as executor: |
|
results = list( |
|
executor.map(self._get_ppl, inputs, |
|
[self.max_out_len] * len(inputs))) |
|
return np.array(results) |
|
|
|
def _get_ppl(self, input: str, max_out_len: int) -> float: |
|
max_num_retries = 0 |
|
if max_out_len is None: |
|
max_out_len = 1 |
|
while max_num_retries < self.retry: |
|
self.wait() |
|
header = {'content-type': 'application/json'} |
|
try: |
|
input = self.input_format.replace('<input_text_to_replace>', |
|
input) |
|
data = dict(inputs=input, parameters=self.generation_kwargs) |
|
raw_response = requests.post(self.url, |
|
headers=header, |
|
data=json.dumps(data)) |
|
except requests.ConnectionError: |
|
self.logger.error('Got connection error, retrying...') |
|
continue |
|
try: |
|
response = raw_response.json() |
|
|
|
assert ('prompt_token_ids' in response and 'prompt_logprobs' |
|
in response), f'prompt_token_ids and prompt_logprobs \ |
|
must be in the output. \ |
|
Please consider adding \ |
|
--return_all_prompt_logprobs argument \ |
|
when starting lightllm service. Response: {str(response)}' |
|
|
|
prompt_token_ids = response['prompt_token_ids'][1:] |
|
prompt_logprobs = [ |
|
item[1] for item in response['prompt_logprobs'] |
|
] |
|
logprobs = [ |
|
item[str(token_id)] for token_id, item in zip( |
|
prompt_token_ids, prompt_logprobs) |
|
] |
|
if len(logprobs) == 0: |
|
return 0.0 |
|
ce_loss = -sum(logprobs) / len(logprobs) |
|
return ce_loss |
|
except requests.JSONDecodeError: |
|
self.logger.error('JsonDecode error, got', |
|
str(raw_response.content)) |
|
max_num_retries += 1 |
|
raise RuntimeError('Calling LightllmAPI failed after retrying for ' |
|
f'{max_num_retries} times. Check the logs for ' |
|
'details.') |
|
|