ysharma HF staff commited on
Commit
8ca72f4
1 Parent(s): 785765a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -119
app.py CHANGED
@@ -18,84 +18,23 @@ PLACEHOLDER = """
18
  </div>
19
  """
20
 
21
- model_id_llama3 = "xtuner/llava-llama-3-8b-v1_1-transformers"
22
- model_id_phi3 = "xtuner/llava-phi-3-mini-hf"
23
 
24
- processor = AutoProcessor.from_pretrained(model_id_llama3)
25
- processor = AutoProcessor.from_pretrained(model_id_phi3)
26
 
27
- model_llama3 = LlavaForConditionalGeneration.from_pretrained(
28
- model_id_llama3,
29
- torch_dtype=torch.float16,
30
- low_cpu_mem_usage=True,
31
- )
32
- model_llama3.to("cuda:0")
33
- model_llama3.generation_config.eos_token_id = 128009
34
 
35
- model_phi3 = LlavaForConditionalGeneration.from_pretrained(
36
- model_id_phi3,
37
  torch_dtype=torch.float16,
38
  low_cpu_mem_usage=True,
39
  )
40
- model_phi3.to("cuda:0")
41
- model_phi3.generation_config.eos_token_id = 128009
42
 
43
-
44
- @spaces.GPU
45
- def bot_streaming_llama3(message, history):
46
- print(message)
47
- if message["files"]:
48
- # message["files"][-1] is a Dict or just a string
49
- if type(message["files"][-1]) == dict:
50
- image = message["files"][-1]["path"]
51
- else:
52
- image = message["files"][-1]
53
- else:
54
- # if there's no image uploaded for this turn, look for images in the past turns
55
- # kept inside tuples, take the last one
56
- for hist in history:
57
- if type(hist[0]) == tuple:
58
- image = hist[0][0]
59
- try:
60
- if image is None:
61
- # Handle the case where image is None
62
- gr.Error("You need to upload an image for LLaVA to work.")
63
- except NameError:
64
- # Handle the case where 'image' is not defined at all
65
- gr.Error("You need to upload an image for LLaVA to work.")
66
-
67
- prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
68
- # print(f"prompt: {prompt}")
69
- image = Image.open(image)
70
- inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
71
-
72
- streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
73
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
74
-
75
- thread = Thread(target=model_llama3.generate, kwargs=generation_kwargs)
76
- thread.start()
77
-
78
- text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
79
- # print(f"text_prompt: {text_prompt}")
80
-
81
- buffer = ""
82
- time.sleep(0.5)
83
- for new_text in streamer:
84
- # find <|eot_id|> and remove it from the new_text
85
- if "<|eot_id|>" in new_text:
86
- new_text = new_text.split("<|eot_id|>")[0]
87
- buffer += new_text
88
-
89
- # generated_text_without_prompt = buffer[len(text_prompt):]
90
- generated_text_without_prompt = buffer
91
- # print(generated_text_without_prompt)
92
- time.sleep(0.06)
93
- # print(f"new_text: {generated_text_without_prompt}")
94
- yield generated_text_without_prompt
95
 
96
 
97
  @spaces.GPU
98
- def bot_streaming_phi3(message, history):
99
  print(message)
100
  if message["files"]:
101
  # message["files"][-1] is a Dict or just a string
@@ -125,7 +64,7 @@ def bot_streaming_phi3(message, history):
125
  streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
126
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
127
 
128
- thread = Thread(target=model_phi3.generate, kwargs=generation_kwargs)
129
  thread.start()
130
 
131
  text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
@@ -147,57 +86,20 @@ def bot_streaming_phi3(message, history):
147
  yield generated_text_without_prompt
148
 
149
 
150
- def print_like_dislike(x: gr.LikeData):
151
- print(x.index, x.value, x.liked)
152
- #chatbot=gr.Chatbot(placeholder=PLACEHOLDER,scale=1)
153
- #chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
154
-
155
  with gr.Blocks(fill_height=True, ) as demo:
156
- with gr.Row():
157
- chatbot1 = gr.Chatbot(
158
- [],
159
- elem_id="llama3",
160
- bubble_full_width=False,
161
- label='LLaVa-Llama3'
162
- )
163
- chatbot2 = gr.Chatbot(
164
- [],
165
- elem_id="phi3",
166
- bubble_full_width=False,
167
- label='LLaVa-Phi3'
168
  )
169
-
170
- chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
171
-
172
- gr.Examples(examples=[{"text": "What is on the flower?", "files": ["./bee.png"]},
173
- {"text": "How to make this pastry?", "files": ["./baklava.png"]},],
174
- inputs=chat_input)
175
-
176
- #chat_input.submit(lambda: gr.MultimodalTextbox(interactive=False), None, [chat_input]).then(bot_streaming_llama3, [chat_input, chatbot1,], [chatbot1,])
177
-
178
- chat_msg1 = chat_input.submit(bot_streaming_llama3, [chat_input, chatbot1,], [chatbot1,])
179
- chat_msg2 = chat_input.submit(bot_streaming_phi3, [chat_input, chatbot2,], [chatbot2,])
180
-
181
- #bot_msg1 = chat_msg1.then(bot, chatbot1, chatbot1, api_name="bot_response1")
182
- #chat_msg1.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
183
- #bot_msg2 = chat_msg2.then(bot, chatbot2, chatbot2, api_name="bot_response2")
184
- #bot_msg2.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
185
-
186
- chatbot1.like(print_like_dislike, None, None)
187
- chatbot2.like(print_like_dislike, None, None)
188
-
189
-
190
- #gr.ChatInterface(
191
- #fn=bot_streaming_llama3,
192
- #title="LLaVA Llama-3-8B",
193
- #examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
194
- # {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
195
- #description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
196
- #stop_btn="Stop Generation",
197
- #multimodal=True,
198
- #textbox=chat_input,
199
- #chatbot=chatbot,
200
- #)
201
 
202
  demo.queue(api_open=False)
203
- demo.launch(show_api=False, share=False)
 
18
  </div>
19
  """
20
 
 
 
21
 
22
+ model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
 
23
 
24
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
 
 
25
 
26
+ model = LlavaForConditionalGeneration.from_pretrained(
27
+ model_id,
28
  torch_dtype=torch.float16,
29
  low_cpu_mem_usage=True,
30
  )
 
 
31
 
32
+ model.to("cuda:0")
33
+ model.generation_config.eos_token_id = 128009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  @spaces.GPU
37
+ def bot_streaming(message, history):
38
  print(message)
39
  if message["files"]:
40
  # message["files"][-1] is a Dict or just a string
 
64
  streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
65
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
66
 
67
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
68
  thread.start()
69
 
70
  text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
 
86
  yield generated_text_without_prompt
87
 
88
 
89
+ chatbot=gr.Chatbot(placeholder=PLACEHOLDER,scale=1)
90
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
 
 
 
91
  with gr.Blocks(fill_height=True, ) as demo:
92
+ gr.ChatInterface(
93
+ fn=bot_streaming,
94
+ title="LLaVA Llama-3-8B",
95
+ examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
96
+ {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
97
+ description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
98
+ stop_btn="Stop Generation",
99
+ multimodal=True,
100
+ textbox=chat_input,
101
+ chatbot=chatbot,
 
 
102
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  demo.queue(api_open=False)
105
+ demo.launch(show_api=False, share=False)