File size: 12,435 Bytes
dc12c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import json
import os
import traceback
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread

import extensions.openai.completions as OAIcompletions
import extensions.openai.edits as OAIedits
import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages
import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import (
    InvalidRequestError,
    OpenAIError,
    ServiceUnavailableError
)
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import debug_msg
from modules import shared

import cgi
import speech_recognition as sr
from pydub import AudioSegment

params = {
    'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
}


class Handler(BaseHTTPRequestHandler):
    def send_access_control_headers(self):
        self.send_header("Access-Control-Allow-Origin", "*")
        self.send_header("Access-Control-Allow-Credentials", "true")
        self.send_header(
            "Access-Control-Allow-Methods",
            "GET,HEAD,OPTIONS,POST,PUT"
        )
        self.send_header(
            "Access-Control-Allow-Headers",
            "Origin, Accept, X-Requested-With, Content-Type, "
            "Access-Control-Request-Method, Access-Control-Request-Headers, "
            "Authorization"
        )

    def do_OPTIONS(self):
        self.send_response(200)
        self.send_access_control_headers()
        self.send_header('Content-Type', 'application/json')
        self.end_headers()
        self.wfile.write("OK".encode('utf-8'))

    def start_sse(self):
        self.send_response(200)
        self.send_access_control_headers()
        self.send_header('Content-Type', 'text/event-stream')
        self.send_header('Cache-Control', 'no-cache')
        # self.send_header('Connection', 'keep-alive')
        self.end_headers()

    def send_sse(self, chunk: dict):
        response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
        debug_msg(response[:-4])
        self.wfile.write(response.encode('utf-8'))

    def end_sse(self):
        response = 'data: [DONE]\r\n\r\n'
        debug_msg(response[:-4])
        self.wfile.write(response.encode('utf-8'))

    def return_json(self, ret: dict, code: int = 200, no_debug=False):
        self.send_response(code)
        self.send_access_control_headers()
        self.send_header('Content-Type', 'application/json')

        response = json.dumps(ret)
        r_utf8 = response.encode('utf-8')

        self.send_header('Content-Length', str(len(r_utf8)))
        self.end_headers()

        self.wfile.write(r_utf8)
        if not no_debug:
            debug_msg(r_utf8)

    def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):

        error_resp = {
            'error': {
                'message': message,
                'code': code,
                'type': error_type,
                'param': param,
            }
        }
        if internal_message:
            print(error_type, message)
            print(internal_message)
            # error_resp['internal_message'] = internal_message

        self.return_json(error_resp, code)

    def openai_error_handler(func):
        def wrapper(self):
            try:
                func(self)
            except InvalidRequestError as e:
                self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message)
            except OpenAIError as e:
                self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message)
            except Exception as e:
                self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())

        return wrapper

    @openai_error_handler
    def do_GET(self):
        debug_msg(self.requestline)
        debug_msg(self.headers)

        if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
            is_legacy = 'engines' in self.path
            is_list = self.path in ['/v1/engines', '/v1/models']
            if is_legacy and not is_list:
                model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
                resp = OAImodels.load_model(model_name)
            elif is_list:
                resp = OAImodels.list_models(is_legacy)
            else:
                model_name = self.path[len('/v1/models/'):]
                resp = OAImodels.model_info(model_name)

            self.return_json(resp)

        elif '/billing/usage' in self.path:
            #  Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
            self.return_json({"total_usage": 0}, no_debug=True)

        else:
            self.send_error(404)

    @openai_error_handler
    def do_POST(self):

        if '/v1/audio/transcriptions' in self.path:
            r = sr.Recognizer()

            # Parse the form data
            form = cgi.FieldStorage(
                fp=self.rfile,
                headers=self.headers,
                environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']}
            )
            
            audio_file = form['file'].file
            audio_data = AudioSegment.from_file(audio_file)
            
            # Convert AudioSegment to raw data
            raw_data = audio_data.raw_data
            
            # Create AudioData object
            audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
            whipser_language = form.getvalue('language', None)
            whipser_model = form.getvalue('model', 'tiny')  # Use the model from the form data if it exists, otherwise default to tiny

            transcription = {"text": ""}
            
            try:
                transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
            except sr.UnknownValueError:
                print("Whisper could not understand audio")
                transcription["text"] = "Whisper could not understand audio UnknownValueError"
            except sr.RequestError as e:
                print("Could not request results from Whisper", e)
                transcription["text"] = "Whisper could not understand audio RequestError"
            
            self.return_json(transcription, no_debug=True)
            return   
            
        debug_msg(self.requestline)
        debug_msg(self.headers)

        content_length = self.headers.get('Content-Length')
        transfer_encoding = self.headers.get('Transfer-Encoding')

        if content_length:
            body = json.loads(self.rfile.read(int(content_length)).decode('utf-8'))
        elif transfer_encoding == 'chunked':
            chunks = []
            while True:
                chunk_size = int(self.rfile.readline(), 16)  # Read the chunk size
                if chunk_size == 0:
                    break  # End of chunks
                chunks.append(self.rfile.read(chunk_size))
                self.rfile.readline()  # Consume the trailing newline after each chunk
            body = json.loads(b''.join(chunks).decode('utf-8'))
        else:
            self.send_response(400, "Bad Request: Either Content-Length or Transfer-Encoding header expected.")
            self.end_headers()
            return

        debug_msg(body)

        if '/completions' in self.path or '/generate' in self.path:

            if not shared.model:
                raise ServiceUnavailableError("No model loaded.")

            is_legacy = '/generate' in self.path
            is_streaming = body.get('stream', False)

            if is_streaming:
                self.start_sse()

                response = []
                if 'chat' in self.path:
                    response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
                else:
                    response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)

                for resp in response:
                    self.send_sse(resp)

                self.end_sse()

            else:
                response = ''
                if 'chat' in self.path:
                    response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
                else:
                    response = OAIcompletions.completions(body, is_legacy=is_legacy)

                self.return_json(response)

        elif '/edits' in self.path:
            # deprecated

            if not shared.model:
                raise ServiceUnavailableError("No model loaded.")

            req_params = get_default_req_params()

            instruction = body['instruction']
            input = body.get('input', '')
            temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999)  # fixup absolute 0.0
            top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)

            response = OAIedits.edits(instruction, input, temperature, top_p)

            self.return_json(response)

        elif '/images/generations' in self.path:
            if 'SD_WEBUI_URL' not in os.environ:
                raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")

            prompt = body['prompt']
            size = default(body, 'size', '1024x1024')
            response_format = default(body, 'response_format', 'url')  # or b64_json
            n = default(body, 'n', 1)  # ignore the batch limits of max 10

            response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)

            self.return_json(response, no_debug=True)

        elif '/embeddings' in self.path:
            encoding_format = body.get('encoding_format', '')

            input = body.get('input', body.get('text', ''))
            if not input:
                raise InvalidRequestError("Missing required argument input", params='input')

            if type(input) is str:
                input = [input]

            response = OAIembeddings.embeddings(input, encoding_format)

            self.return_json(response, no_debug=True)

        elif '/moderations' in self.path:
            input = body['input']
            if not input:
                raise InvalidRequestError("Missing required argument input", params='input')

            response = OAImoderations.moderations(input)

            self.return_json(response, no_debug=True)

        elif self.path == '/api/v1/token-count':
            # NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
            response = token_count(body['prompt'])

            self.return_json(response, no_debug=True)

        elif self.path == '/api/v1/token/encode':
            # NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
            encoding_format = body.get('encoding_format', '')

            response = token_encode(body['input'], encoding_format)

            self.return_json(response, no_debug=True)

        elif self.path == '/api/v1/token/decode':
            # NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
            encoding_format = body.get('encoding_format', '')

            response = token_decode(body['input'], encoding_format)

            self.return_json(response, no_debug=True)

        else:
            self.send_error(404)


def run_server():
    server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
    server = ThreadingHTTPServer(server_addr, Handler)
    if shared.args.share:
        try:
            from flask_cloudflared import _run_cloudflared
            public_url = _run_cloudflared(params['port'], params['port'] + 1)
            print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
        except ImportError:
            print('You should install flask_cloudflared manually')
    else:
        print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')

    server.serve_forever()


def setup():
    Thread(target=run_server, daemon=True).start()