nisten commited on
Commit
24b2580
1 Parent(s): c290e7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -66
app.py CHANGED
@@ -4,23 +4,13 @@ import torch
4
  import subprocess
5
  import sys
6
  from threading import Thread
7
- from transformers import AutoTokenizer, TextIteratorStreamer
8
- import numpy as np
9
-
10
- from queue import Queue
11
- from threading import Event
12
 
13
  # Install required packages
14
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "librosa", "parler_tts", "melotts", "funasr", "sounddevice", "nltk", "sounddevice", "deepfilternet", "einops", "accelerate", "ChatTTS", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
15
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
16
 
17
- # Import the custom OlmoeForCausalLM
18
- from transformers import OlmoeForCausalLM
19
- import librosa
20
- # Import speech-to-speech components
21
- from VAD.vad_handler import VADHandler
22
- from STT.whisper_stt_handler import WhisperSTTHandler
23
- from TTS.parler_handler import ParlerTTSHandler
24
 
25
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
26
 
@@ -30,12 +20,12 @@ try:
30
  model = OlmoeForCausalLM.from_pretrained(
31
  model_name,
32
  trust_remote_code=True,
33
- torch_dtype=torch.float16,
34
  low_cpu_mem_usage=True,
35
  device_map="auto",
36
- _attn_implementation="flash_attention_2"
37
  ).to(DEVICE)
38
- model.gradient_checkpointing_enable()
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
40
  except Exception as e:
41
  print(f"Error loading model: {e}")
@@ -47,32 +37,10 @@ system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
47
  "while always answering questions in full first principles analysis type of thinking "
48
  "without using any analogies and always showing full working code or output in his answers.")
49
 
50
- # Setup speech-to-speech components
51
- stop_event = Event()
52
- should_listen = Event()
53
- vad = VADHandler(stop_event, Queue(), Queue(), setup_args=(should_listen,))
54
- stt = WhisperSTTHandler(stop_event, Queue(), Queue())
55
- tts = ParlerTTSHandler(stop_event, Queue(), Queue(), setup_args=(should_listen,))
56
-
57
- @spaces.GPU
58
- def speech_to_text(audio):
59
- if audio is None:
60
- return ""
61
- audio_np = librosa.resample(audio[1], orig_sr=audio[0], target_sr=16000)
62
- audio_np = (audio_np * 32768).astype(np.int16)
63
-
64
- vad_output = vad.process(audio_np)
65
- stt_output, _ = next(stt.process(vad_output))
66
- return stt_output
67
-
68
- @spaces.GPU
69
- def user(user_message, history):
70
- return "", history + [[user_message, None]]
71
-
72
  @spaces.GPU
73
- def bot(history, temperature, max_new_tokens):
74
  if model is None or tokenizer is None:
75
- yield history
76
  return
77
 
78
  messages = [{"role": "system", "content": system_prompt}]
@@ -80,6 +48,7 @@ def bot(history, temperature, max_new_tokens):
80
  messages.append({"role": "user", "content": user_msg})
81
  if assistant_msg:
82
  messages.append({"role": "assistant", "content": assistant_msg})
 
83
 
84
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
85
 
@@ -97,53 +66,52 @@ def bot(history, temperature, max_new_tokens):
97
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
98
  thread.start()
99
 
100
- generated_text = ""
101
  for new_text in streamer:
102
- generated_text += new_text
103
- history[-1][1] = generated_text
104
- yield history
105
-
 
 
 
 
106
  except Exception as e:
107
- history[-1][1] = f"An error occurred: {str(e)}"
108
- yield history
109
-
110
- def text_to_speech(text):
111
- audio_output = np.concatenate(list(tts.process(text)))
112
- return (16000, audio_output)
113
 
114
  css = """
115
  #output {
116
- height: 1000px;
117
  overflow: auto;
118
- border: 2px solid #ccc;
119
  }
120
  """
121
 
122
  with gr.Blocks(css=css) as demo:
123
- gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (Now with Speech-to-Speech!)")
124
  chatbot = gr.Chatbot(elem_id="output")
125
- audio_input = gr.Audio(source="microphone", type="numpy")
126
- audio_output = gr.Audio()
127
  msg = gr.Textbox(label="Meow")
128
  with gr.Row():
129
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
130
  max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=2000, step=50, label="Max New Tokens")
131
  clear = gr.Button("Clear")
132
 
133
- def process_audio(audio, history, temp, max_tokens):
134
- text = speech_to_text(audio)
135
- history = history + [[text, None]]
136
- for new_history in bot(history, temp, max_tokens):
137
- yield new_history, text_to_speech(new_history[-1][1])
 
 
 
 
 
138
 
139
- audio_input.stop_recording(process_audio, [audio_input, chatbot, temperature, max_new_tokens], [chatbot, audio_output])
140
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
141
  bot, [chatbot, temperature, max_new_tokens], chatbot
142
- ).then(
143
- lambda history: text_to_speech(history[-1][1]), chatbot, audio_output
144
  )
145
  clear.click(lambda: None, None, chatbot, queue=False)
146
 
147
  if __name__ == "__main__":
148
- demo.queue(api_open=True, max_size=10)
149
- demo.launch(debug=True, show_api=True, share=False)
 
4
  import subprocess
5
  import sys
6
  from threading import Thread
7
+ from transformers import TextIteratorStreamer
 
 
 
 
8
 
9
  # Install required packages
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
+ from transformers import OlmoeForCausalLM, AutoTokenizer
 
 
 
 
 
 
14
 
15
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
16
 
 
20
  model = OlmoeForCausalLM.from_pretrained(
21
  model_name,
22
  trust_remote_code=True,
23
+ torch_dtype=torch.float16, # Using float16 for lower precision
24
  low_cpu_mem_usage=True,
25
  device_map="auto",
26
+ _attn_implementation="flash_attention_2" # Enable Flash Attention 2
27
  ).to(DEVICE)
28
+ model.gradient_checkpointing_enable() # Enable gradient checkpointing
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
  except Exception as e:
31
  print(f"Error loading model: {e}")
 
37
  "while always answering questions in full first principles analysis type of thinking "
38
  "without using any analogies and always showing full working code or output in his answers.")
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @spaces.GPU
41
+ def generate_response(message, history, temperature, max_new_tokens):
42
  if model is None or tokenizer is None:
43
+ yield "Model or tokenizer not loaded properly. Please check the logs."
44
  return
45
 
46
  messages = [{"role": "system", "content": system_prompt}]
 
48
  messages.append({"role": "user", "content": user_msg})
49
  if assistant_msg:
50
  messages.append({"role": "assistant", "content": assistant_msg})
51
+ messages.append({"role": "user", "content": message})
52
 
53
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
54
 
 
66
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
67
  thread.start()
68
 
69
+ partial_message = ""
70
  for new_text in streamer:
71
+ partial_message += new_text
72
+ yield partial_message.strip()
73
+
74
+ except RuntimeError as e:
75
+ if "CUDA out of memory" in str(e):
76
+ yield "GPU memory exceeded. Try reducing the max tokens or using a smaller model."
77
+ else:
78
+ yield f"An error occurred: {str(e)}"
79
  except Exception as e:
80
+ yield f"An unexpected error occurred: {str(e)}"
 
 
 
 
 
81
 
82
  css = """
83
  #output {
84
+ height: 1100px;
85
  overflow: auto;
86
+ border: 3px solid #ccc;
87
  }
88
  """
89
 
90
  with gr.Blocks(css=css) as demo:
91
+ gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (Now with Flash Attention 2 and Streaming!)")
92
  chatbot = gr.Chatbot(elem_id="output")
 
 
93
  msg = gr.Textbox(label="Meow")
94
  with gr.Row():
95
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
96
  max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=2000, step=50, label="Max New Tokens")
97
  clear = gr.Button("Clear")
98
 
99
+ def user(user_message, history):
100
+ return "", history + [[user_message, None]]
101
+
102
+ def bot(history, temp, max_tokens):
103
+ user_message = history[-1][0]
104
+ bot_message = ""
105
+ for token in generate_response(user_message, history[:-1], temp, max_tokens):
106
+ bot_message = token
107
+ history[-1][1] = bot_message
108
+ yield history
109
 
 
110
  msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
111
  bot, [chatbot, temperature, max_new_tokens], chatbot
 
 
112
  )
113
  clear.click(lambda: None, None, chatbot, queue=False)
114
 
115
  if __name__ == "__main__":
116
+ demo.queue(api_open=True, max_size=10) # Limiting queue size
117
+ demo.launch(debug=True, show_api=True, share=False) # Disabled sharing for security