Royrotem100 commited on
Commit
b7358fc
1 Parent(s): 14bbdd9

Changes app.py to a server

Browse files
Files changed (1) hide show
  1. app.py +34 -17
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import os
2
  import gradio as gr
3
  from http import HTTPStatus
4
- import openai
5
  from typing import Generator, List, Optional, Tuple, Dict
6
  from urllib.error import HTTPError
7
  from flask import Flask, request, jsonify
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  import threading
 
 
10
 
11
  # Load the model and tokenizer
12
  tokenizer = AutoTokenizer.from_pretrained("./dictalm2.0-instruct")
@@ -31,34 +32,50 @@ def messages_to_history(messages: Messages) -> Tuple[str, History]:
31
  history.append([q['content'], r['content']])
32
  return history
33
 
34
- def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tuple[str, History], None, None]:
35
- if query is None:
36
- query = ''
37
- if history is None:
38
- history = []
39
- if not query.strip():
40
- return
41
- messages = history_to_messages(history)
42
- messages.append({'role': 'user', 'content': query.strip()})
43
 
44
- # Combine all messages into one formatted input text
45
- formatted_text = "<s>" + "".join(f"[INST] {m['content']} [/INST]" for m in messages if m['role'] == 'user')
 
 
46
  inputs = tokenizer(formatted_text, return_tensors='pt')
47
 
48
  # Generate the output
49
  outputs = model.generate(inputs['input_ids'], max_length=1024, temperature=0.7, top_p=0.9)
50
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
 
52
- # Simulate streaming by yielding the response in chunks
53
- chunk_size = 20 # You can adjust the chunk size
54
- for i in range(0, len(full_response), chunk_size):
55
- yield full_response[i:i+chunk_size]
56
 
57
  def run_flask():
58
  app.run(host='0.0.0.0', port=5000)
59
 
 
60
  # Run Flask in a separate thread
61
  threading.Thread(target=run_flask).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  with gr.Blocks(css='''
 
1
  import os
2
  import gradio as gr
3
  from http import HTTPStatus
 
4
  from typing import Generator, List, Optional, Tuple, Dict
5
  from urllib.error import HTTPError
6
  from flask import Flask, request, jsonify
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  import threading
9
+ import requests
10
+
11
 
12
  # Load the model and tokenizer
13
  tokenizer = AutoTokenizer.from_pretrained("./dictalm2.0-instruct")
 
32
  history.append([q['content'], r['content']])
33
  return history
34
 
35
+
36
+ # Flask app setup
37
+ app = Flask(__name__)
38
+
39
+ @app.route('/predict', methods=['POST'])
40
+ def predict():
41
+ data = request.json
42
+ input_text = data.get('text', '')
 
43
 
44
+ # Format the input text with instruction tokens
45
+ formatted_text = f"<s>[INST] {input_text} [/INST]"
46
+
47
+ # Tokenize the input
48
  inputs = tokenizer(formatted_text, return_tensors='pt')
49
 
50
  # Generate the output
51
  outputs = model.generate(inputs['input_ids'], max_length=1024, temperature=0.7, top_p=0.9)
 
52
 
53
+ # Decode the output
54
+ prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
+
56
+ return jsonify({"prediction": prediction})
57
 
58
  def run_flask():
59
  app.run(host='0.0.0.0', port=5000)
60
 
61
+
62
  # Run Flask in a separate thread
63
  threading.Thread(target=run_flask).start()
64
+ def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tuple[str, History], None, None]:
65
+ if query is None:
66
+ query = ''
67
+ if history is None:
68
+ history = []
69
+ if not query.strip():
70
+ return
71
+
72
+ response = requests.post("http://127.0.0.1:5000/predict", json={"text": query.strip()})
73
+ if response.status_code == 200:
74
+ prediction = response.json().get("prediction", "")
75
+ history.append((query, prediction))
76
+ yield prediction, history
77
+ else:
78
+ yield "Error: Unable to get a response from the model.", history
79
 
80
 
81
  with gr.Blocks(css='''