sandz7 commited on
Commit
24384a7
β€’
1 Parent(s): d69619f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -54
app.py CHANGED
@@ -28,70 +28,56 @@ processor = AutoProcessor.from_pretrained(model_id)
28
  # Confirming and setting the eos_token_id (if necessary)
29
  model.generation_config.eos_token_id = processor.tokenizer.eos_token_id
30
 
31
- @spaces.GPU(duration=120)
32
- def krypton(input, history):
33
- print(f"Input: {input}") # Debug input
34
- print(f"History: {history}") # Debug history
35
-
36
- image_path = None
37
- if input["files"]:
38
- print("Found the image")
39
- image_path = input["files"][-1]["path"] if isinstance(input["files"][-1], dict) else input["files"][-1]
40
- print(f"Image path: {image_path}")
41
  else:
 
 
42
  for hist in history:
43
- if isinstance(hist[0], tuple):
44
- image_path = hist[0][0]
45
- break
 
 
 
 
 
 
46
 
47
- if not image_path:
48
- gr.Error("You need to upload an image for Krypton to work.")
49
- return
 
50
 
51
- try:
52
- image = Image.open(image_path)
53
- image.show() # Show the image to confirm it's loaded
54
- print(f"Image open: {image}")
55
- except Exception as e:
56
- print(f"Error opening image: {e}")
57
- gr.Error("Failed to open the image.")
58
- return
59
-
60
- # Adding more context to the prompt with a placeholder for the image
61
- prompt = f"user: Here is an image and a question about it.\n<image>{input['text']}\nassistant: "
62
- print("Made the prompt")
63
 
64
- try:
65
- inputs = processor(text=prompt, images=image, return_tensors='pt').to('cuda', torch.float16)
66
- print(f"Processed inputs: {inputs}")
67
- except Exception as e:
68
- print(f"Error processing inputs: {e}")
69
- gr.Error("Failed to process the inputs.")
70
- return
71
-
72
- # Streamer
73
- print('About to init streamer')
74
- streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=False, skip_prompt=True)
75
-
76
- # Generation kwargs
77
- generation_kwargs = dict(
78
- inputs=inputs['input_ids'],
79
- attention_mask=inputs['attention_mask'],
80
- streamer=streamer,
81
- max_new_tokens=1024,
82
- do_sample=False
83
- )
84
-
85
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
86
- print('Thread about to start')
87
  thread.start()
88
-
 
 
 
89
  buffer = ""
90
- # time.sleep(0.5)
91
  for new_text in streamer:
 
 
 
92
  buffer += new_text
 
 
93
  generated_text_without_prompt = buffer
94
- # time.sleep(0.06)
 
 
95
  yield generated_text_without_prompt
96
 
97
 
 
28
  # Confirming and setting the eos_token_id (if necessary)
29
  model.generation_config.eos_token_id = processor.tokenizer.eos_token_id
30
 
31
+ @spaces.GPU
32
+ def bot_streaming(message, history):
33
+ print(message)
34
+ if message["files"]:
35
+ # message["files"][-1] is a Dict or just a string
36
+ if type(message["files"][-1]) == dict:
37
+ image = message["files"][-1]["path"]
38
+ else:
39
+ image = message["files"][-1]
 
40
  else:
41
+ # if there's no image uploaded for this turn, look for images in the past turns
42
+ # kept inside tuples, take the last one
43
  for hist in history:
44
+ if type(hist[0]) == tuple:
45
+ image = hist[0][0]
46
+ try:
47
+ if image is None:
48
+ # Handle the case where image is None
49
+ gr.Error("You need to upload an image for LLaVA to work.")
50
+ except NameError:
51
+ # Handle the case where 'image' is not defined at all
52
+ gr.Error("You need to upload an image for LLaVA to work.")
53
 
54
+ prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
55
+ # print(f"prompt: {prompt}")
56
+ image = Image.open(image)
57
+ inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
58
 
59
+ streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
60
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
 
 
 
 
 
 
 
 
 
 
61
 
62
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  thread.start()
64
+
65
+ text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
66
+ # print(f"text_prompt: {text_prompt}")
67
+
68
  buffer = ""
69
+ time.sleep(0.5)
70
  for new_text in streamer:
71
+ # find <|eot_id|> and remove it from the new_text
72
+ if "<|eot_id|>" in new_text:
73
+ new_text = new_text.split("<|eot_id|>")[0]
74
  buffer += new_text
75
+
76
+ # generated_text_without_prompt = buffer[len(text_prompt):]
77
  generated_text_without_prompt = buffer
78
+ # print(generated_text_without_prompt)
79
+ time.sleep(0.06)
80
+ # print(f"new_text: {generated_text_without_prompt}")
81
  yield generated_text_without_prompt
82
 
83