nisten commited on
Commit
e659cfe
1 Parent(s): aaeb784

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -70
app.py CHANGED
@@ -1,18 +1,17 @@
1
  import gradio as gr
 
 
2
  import torch
3
  import subprocess
4
- import sys
5
- import os
6
 
7
  # Force install the specific transformers version from the GitHub PR
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
 
10
- from transformers import AutoModelForCausalLM, AutoTokenizer
11
-
12
- # Define model name
13
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
14
 
15
- # Define prompts
 
 
16
  system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
17
  "who is stuck inside a step function machine and remembers and counts everything he says "
18
  "while always answering questions in full first principles analysis type of thinking "
@@ -22,90 +21,42 @@ user_prompt = '<|user|>\n'
22
  assistant_prompt = '<|assistant|>\n'
23
  prompt_suffix = "<|end|>\n"
24
 
25
- # Function to load model and tokenizer
26
- def load_model_and_tokenizer(model_name):
27
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
28
-
29
- # Check for CUDA availability
30
- if torch.cuda.is_available():
31
- print("CUDA is available. Using GPU.")
32
- device = "cuda"
33
- else:
34
- print("CUDA is not available. Using CPU.")
35
- device = "cpu"
36
-
37
- # Load model
38
- model = AutoModelForCausalLM.from_pretrained(
39
- model_name,
40
- trust_remote_code=True,
41
- torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
42
- ).to(device).eval()
43
-
44
- return model, tokenizer, device
45
-
46
- # Function to generate response
47
- def generate_response(message, history, model, tokenizer, device):
48
  full_prompt = f"{system_prompt}\n{user_prompt}{message}{prompt_suffix}{assistant_prompt}"
49
 
50
- inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
51
- with torch.no_grad():
52
- generate_ids = model.generate(
53
- **inputs,
54
- max_new_tokens=1000,
55
- do_sample=True,
56
- temperature=0.7,
57
- eos_token_id=tokenizer.eos_token_id,
58
- )
59
  response = tokenizer.batch_decode(generate_ids[:, inputs['input_ids'].shape[1]:],
60
  skip_special_tokens=True,
61
  clean_up_tokenization_spaces=False)[0]
62
  return response.strip()
63
 
64
- # Function to set client for session
65
- def set_client_for_session(request: gr.Request):
66
- x_ip_token = request.headers.get('x-ip-token', '')
67
- return {"X-IP-Token": x_ip_token}
68
-
69
- # Set up Gradio interface
70
  with gr.Blocks() as demo:
71
- gr.Markdown("#Karpathy Chatbot")
72
  chatbot = gr.Chatbot()
73
  msg = gr.Textbox()
74
  clear = gr.Button("Clear")
75
-
76
- # States
77
- model_state = gr.State()
78
- tokenizer_state = gr.State()
79
- device_state = gr.State()
80
- headers_state = gr.State()
81
-
82
- def initialize_model(headers):
83
- if not model_state.value:
84
- model, tokenizer, device = load_model_and_tokenizer(model_name)
85
- return model, tokenizer, device
86
- return model_state.value, tokenizer_state.value, device_state.value
87
 
88
  def user(user_message, history):
89
  return "", history + [[user_message, None]]
90
 
91
- def bot(history, model, tokenizer, device):
92
  user_message = history[-1][0]
93
- bot_message = generate_response(user_message, history, model, tokenizer, device)
94
  history[-1][1] = bot_message
95
  return history
96
 
97
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
98
- initialize_model, headers_state, [model_state, tokenizer_state, device_state]
99
- ).then(
100
- bot, [chatbot, model_state, tokenizer_state, device_state], chatbot
101
  )
102
  clear.click(lambda: None, None, chatbot, queue=False)
103
 
104
- demo.load(set_client_for_session, None, headers_state)
105
-
106
- if __name__ == "__main__":
107
- if os.environ.get("SPACE_ID"):
108
- demo.queue(api_open=False)
109
- demo.launch(debug=True)
110
- else:
111
- demo.launch(debug=True, share=True)
 
1
  import gradio as gr
2
+ import spaces
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import subprocess
 
 
6
 
7
  # Force install the specific transformers version from the GitHub PR
8
  subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
 
 
 
 
10
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
11
 
12
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype="auto").cuda().eval()
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
14
+
15
  system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
16
  "who is stuck inside a step function machine and remembers and counts everything he says "
17
  "while always answering questions in full first principles analysis type of thinking "
 
21
  assistant_prompt = '<|assistant|>\n'
22
  prompt_suffix = "<|end|>\n"
23
 
24
+ @spaces.GPU
25
+ def generate_response(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  full_prompt = f"{system_prompt}\n{user_prompt}{message}{prompt_suffix}{assistant_prompt}"
27
 
28
+ inputs = tokenizer(full_prompt, return_tensors="pt").to("cuda:0")
29
+ generate_ids = model.generate(
30
+ **inputs,
31
+ max_new_tokens=1000,
32
+ do_sample=True,
33
+ temperature=0.7,
34
+ eos_token_id=tokenizer.eos_token_id,
35
+ )
 
36
  response = tokenizer.batch_decode(generate_ids[:, inputs['input_ids'].shape[1]:],
37
  skip_special_tokens=True,
38
  clean_up_tokenization_spaces=False)[0]
39
  return response.strip()
40
 
 
 
 
 
 
 
41
  with gr.Blocks() as demo:
42
+ gr.Markdown("# Pissed Off Karpathy Chatbot")
43
  chatbot = gr.Chatbot()
44
  msg = gr.Textbox()
45
  clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def user(user_message, history):
48
  return "", history + [[user_message, None]]
49
 
50
+ def bot(history):
51
  user_message = history[-1][0]
52
+ bot_message = generate_response(user_message, history)
53
  history[-1][1] = bot_message
54
  return history
55
 
56
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
57
+ bot, chatbot, chatbot
 
 
58
  )
59
  clear.click(lambda: None, None, chatbot, queue=False)
60
 
61
+ demo.queue(api_open=False)
62
+ demo.launch(debug=True, show_api=False)