prithivMLmods commited on
Commit
018df5d
1 Parent(s): 20fe796

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -66
app.py CHANGED
@@ -1,16 +1,23 @@
1
- from huggingface_hub import InferenceClient
2
  import gradio as gr
 
 
3
 
4
  css = '''
5
- .gradio-container{max-width: 690px !important}
6
  h1{text-align:center}
7
  footer {
8
  visibility: hidden
9
  }
10
  '''
11
 
12
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
13
 
 
 
 
 
 
 
14
  mood_prompts = {
15
  "Fun": "Respond in a light-hearted, playful manner.",
16
  "Serious": "Respond in a thoughtful, serious tone.",
@@ -62,74 +69,65 @@ mood_prompts = {
62
  "Worried": "Respond with concern and apprehension."
63
  }
64
 
65
- def format_prompt(message, history, system_prompt=None, mood=None):
66
- prompt = "<s>"
67
- if mood:
68
- mood_description = mood_prompts.get(mood, "")
69
- prompt += f"[SYS] {mood_description} [/SYS] "
70
- for user_prompt, bot_response in history:
71
- prompt += f"[INST] {user_prompt} [/INST]"
72
- prompt += f" {bot_response}</s> "
73
- if system_prompt:
74
- prompt += f"[SYS] {system_prompt} [/SYS]"
75
- prompt += f"[INST] {message} [/INST]"
76
- return prompt
77
-
78
- def generate(
79
- prompt, history, system_prompt=None, mood=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
80
  ):
81
- temperature = float(temperature)
82
- if temperature < 1e-2:
83
- temperature = 1e-2
84
- top_p = float(top_p)
85
-
86
- generate_kwargs = dict(
87
- temperature=temperature,
88
- max_new_tokens=max_new_tokens,
89
- top_p=top_p,
90
- repetition_penalty=repetition_penalty,
91
- do_sample=True,
92
- seed=42,
93
- )
94
-
95
- formatted_prompt = format_prompt(prompt, history, system_prompt, mood)
96
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
97
- output = ""
98
 
99
- for response in stream:
100
- output += response.token.text
101
- yield output
102
 
103
- # Append the latest interaction to history in tuple format
104
- history.append((prompt, output))
105
- return history # Return updated history for output
 
 
106
 
107
- def gradio_interface():
108
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
109
- # Initialize state for history
110
- history = gr.State([])
111
 
112
- # Row for system prompt and user prompt
113
- with gr.Row():
114
- system_prompt = gr.Textbox(placeholder="System prompt (optional)", lines=1, visible=False)
115
- prompt = gr.Textbox(placeholder="Enter your message", lines=4)
 
 
 
 
 
 
 
116
 
117
- # Row for generate button and output
118
- with gr.Row():
119
- generate_btn = gr.Button("Generate")
120
- output = gr.Chatbot()
121
-
122
- # Row for mood selection
123
- with gr.Row():
124
- mood = gr.Radio(choices=list(mood_prompts.keys()), value="Professional", label="Select Mood")
125
 
126
- # Connect button click to generate function
127
- generate_btn.click(
128
- generate,
129
- inputs=[prompt, history, system_prompt, mood],
130
- outputs=[output]
 
 
 
 
 
 
 
 
 
 
 
 
131
  )
132
-
133
- demo.queue().launch(show_api=False)
134
-
135
- gradio_interface()
 
 
 
 
1
  import gradio as gr
2
+ from openai import OpenAI
3
+ import os
4
 
5
  css = '''
6
+ .gradio-container{max-width: 1000px !important}
7
  h1{text-align:center}
8
  footer {
9
  visibility: hidden
10
  }
11
  '''
12
 
13
+ ACCESS_TOKEN = os.getenv("HF_TOKEN")
14
 
15
+ client = OpenAI(
16
+ base_url="https://api-inference.huggingface.co/v1/",
17
+ api_key=ACCESS_TOKEN,
18
+ )
19
+
20
+ # Mood prompts dictionary
21
  mood_prompts = {
22
  "Fun": "Respond in a light-hearted, playful manner.",
23
  "Serious": "Respond in a thoughtful, serious tone.",
 
69
  "Worried": "Respond with concern and apprehension."
70
  }
71
 
72
+ def respond(
73
+ message,
74
+ history: list[tuple[str, str]],
75
+ system_message,
76
+ max_tokens,
77
+ temperature,
78
+ top_p,
79
+ mood
 
 
 
 
 
 
 
80
  ):
81
+ # Update system message with mood prompt
82
+ mood_prompt = mood_prompts.get(mood, "")
83
+ full_system_message = f"{system_message} {mood_prompt}".strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ messages = [{"role": "system", "content": full_system_message}]
 
 
86
 
87
+ for val in history:
88
+ if val[0]:
89
+ messages.append({"role": "user", "content": val[0]})
90
+ if val[1]:
91
+ messages.append({"role": "assistant", "content": val[1]})
92
 
93
+ messages.append({"role": "user", "content": message})
 
 
 
94
 
95
+ response = ""
96
+
97
+ for message in client.chat.completions.create(
98
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct",
99
+ max_tokens=max_tokens,
100
+ stream=True,
101
+ temperature=temperature,
102
+ top_p=top_p,
103
+ messages=messages,
104
+ ):
105
+ token = message.choices[0].delta.content
106
 
107
+ response += token
108
+ yield response
 
 
 
 
 
 
109
 
110
+ demo = gr.ChatInterface(
111
+ respond,
112
+ additional_inputs=[
113
+ gr.Textbox(value="", label="System message"),
114
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
115
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
116
+ gr.Slider(
117
+ minimum=0.1,
118
+ maximum=1.0,
119
+ value=0.95,
120
+ step=0.05,
121
+ label="Top-P",
122
+ ),
123
+ gr.Dropdown(
124
+ choices=list(mood_prompts.keys()),
125
+ label="Mood",
126
+ value="Casual"
127
  )
128
+ ],
129
+ css=css,
130
+ theme="bethecloud/storj_theme",
131
+ )
132
+ if __name__ == "__main__":
133
+ demo.launch()