BeveledCube commited on
Commit
5516522
1 Parent(s): dda9b20
Files changed (3) hide show
  1. .gitignore +2 -1
  2. main.py +4 -4
  3. models/mixtral.py +9 -15
.gitignore CHANGED
@@ -1 +1,2 @@
1
- client.js
 
 
1
+ client.js
2
+ __pycache__
main.py CHANGED
@@ -16,14 +16,14 @@ def test_route():
16
  @app.route("/api", methods=["POST"])
17
  def receive_data():
18
  data = request.get_json()
 
19
  print("Prompt:", data["prompt"])
20
- print("System:", data["prompt"])
21
 
22
- generated_text = mixtral.generate("helo", "You are friendly", ["helo"], False, False)
23
 
24
- answer_data = { "answer": generated_text }
25
  print("Response:", generated_text)
26
 
27
- return jsonify(answer_data)
28
 
29
  app.run(host="0.0.0.0", port=7860, debug=False)
 
16
  @app.route("/api", methods=["POST"])
17
  def receive_data():
18
  data = request.get_json()
19
+
20
  print("Prompt:", data["prompt"])
21
+ print("System:", data["system"])
22
 
23
+ generated_text = mixtral.generate(data["prompt"], data["system"], data["history"], False, None)
24
 
 
25
  print("Response:", generated_text)
26
 
27
+ return { "answer": generated_text }
28
 
29
  app.run(host="0.0.0.0", port=7860, debug=False)
models/mixtral.py CHANGED
@@ -9,20 +9,20 @@ def split_list(lst, chunk_size):
9
 
10
  def format_prompt(message, history, system_prompt):
11
  prompt = f"<s>[INST] <<SYS>>{system_prompt}<</SYS>> [/INST] </s>" if system_prompt else "<s>"
12
- for user_prompt, bot_response in history:
 
13
  prompt += f"[INST] {user_prompt} [/INST]"
14
- prompt += f" {bot_response}</s> "
15
  prompt += f"[INST] {message} [/INST]"
 
16
  return prompt
17
 
18
  def generate(
19
  prompt, system_prompt, history, shouldoverridehistory, historyoverride, max_new_tokens=1024, temperature=1.2, top_p=0.95, repetition_penalty=1.0,
20
  ):
21
- print(history)
22
- print(historyoverride)
23
  temperature = float(temperature)
24
  if temperature < 1e-2:
25
- temperature = 1e-2
26
  top_p = float(top_p)
27
 
28
  generate_kwargs = dict(
@@ -37,15 +37,9 @@ def generate(
37
  if shouldoverridehistory:
38
  history = split_list(historyoverride[0], 2)
39
 
40
- print(history)
41
-
42
  formatted_prompt = format_prompt(prompt, history, system_prompt)
 
43
 
44
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
- output = ""
46
-
47
- for response in stream:
48
- output += response.token.text
49
- yield output
50
-
51
- return output
 
9
 
10
  def format_prompt(message, history, system_prompt):
11
  prompt = f"<s>[INST] <<SYS>>{system_prompt}<</SYS>> [/INST] </s>" if system_prompt else "<s>"
12
+
13
+ for user_prompt in history:
14
  prompt += f"[INST] {user_prompt} [/INST]"
15
+
16
  prompt += f"[INST] {message} [/INST]"
17
+
18
  return prompt
19
 
20
  def generate(
21
  prompt, system_prompt, history, shouldoverridehistory, historyoverride, max_new_tokens=1024, temperature=1.2, top_p=0.95, repetition_penalty=1.0,
22
  ):
 
 
23
  temperature = float(temperature)
24
  if temperature < 1e-2:
25
+ temperature = 1e-2
26
  top_p = float(top_p)
27
 
28
  generate_kwargs = dict(
 
37
  if shouldoverridehistory:
38
  history = split_list(historyoverride[0], 2)
39
 
 
 
40
  formatted_prompt = format_prompt(prompt, history, system_prompt)
41
+ print(formatted_prompt)
42
 
43
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=True, return_full_text=False)
44
+
45
+ return stream