File size: 2,876 Bytes
e5928ae
2efb72f
e5928ae
 
9b6975c
2efb72f
9b6975c
e5928ae
9b6975c
e5928ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44b50c2
e5928ae
 
 
 
 
 
 
 
 
 
 
 
 
44b50c2
2efb72f
e5928ae
d53e153
 
 
 
8d5e1fd
 
d53e153
 
 
a97e214
 
bd72865
a97e214
7d56800
4e976d5
a97e214
f01378a
a97e214
 
 
 
ce5d4ab
b4bd245
cc8c305
a97e214
defc45e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2efb72f
e5928ae
2efb72f
e5928ae
 
2efb72f
e5928ae
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn

app = FastAPI()

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.01
    top_p: float = 1.0
    details: bool = True
    return_full_text: bool = False
    stream: bool = False

def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(item: Item):
    temperature = float(item.temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(item.top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=1048,
        top_p=top_p,
        repetition_penalty=1.0,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
    # 
    stream = client.text_generation(
        formatted_prompt,
        **generate_kwargs,
        # stream=item.stream,
        stream=False,
        details=item.details,
        return_full_text=item.return_full_text
    )
    # return stream
    output = ""
    
    for response in stream:
        # Check if response has the attribute 'token'
        
        if hasattr(response, 'tokens'):
            print('tokens')
            output += response.token.text
        else:
            output += response  # If not, treat it as a string

    return [{'msg': output}]    
    # return output


# def generate(item: Item):
#     temperature = float(item.temperature)
#     if temperature < 1e-2:
#         temperature = 1e-2
#     top_p = float(item.top_p)

#     generate_kwargs = dict(
#         temperature=temperature,
#         max_new_tokens=1048,
#         top_p=top_p,
#         repetition_penalty=1.0,
#         do_sample=True,
#         seed=42,
#     )

#     formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
#     stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=item.stream, details=item.details, return_full_text=item.return_full_text)
#     output = ""

#     for response in stream:
#         output += response.token.text
#     return output

@app.post("/generate/")
async def generate_text(item: Item):
    try:
        response = generate(item)
        return {"response": response}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# @app.get("/health")
# async def health_check():
#     return {
#         "status": "healthy",
#         "huggingface_client": "initialized",
#         "auth_required": True
#     }