radames commited on
Commit
c139d24
1 Parent(s): b94ad3f

add upload image component

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -14,8 +14,11 @@ processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
14
  model = model.to(device)
15
  model.eval()
16
 
17
- def predict(prompt, image_url, max_length):
18
- image = processor.image_processor.fetch_images(image_url)
 
 
 
19
  prompts = [[image, prompt]]
20
  inputs = processor(prompts[0], return_tensors="pt").to(device)
21
  generated_ids = model.generate(**inputs, max_length=max_length)
@@ -28,17 +31,18 @@ title = "Midjourney-like Image Captioning with IDEFICS"
28
  description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"
29
 
30
  examples = [
31
- ["Describe the following image:", "https://miro.medium.com/v2/resize:fit:0/1*sTXgMwDUW0pk-1yK4iHYFw.png", 64],
32
- ["Describe the following image:", "https://miro.medium.com/v2/resize:fit:1400/0*6as5rHi0sgG4W2Tq.png", 64],
33
- ["Describe the following image:", "https://cdn.arstechnica.net/wp-content/uploads/2023/06/zoomout_2-1440x807.jpg", 64],
34
- ["Describe the following image:", "https://framerusercontent.com/images/inZdRVn7eafZNvaVre2iW1a538.png", 64],
35
- ["Describe the following image:", "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg", 64]
36
 
37
  ]
38
  io = gr.Interface(fn=predict,
39
  inputs=[
40
  gr.Textbox(label="Prompt", value="Describe the following image:"),
41
  gr.Textbox(label="image URL", placeholder="Insert the URL of the image to be described"),
 
42
  gr.Slider(label="Max tokens", value=64, max=128, min=16, step=8)
43
  ],
44
  outputs=[
 
14
  model = model.to(device)
15
  model.eval()
16
 
17
+ def predict(prompt, image_url, image_pil=None, max_length=64):
18
+ if image_pil is not None:
19
+ image = image_pil
20
+ else:
21
+ image = processor.image_processor.fetch_images(image_url)
22
  prompts = [[image, prompt]]
23
  inputs = processor(prompts[0], return_tensors="pt").to(device)
24
  generated_ids = model.generate(**inputs, max_length=max_length)
 
31
  description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"
32
 
33
  examples = [
34
+ ["Describe the following image:", "https://miro.medium.com/v2/resize:fit:0/1*sTXgMwDUW0pk-1yK4iHYFw.png", None, 64],
35
+ ["Describe the following image:", "https://miro.medium.com/v2/resize:fit:1400/0*6as5rHi0sgG4W2Tq.png", None, 64],
36
+ ["Describe the following image:", "https://cdn.arstechnica.net/wp-content/uploads/2023/06/zoomout_2-1440x807.jpg", None, 64],
37
+ ["Describe the following image:", "https://framerusercontent.com/images/inZdRVn7eafZNvaVre2iW1a538.png", None, 64],
38
+ ["Describe the following image:", "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg", None, 64]
39
 
40
  ]
41
  io = gr.Interface(fn=predict,
42
  inputs=[
43
  gr.Textbox(label="Prompt", value="Describe the following image:"),
44
  gr.Textbox(label="image URL", placeholder="Insert the URL of the image to be described"),
45
+ gr.Image(label="or upload an image", type="pil"),
46
  gr.Slider(label="Max tokens", value=64, max=128, min=16, step=8)
47
  ],
48
  outputs=[