0xdant commited on
Commit
9722366
1 Parent(s): 837cedd

Refactor image captioning model initialization and device assignment

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,15 +1,17 @@
 
1
  import gradio as gr
2
  from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from PIL import Image
4
 
 
5
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
6
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda")
7
 
8
  def generate_caption(image):
9
  # Now directly using the PIL Image object
10
- inputs = processor(images=image, return_tensors="pt").to("cuda")
11
  outputs = model.generate(**inputs)
12
- caption = processor.decode(outputs[0], skip_special_tokens=True)
13
  return caption
14
 
15
  def caption_image(image):
 
1
+ import torch
2
  import gradio as gr
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
  from PIL import Image
5
 
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
8
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
9
 
10
  def generate_caption(image):
11
  # Now directly using the PIL Image object
12
+ inputs = processor(images=image, return_tensors="pt")
13
  outputs = model.generate(**inputs)
14
+ caption = processor.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
15
  return caption
16
 
17
  def caption_image(image):