paralym commited on
Commit
8723271
1 Parent(s): c3ccc0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -91
app.py CHANGED
@@ -156,6 +156,22 @@ class InferenceDemo(object):
156
  self.conversation = conv_templates[args.conv_mode].copy()
157
  self.num_frames = args.num_frames
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  def is_valid_video_filename(name):
161
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
@@ -209,13 +225,6 @@ def load_image(image_file):
209
  return image
210
 
211
 
212
- def clear_history(history):
213
-
214
- our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
215
-
216
- return None
217
-
218
-
219
  def clear_response(history):
220
  for index_conv in range(1, len(history)):
221
  # loop until get a text response from our model.
@@ -226,60 +235,80 @@ def clear_response(history):
226
  history = history[:-index_conv]
227
  return history, question
228
 
 
 
 
 
 
 
 
229
 
230
- # def print_like_dislike(x: gr.LikeData):
231
- # print(x.index, x.value, x.liked)
232
 
233
 
234
  def add_message(history, message):
235
- # history=[]
236
- global our_chatbot
237
- if len(history) == 0:
238
- our_chatbot = InferenceDemo(
239
- args, model_path, tokenizer, model, image_processor, context_len
240
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
- for x in message["files"]:
243
- history.append(((x,), None))
244
- if message["text"] is not None:
245
- history.append((message["text"], None))
246
- return history, gr.MultimodalTextbox(value=None, interactive=False)
247
 
248
 
249
  @spaces.GPU
250
  def bot(history, temperature, top_p, max_output_tokens):
251
- print("### turn start history",history)
252
- print("### turn start conv",our_chatbot.conversation)
253
  text = history[-1][0]
254
  images_this_term = []
255
  text_this_term = ""
256
- # import pdb;pdb.set_trace()
257
  num_new_images = 0
 
258
  for i, message in enumerate(history[:-1]):
259
  if type(message[0]) is tuple:
260
- if len(message[0])>1:
261
  gr.Warning("Only one image can be uploaded in a conversation. Please reduce the number of images and start a new conversation.")
262
- return history
 
 
 
 
 
 
 
 
 
263
  else:
264
- images_this_term.append(message[0][0])
265
- if is_valid_video_filename(message[0][0]):
266
- # 不接受视频
267
- raise ValueError("Video is not supported")
268
- num_new_images += our_chatbot.num_frames
269
- elif is_valid_image_filename(message[0][0]):
270
- print("#### Load image from local file",message[0][0])
271
- num_new_images += 1
272
- else:
273
- raise ValueError("Invalid image file")
274
  else:
275
  num_new_images = 0
276
-
277
- # for message in history[-i-1:]:
278
- # images_this_term.append(message[0][0])
279
-
280
- # assert len(images_this_term) > 0, "must have an image"
281
- # image_files = (args.image_file).split(',')
282
- # image = [load_image(f) for f in images_this_term if f]
283
 
284
  all_image_hash = []
285
  all_image_path = []
@@ -323,9 +352,7 @@ def bot(history, temperature, top_p, max_output_tokens):
323
 
324
  image_tensor = torch.stack(image_tensor)
325
  image_token = DEFAULT_IMAGE_TOKEN * num_new_images
326
- # if our_chatbot.model.config.mm_use_im_start_end:
327
- # inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp
328
- # else:
329
  inp = text
330
  inp = image_token + "\n" + inp
331
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
@@ -333,18 +360,6 @@ def bot(history, temperature, top_p, max_output_tokens):
333
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
334
  prompt = our_chatbot.conversation.get_prompt()
335
 
336
-
337
- if len(images_this_term) > 1:
338
- gr.Warning("Only one image can be uploaded in a conversation. Please reduce the number of images and start a new conversation.")
339
- return history
340
-
341
- # input_ids = (
342
- # tokenizer_image_token(
343
- # prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
344
- # )
345
- # .unsqueeze(0)
346
- # .to(our_chatbot.model.device)
347
- # )
348
  input_ids = tokenizer_image_token(
349
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
350
  ).unsqueeze(0).to(our_chatbot.model.device)
@@ -358,9 +373,7 @@ def bot(history, temperature, top_p, max_output_tokens):
358
  stopping_criteria = KeywordsStoppingCriteria(
359
  keywords, our_chatbot.tokenizer, input_ids
360
  )
361
- # streamer = TextStreamer(
362
- # our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
363
- # )
364
  streamer = TextIteratorStreamer(
365
  our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
366
  )
@@ -368,27 +381,6 @@ def bot(history, temperature, top_p, max_output_tokens):
368
  print(input_ids.device)
369
  print(image_tensor.device)
370
 
371
- # with torch.inference_mode():
372
- # output_ids = our_chatbot.model.generate(
373
- # input_ids,
374
- # images=image_tensor,
375
- # do_sample=True,
376
- # temperature=0.7,
377
- # top_p=1.0,
378
- # max_new_tokens=4096,
379
- # streamer=streamer,
380
- # use_cache=False,
381
- # stopping_criteria=[stopping_criteria],
382
- # )
383
-
384
- # outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
385
- # if outputs.endswith(stop_str):
386
- # outputs = outputs[: -len(stop_str)]
387
- # our_chatbot.conversation.messages[-1][-1] = outputs
388
-
389
- # history[-1] = [text, outputs]
390
-
391
- # return history
392
  generate_kwargs = dict(
393
  inputs=input_ids,
394
  streamer=streamer,
@@ -407,13 +399,12 @@ def bot(history, temperature, top_p, max_output_tokens):
407
  outputs = []
408
  for stream_token in streamer:
409
  outputs.append(stream_token)
410
- # print("### stream_token",stream_token)
411
- # our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
412
  history[-1] = [text, "".join(outputs)]
413
  yield history
414
  our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
415
- print("### turn end history", history)
416
- print("### turn end conv",our_chatbot.conversation)
417
 
418
  with open(get_conv_log_filename(), "a") as fout:
419
  data = {
@@ -677,11 +668,10 @@ with gr.Blocks(
677
  gr.Markdown(learn_more_markdown)
678
  gr.Markdown(bibtext)
679
 
680
- chat_msg = chat_input.submit(
681
- add_message, [chatbot, chat_input], [chatbot, chat_input]
682
- )
683
- bot_msg = chat_msg.then(bot, [chatbot,temperature, top_p, max_output_tokens], chatbot, api_name="bot_response")
684
- bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
685
 
686
  # chatbot.like(print_like_dislike, None, None)
687
  clear_btn.click(
@@ -727,5 +717,5 @@ if __name__ == "__main__":
727
  model_name = get_model_name_from_path(args.model_path)
728
  tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
729
  model=model.to(torch.device('cuda'))
730
- our_chatbot = None
731
  demo.launch()
 
156
  self.conversation = conv_templates[args.conv_mode].copy()
157
  self.num_frames = args.num_frames
158
 
159
+ class ChatSessionManager:
160
+ def __init__(self):
161
+ self.chatbot_instance = None
162
+
163
+ def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
164
+ self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
165
+ print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
166
+
167
+ def reset_chatbot(self):
168
+ self.chatbot_instance = None
169
+
170
+ def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
171
+ if self.chatbot_instance is None:
172
+ self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
173
+ return self.chatbot_instance
174
+
175
 
176
  def is_valid_video_filename(name):
177
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
 
225
  return image
226
 
227
 
 
 
 
 
 
 
 
228
  def clear_response(history):
229
  for index_conv in range(1, len(history)):
230
  # loop until get a text response from our model.
 
235
  history = history[:-index_conv]
236
  return history, question
237
 
238
+ chat_manager = ChatSessionManager()
239
+
240
+
241
+ def clear_history(history):
242
+ chatbot_instance = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
243
+ chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
244
+ return None
245
 
 
 
246
 
247
 
248
  def add_message(history, message):
249
+ global chat_image_num
250
+ if not history:
251
+ history = []
252
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
253
+ chat_image_num = 0
254
+
255
+ if len(message["files"]) <= 1:
256
+ for x in message["files"]:
257
+ history.append(((x,), None))
258
+ chat_image_num += 1
259
+ if chat_image_num > 1:
260
+ history = []
261
+ chat_manager.reset_chatbot()
262
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
263
+ chat_image_num = 0
264
+ for x in message["files"]:
265
+ history.append(((x,), None))
266
+ chat_image_num += 1
267
+
268
+ if message["text"] is not None:
269
+ history.append((message["text"], None))
270
+
271
+ print(f"### Chatbot instance ID: {id(our_chatbot)}")
272
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
273
+ else:
274
+ for x in message["files"]:
275
+ history.append(((x,), None))
276
+ if message["text"] is not None:
277
+ history.append((message["text"], None))
278
 
279
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
 
 
 
 
280
 
281
 
282
  @spaces.GPU
283
  def bot(history, temperature, top_p, max_output_tokens):
284
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
285
+ print(f"### Chatbot instance ID: {id(our_chatbot)}")
286
  text = history[-1][0]
287
  images_this_term = []
288
  text_this_term = ""
289
+
290
  num_new_images = 0
291
+ previous_image = False
292
  for i, message in enumerate(history[:-1]):
293
  if type(message[0]) is tuple:
294
+ if previous_image:
295
  gr.Warning("Only one image can be uploaded in a conversation. Please reduce the number of images and start a new conversation.")
296
+ our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
297
+ return None
298
+
299
+ images_this_term.append(message[0][0])
300
+ if is_valid_video_filename(message[0][0]):
301
+ raise ValueError("Video is not supported")
302
+ num_new_images += our_chatbot.num_frames
303
+ elif is_valid_image_filename(message[0][0]):
304
+ print("#### Load image from local file",message[0][0])
305
+ num_new_images += 1
306
  else:
307
+ raise ValueError("Invalid image file")
308
+ previous_image = True
 
 
 
 
 
 
 
 
309
  else:
310
  num_new_images = 0
311
+ previous_image = False
 
 
 
 
 
 
312
 
313
  all_image_hash = []
314
  all_image_path = []
 
352
 
353
  image_tensor = torch.stack(image_tensor)
354
  image_token = DEFAULT_IMAGE_TOKEN * num_new_images
355
+
 
 
356
  inp = text
357
  inp = image_token + "\n" + inp
358
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
 
360
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
361
  prompt = our_chatbot.conversation.get_prompt()
362
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  input_ids = tokenizer_image_token(
364
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
365
  ).unsqueeze(0).to(our_chatbot.model.device)
 
373
  stopping_criteria = KeywordsStoppingCriteria(
374
  keywords, our_chatbot.tokenizer, input_ids
375
  )
376
+
 
 
377
  streamer = TextIteratorStreamer(
378
  our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
379
  )
 
381
  print(input_ids.device)
382
  print(image_tensor.device)
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  generate_kwargs = dict(
385
  inputs=input_ids,
386
  streamer=streamer,
 
399
  outputs = []
400
  for stream_token in streamer:
401
  outputs.append(stream_token)
402
+
 
403
  history[-1] = [text, "".join(outputs)]
404
  yield history
405
  our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
406
+ # print("### turn end history", history)
407
+ # print("### turn end conv",our_chatbot.conversation)
408
 
409
  with open(get_conv_log_filename(), "a") as fout:
410
  data = {
 
668
  gr.Markdown(learn_more_markdown)
669
  gr.Markdown(bibtext)
670
 
671
+ chat_input.submit(
672
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
673
+ ).then(bot, [chatbot, temperature, top_p, max_output_tokens], chatbot, api_name="bot_response").then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
674
+
 
675
 
676
  # chatbot.like(print_like_dislike, None, None)
677
  clear_btn.click(
 
717
  model_name = get_model_name_from_path(args.model_path)
718
  tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
719
  model=model.to(torch.device('cuda'))
720
+ chat_image_num = 0
721
  demo.launch()