File size: 6,307 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional

import numpy as np
import requests

from opencompass.registry import MODELS
from opencompass.utils.logging import get_logger

from .base_api import BaseAPIModel


@MODELS.register_module()
class LightllmAPI(BaseAPIModel):

    is_api: bool = True

    def __init__(
            self,
            path: str = 'LightllmAPI',
            url: str = 'http://localhost:8080/generate',
            input_format: str = '<input_text_to_replace>',
            max_seq_len: int = 2048,
            meta_template: Optional[Dict] = None,
            retry: int = 2,
            generation_kwargs: Optional[Dict] = dict(),
    ):

        super().__init__(path=path,
                         max_seq_len=max_seq_len,
                         meta_template=meta_template,
                         retry=retry,
                         generation_kwargs=generation_kwargs)
        self.logger = get_logger()
        self.url = url
        self.input_format = input_format
        self.generation_kwargs = generation_kwargs
        self.max_out_len = self.generation_kwargs.get('max_new_tokens', 1024)

    def generate(self, inputs: List[str], max_out_len: int,
                 **kwargs) -> List[str]:
        """Generate results given a list of inputs.

        Args:
            inputs (List[str]): A list of strings or PromptDicts.
                The PromptDict should be organized in OpenCompass'
                API format.
            max_out_len (int): The maximum length of the output.

        Returns:
            List[str]: A list of generated strings.
        """

        with ThreadPoolExecutor() as executor:
            results = list(
                executor.map(self._generate, inputs,
                             [self.max_out_len] * len(inputs)))
        return results

    def _generate(self, input: str, max_out_len: int) -> str:
        max_num_retries = 0
        while max_num_retries < self.retry:
            self.wait()
            header = {'content-type': 'application/json'}
            try:
                input = self.input_format.replace('<input_text_to_replace>',
                                                  input)
                data = dict(inputs=input, parameters=self.generation_kwargs)
                raw_response = requests.post(self.url,
                                             headers=header,
                                             data=json.dumps(data))
            except requests.ConnectionError:
                self.logger.error('Got connection error, retrying...')
                continue
            try:
                response = raw_response.json()
                generated_text = response['generated_text']
                if isinstance(generated_text, list):
                    generated_text = generated_text[0]
                return generated_text
            except requests.JSONDecodeError:
                self.logger.error('JsonDecode error, got',
                                  str(raw_response.content))
            except KeyError:
                self.logger.error(f'KeyError. Response: {str(response)}')
            max_num_retries += 1

        raise RuntimeError('Calling LightllmAPI failed after retrying for '
                           f'{max_num_retries} times. Check the logs for '
                           'details.')

    def get_ppl(self, inputs: List[str], max_out_len: int,
                **kwargs) -> List[float]:
        """Generate results given a list of inputs.

        Args:
            inputs (List[str]): A list of strings or PromptDicts.
                The PromptDict should be organized in OpenCompass'
                API format.
            max_out_len (int): The maximum length of the output.

        Returns:
            List[str]: A list of generated strings.
        """

        with ThreadPoolExecutor() as executor:
            results = list(
                executor.map(self._get_ppl, inputs,
                             [self.max_out_len] * len(inputs)))
        return np.array(results)

    def _get_ppl(self, input: str, max_out_len: int) -> float:
        max_num_retries = 0
        if max_out_len is None:
            max_out_len = 1
        while max_num_retries < self.retry:
            self.wait()
            header = {'content-type': 'application/json'}
            try:
                input = self.input_format.replace('<input_text_to_replace>',
                                                  input)
                data = dict(inputs=input, parameters=self.generation_kwargs)
                raw_response = requests.post(self.url,
                                             headers=header,
                                             data=json.dumps(data))
            except requests.ConnectionError:
                self.logger.error('Got connection error, retrying...')
                continue
            try:
                response = raw_response.json()

                assert ('prompt_token_ids' in response and 'prompt_logprobs'
                        in response), f'prompt_token_ids and prompt_logprobs \
                    must be in the output. \
                    Please consider adding \
                    --return_all_prompt_logprobs argument \
                    when starting lightllm service. Response: {str(response)}'

                prompt_token_ids = response['prompt_token_ids'][1:]
                prompt_logprobs = [
                    item[1] for item in response['prompt_logprobs']
                ]
                logprobs = [
                    item[str(token_id)] for token_id, item in zip(
                        prompt_token_ids, prompt_logprobs)
                ]
                if len(logprobs) == 0:
                    return 0.0
                ce_loss = -sum(logprobs) / len(logprobs)
                return ce_loss
            except requests.JSONDecodeError:
                self.logger.error('JsonDecode error, got',
                                  str(raw_response.content))
            max_num_retries += 1
        raise RuntimeError('Calling LightllmAPI failed after retrying for '
                           f'{max_num_retries} times. Check the logs for '
                           'details.')