nisten commited on
Commit
0cb4dc1
1 Parent(s): 8b1ab04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -33
app.py CHANGED
@@ -1,28 +1,15 @@
1
  import gradio as gr
2
- import subprocess
3
- import sys
4
  import torch
5
  import os
6
- from gradio_client import Client
7
 
8
- # Clone the transformers repository and install from source
9
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "--upgrade", "transformers", "torch"])
10
-
11
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
12
-
13
- # Initialize ZeroGPU
14
- os.environ["ZEROGPU"] = "1"
15
-
16
- # Set the device to GPU if available, otherwise fallback to ZeroGPU
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- # Register the olmoe model type
20
- AutoConfig.register("olmoe", AutoConfig)
21
- AutoModelForCausalLM.register("olmoe", AutoModelForCausalLM)
22
-
23
  # Load the model and tokenizer
24
- model = AutoModelForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924-Instruct").to(DEVICE)
25
- tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924-Instruct")
 
26
 
27
  # Define the system prompt
28
  system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
@@ -31,21 +18,36 @@ system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
31
  "without using any analogies and always showing full working code or output in his answers.")
32
 
33
  # Define a function for generating text
34
- def generate_text(prompt):
35
- inputs = tokenizer(prompt, return_tensors="pt")
36
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
37
- out = model.generate(**inputs, max_length=64)
38
- return tokenizer.decode(out[0])
39
-
40
- # Function to set client for session
41
- def set_client_for_session(request: gr.Request):
42
- x_ip_token = request.headers['x-ip-token']
43
- return Client("gradio/text-to-image", headers={"X-IP-Token": x_ip_token})
 
44
 
45
  # Set up the Gradio chat interface
46
  with gr.Blocks() as demo:
47
- client = gr.State()
48
- iface = gr.ChatInterface(fn=generate_text, system_prompt=system_prompt)
49
-
50
- demo.load(set_client_for_session, None, client)
51
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  import torch
3
  import os
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ # Set the device to GPU if available, otherwise use CPU
 
 
 
 
 
 
 
 
7
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
 
 
 
 
 
9
  # Load the model and tokenizer
10
+ model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(DEVICE)
13
 
14
  # Define the system prompt
15
  system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
 
18
  "without using any analogies and always showing full working code or output in his answers.")
19
 
20
  # Define a function for generating text
21
+ def generate_text(prompt, history):
22
+ full_prompt = f"{system_prompt}\n\nHuman: {prompt}\n\nAssistant:"
23
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
24
+
25
+ with torch.no_grad():
26
+ outputs = model.generate(**inputs, max_new_tokens=4000, do_sample=True, temperature=0.5)
27
+
28
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
+ assistant_response = response.split("Assistant:")[-1].strip()
30
+
31
+ return assistant_response
32
 
33
  # Set up the Gradio chat interface
34
  with gr.Blocks() as demo:
35
+ chatbot = gr.Chatbot()
36
+ msg = gr.Textbox()
37
+ clear = gr.Button("Clear")
38
+
39
+ def user(user_message, history):
40
+ return "", history + [[user_message, None]]
41
+
42
+ def bot(history):
43
+ user_message = history[-1][0]
44
+ bot_message = generate_text(user_message, history)
45
+ history[-1][1] = bot_message
46
+ return history
47
+
48
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
49
+ bot, chatbot, chatbot
50
+ )
51
+ clear.click(lambda: None, None, chatbot, queue=False)
52
+
53
+ demo.launch(share=True)