Linoy Tsaban commited on
Commit
6494dc6
1 Parent(s): 17db690

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -16,17 +16,17 @@ from transformers import AutoProcessor, BlipForConditionalGeneration
16
  # load pipelines
17
  sd_model_id = "stabilityai/stable-diffusion-2-1-base"
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
20
  sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
21
- sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
22
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
23
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
24
 
25
 
26
 
27
  ## IMAGE CPATIONING ##
28
  def caption_image(input_image):
29
- inputs = blip_processor(images=input_image, return_tensors="pt").to(device)
30
  pixel_values = inputs.pixel_values
31
 
32
  generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
 
16
  # load pipelines
17
  sd_model_id = "stabilityai/stable-diffusion-2-1-base"
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id,torch_dtype=torch.float16).to(device)
20
  sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
21
+ sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id, torch_dtype=torch.float16).to(device)
22
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
23
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)
24
 
25
 
26
 
27
  ## IMAGE CPATIONING ##
28
  def caption_image(input_image):
29
+ inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
30
  pixel_values = inputs.pixel_values
31
 
32
  generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)