Tonic commited on
Commit
2bdacd4
1 Parent(s): c6378e6

add reference code from vllm

Browse files
Files changed (1) hide show
  1. app.py +54 -49
app.py CHANGED
@@ -214,65 +214,68 @@ model = load_model(params, model_path)
214
  tokenizer = MistralTokenizer.from_model("pixtral")
215
 
216
  def preprocess_image(image):
 
 
217
  image = image.convert('RGB')
218
  image = image.resize((params['vision_encoder']['image_size'], params['vision_encoder']['image_size']))
219
  image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
220
  return image_tensor
221
 
222
- @spaces.GPU
223
- def generate_text(image, prompt):
224
- image_tensor = preprocess_image(image).cuda()
225
-
226
- tokenized = tokenizer.encode_chat_completion(
227
- ChatCompletionRequest(
228
- messages=[
229
- UserMessage(
230
- content=[
231
- TextChunk(text=prompt),
232
- ImageChunk(image=image),
233
- ]
234
- )
235
- ],
236
- model="pixtral",
 
 
237
  )
238
- )
239
- input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).cuda()
240
-
241
- # Generate text
242
- with torch.no_grad():
243
- model.cuda()
244
- max_length = 100 # add slider
245
- for _ in range(max_length):
246
- logits = model(image_tensor, input_ids)
247
- next_token_logits = logits[0, -1, :]
248
- next_token = torch.argmax(next_token_logits, dim=-1)
249
- input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1)
250
- if next_token.item() == tokenizer.eos_token_id:
251
- break
252
- model.cpu()
253
-
254
- generated_text = tokenizer.decode(input_ids[0].tolist())
255
- return generated_text, len(input_ids[0]), 1 # 1 image processed
256
-
257
- @spaces.GPU
258
  def calculate_similarity(image1, image2):
259
- # Preprocess images
260
- tensor1 = preprocess_image(image1).cuda()
261
- tensor2 = preprocess_image(image2).cuda()
262
 
263
- # Generate embeddings
264
- with torch.no_grad():
265
- model.cuda()
266
- embedding1 = model(tensor1).mean(dim=1) # Average over spatial dimensions
267
- embedding2 = model(tensor2).mean(dim=1)
268
- model.cpu()
269
 
270
- # Calculate cosine similarity
271
- similarity = F.cosine_similarity(embedding1, embedding2).item()
272
 
273
- return similarity
 
 
274
 
275
- with gr.Blocks(theme=gr.themes.Base()) as demo:
276
  gr.Markdown(title)
277
  gr.Markdown("## Model Details")
278
  gr.Markdown(f"- Model Dimension: {params['dim']}")
@@ -287,6 +290,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
287
  gr.Markdown("1. The image is processed by a Vision Encoder using 2D ROPE (Rotary Position Embedding).")
288
  gr.Markdown("2. The encoder uses SiLU activation in its feed-forward layers.")
289
  gr.Markdown("3. The encoded image is used for text generation or similarity comparison.")
 
290
  gr.Markdown(description)
291
 
292
  with gr.Tabs():
@@ -295,6 +299,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
295
  with gr.Column():
296
  input_image = gr.Image(type="pil", label="Input Image")
297
  input_prompt = gr.Textbox(label="Prompt")
 
298
  submit_btn = gr.Button("Generate Text")
299
 
300
  with gr.Column():
@@ -304,7 +309,7 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
304
 
305
  submit_btn.click(
306
  fn=generate_text,
307
- inputs=[input_image, input_prompt],
308
  outputs=[output_text, token_count, image_count]
309
  )
310
 
 
214
  tokenizer = MistralTokenizer.from_model("pixtral")
215
 
216
  def preprocess_image(image):
217
+ if image is None:
218
+ raise ValueError("No image provided")
219
  image = image.convert('RGB')
220
  image = image.resize((params['vision_encoder']['image_size'], params['vision_encoder']['image_size']))
221
  image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
222
  return image_tensor
223
 
224
+ @spaces.GPU(duration=120)
225
+ def generate_text(image, prompt, max_tokens):
226
+ try:
227
+ image_tensor = preprocess_image(image).cuda()
228
+
229
+ tokenized = tokenizer.encode_chat_completion(
230
+ ChatCompletionRequest(
231
+ messages=[
232
+ UserMessage(
233
+ content=[
234
+ TextChunk(text=prompt),
235
+ ImageChunk(image=image),
236
+ ]
237
+ )
238
+ ],
239
+ model="pixtral",
240
+ )
241
  )
242
+ input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).cuda()
243
+
244
+ with torch.no_grad():
245
+ model.cuda()
246
+ for _ in range(max_tokens):
247
+ logits = model(image_tensor, input_ids)
248
+ next_token_logits = logits[0, -1, :]
249
+ next_token = torch.argmax(next_token_logits, dim=-1)
250
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1)
251
+ if next_token.item() == tokenizer.eos_token_id:
252
+ break
253
+ model.cpu()
254
+
255
+ generated_text = tokenizer.decode(input_ids[0].tolist())
256
+ return generated_text, len(input_ids[0]), 1 # 1 image processed
257
+ except Exception as e:
258
+ return f"Error: {str(e)}", 0, 0
259
+
260
+ @spaces.GPU(duration=60)
 
261
  def calculate_similarity(image1, image2):
262
+ try:
263
+ tensor1 = preprocess_image(image1).cuda()
264
+ tensor2 = preprocess_image(image2).cuda()
265
 
266
+ with torch.no_grad():
267
+ model.cuda()
268
+ embedding1 = model(tensor1).mean(dim=1) # Average over spatial dimensions
269
+ embedding2 = model(tensor2).mean(dim=1)
270
+ model.cpu()
 
271
 
272
+ similarity = F.cosine_similarity(embedding1, embedding2).item()
 
273
 
274
+ return similarity
275
+ except Exception as e:
276
+ return f"Error: {str(e)}"
277
 
278
+ with gr.Blocks() as demo:
279
  gr.Markdown(title)
280
  gr.Markdown("## Model Details")
281
  gr.Markdown(f"- Model Dimension: {params['dim']}")
 
290
  gr.Markdown("1. The image is processed by a Vision Encoder using 2D ROPE (Rotary Position Embedding).")
291
  gr.Markdown("2. The encoder uses SiLU activation in its feed-forward layers.")
292
  gr.Markdown("3. The encoded image is used for text generation or similarity comparison.")
293
+
294
  gr.Markdown(description)
295
 
296
  with gr.Tabs():
 
299
  with gr.Column():
300
  input_image = gr.Image(type="pil", label="Input Image")
301
  input_prompt = gr.Textbox(label="Prompt")
302
+ max_tokens_slider = gr.Slider(minimum=60, maximum=1600, value=90, step=5, label="Max Tokens")
303
  submit_btn = gr.Button("Generate Text")
304
 
305
  with gr.Column():
 
309
 
310
  submit_btn.click(
311
  fn=generate_text,
312
+ inputs=[input_image, input_prompt, max_tokens_slider],
313
  outputs=[output_text, token_count, image_count]
314
  )
315