Shreyas094 commited on
Commit
072458d
1 Parent(s): 5f39768

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -23
app.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
  import gradio as gr
4
  from transformers import pipeline
5
 
6
- from llama_cpp_agent.providers.llama_cpp_endpoint_provider import LlamaCppEndpointSettings
7
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
8
  from llama_cpp_agent.chat_history import BasicChatHistory
9
  from llama_cpp_agent.chat_history.messages import Roles
@@ -21,7 +21,6 @@ from typing import List
21
  from langchain_community.llms import HuggingFaceHub
22
 
23
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
24
-
25
  examples = [
26
  ["latest news about Yann LeCun"],
27
  ["Latest news site:github.blog"],
@@ -37,6 +36,7 @@ def get_context_by_model(model_name):
37
 
38
  def get_messages_formatter_type(model_name):
39
  if model_name is None:
 
40
  logging.warning("Model name is None. Defaulting to CHATML formatter.")
41
  return MessagesFormatterType.CHATML
42
 
@@ -46,17 +46,6 @@ def get_messages_formatter_type(model_name):
46
  else:
47
  return MessagesFormatterType.CHATML
48
 
49
- class HuggingFaceHubProvider(LlamaCppEndpointSettings):
50
- def __init__(self, model):
51
- self.model = model
52
-
53
- def create_completion(self, prompt, **kwargs):
54
- response = self.model(prompt)
55
- return {'choices': [{'text': response['generated_text']}]}
56
-
57
- def get_provider_default_settings(self):
58
- return self.model.model_kwargs
59
-
60
  def get_model(temperature, top_p, repetition_penalty):
61
  return HuggingFaceHub(
62
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
@@ -94,7 +83,6 @@ class CitingSources(BaseModel):
94
  def write_message_to_user():
95
  return "Please write the message to the user."
96
 
97
- #@spaces.GPU(duration=120)
98
  def respond(
99
  message,
100
  history: list[tuple[str, str]],
@@ -115,7 +103,7 @@ def respond(
115
  # Create a new model instance for each request
116
  llm = get_model(temperature, top_p, repeat_penalty)
117
 
118
- provider = HuggingFaceHubProvider(llm)
119
  logging.info(f"Loaded chat examples: {chat_template}")
120
  search_tool = WebSearchTool(
121
  llm_provider=provider,
@@ -139,12 +127,12 @@ def respond(
139
  )
140
 
141
  settings = provider.get_provider_default_settings()
142
- settings['stream'] = False
143
- settings['temperature'] = temperature
144
- settings['top_k'] = top_k
145
- settings['top_p'] = top_p
146
- settings['max_tokens'] = max_tokens
147
- settings['repeat_penalty'] = repeat_penalty
148
 
149
  output_settings = LlmStructuredOutputSettings.from_functions(
150
  [search_tool.get_tool()]
@@ -169,7 +157,7 @@ def respond(
169
 
170
  outputs = ""
171
 
172
- settings['stream'] = True
173
  response_text = answer_agent.get_chat_response(
174
  f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" +
175
  result[0]["return_value"],
@@ -207,7 +195,7 @@ demo = gr.ChatInterface(
207
  gr.Dropdown([
208
  'Mistral-7B-Instruct-v0.3'
209
  ],
210
- value="Mistral-7B-Instruct-v0.3", # This should match exactly
211
  label="Model"
212
  ),
213
  gr.Textbox(value=web_search_system_prompt, label="System message"),
 
3
  import gradio as gr
4
  from transformers import pipeline
5
 
6
+ from llama_cpp_agent.providers import LlamaCppPythonProvider
7
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
8
  from llama_cpp_agent.chat_history import BasicChatHistory
9
  from llama_cpp_agent.chat_history.messages import Roles
 
21
  from langchain_community.llms import HuggingFaceHub
22
 
23
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
 
24
  examples = [
25
  ["latest news about Yann LeCun"],
26
  ["Latest news site:github.blog"],
 
36
 
37
  def get_messages_formatter_type(model_name):
38
  if model_name is None:
39
+ # Handle the case where model_name is None
40
  logging.warning("Model name is None. Defaulting to CHATML formatter.")
41
  return MessagesFormatterType.CHATML
42
 
 
46
  else:
47
  return MessagesFormatterType.CHATML
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  def get_model(temperature, top_p, repetition_penalty):
50
  return HuggingFaceHub(
51
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
 
83
  def write_message_to_user():
84
  return "Please write the message to the user."
85
 
 
86
  def respond(
87
  message,
88
  history: list[tuple[str, str]],
 
103
  # Create a new model instance for each request
104
  llm = get_model(temperature, top_p, repeat_penalty)
105
 
106
+ provider = LlamaCppPythonProvider(llm)
107
  logging.info(f"Loaded chat examples: {chat_template}")
108
  search_tool = WebSearchTool(
109
  llm_provider=provider,
 
127
  )
128
 
129
  settings = provider.get_provider_default_settings()
130
+ settings.stream = False
131
+ settings.temperature = temperature
132
+ settings.top_k = top_k
133
+ settings.top_p = top_p
134
+ settings.max_tokens = max_tokens
135
+ settings.repeat_penalty = repeat_penalty
136
 
137
  output_settings = LlmStructuredOutputSettings.from_functions(
138
  [search_tool.get_tool()]
 
157
 
158
  outputs = ""
159
 
160
+ settings.stream = True
161
  response_text = answer_agent.get_chat_response(
162
  f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" +
163
  result[0]["return_value"],
 
195
  gr.Dropdown([
196
  'Mistral-7B-Instruct-v0.3'
197
  ],
198
+ value="Mistral-7B-Instruct-v0.3", # Ensure this matches exactly with the option in the list
199
  label="Model"
200
  ),
201
  gr.Textbox(value=web_search_system_prompt, label="System message"),