ItzRoBeerT commited on
Commit
8e4035c
1 Parent(s): 8fc1f74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -15
app.py CHANGED
@@ -1,43 +1,63 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  from diffusers import StableDiffusionPipeline
4
  from diffusers import DiffusionPipeline
5
  import torch
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
- model_id = "CompVis/stable-diffusion-v1-4"
9
  torch_dtype = torch.float32
10
 
11
  if torch.cuda.is_available():
12
  torch_dtype = torch.bfloat16
13
 
14
  def generate_description(image):
15
- model = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
16
- return model(image)[0]['generated_text']
17
 
18
- def generate_image_by_description(description):
19
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
 
 
 
 
 
20
  pipe = pipe.to(device)
21
  pipe.enable_attention_slicing()
22
-
23
  prompt = (
24
- f"Generate a image of a pigeon for my profile avatar. "
25
- f"The description of the pigeon is: {description}. "
26
  )
 
 
 
 
27
  image = pipe(prompt).images[0]
28
  return image
29
 
30
 
 
 
 
 
31
  with gr.Blocks() as demo:
32
  with gr.Row():
33
  with gr.Column(scale=2, min_width=300):
34
  selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon",height=300)
 
 
35
  generate_button = gr.Button("Generate Avatar", variant="primary")
36
  with gr.Column(scale=2, min_width=300):
37
  generated_image = gr.Image(type="numpy", label="Generated Avatar", height=300)
38
- def process_and_generate(image):
39
- description = generate_description(image)
40
- return generate_image_by_description(description)
41
-
42
- generate_button.click(process_and_generate, inputs=selected_image, outputs=generated_image)
43
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  from diffusers import StableDiffusionPipeline
4
  from diffusers import DiffusionPipeline
5
  import torch
6
+ from PIL import Image
7
+
8
+ device = "cpu"
9
+ if torch.cuda.is_available():
10
+ device = "cuda"
11
+ elif torch.mps.is_available():
12
+ device = "mps"
13
+
14
+ model_id_image = "CompVis/stable-diffusion-v1-4"
15
+ model_id_image_description = "vikhyatk/moondream2"
16
+ revision = "2024-08-26"
17
 
 
 
18
  torch_dtype = torch.float32
19
 
20
  if torch.cuda.is_available():
21
  torch_dtype = torch.bfloat16
22
 
23
  def generate_description(image):
24
+ model = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)
26
 
27
+ image_test = Image.open(image)
28
+ enc_image = model.encode_image(image_test)
29
+ res = model.answer_question(enc_image, "Describe this image to create an avatar", tokenizer)
30
+ return res
31
+
32
+ def generate_image_by_description(description, avatar_style=None):
33
+ pipe = StableDiffusionPipeline.from_pretrained(model_id_image, torch_dtype=torch_dtype)
34
  pipe = pipe.to(device)
35
  pipe.enable_attention_slicing()
36
+
37
  prompt = (
38
+ f"Create a pigeon profile avatar. "
39
+ f"Use the following description: {description}. "
40
  )
41
+
42
+ if avatar_style:
43
+ prompt += f"Use {avatar_style} avatar style."
44
+
45
  image = pipe(prompt).images[0]
46
  return image
47
 
48
 
49
+ def process_and_generate(image, avatar_style):
50
+ description = generate_description(image)
51
+ return generate_image_by_description(description, avatar_style)
52
+
53
  with gr.Blocks() as demo:
54
  with gr.Row():
55
  with gr.Column(scale=2, min_width=300):
56
  selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon",height=300)
57
+ avatar_style = gr.Radio(
58
+ ["Realistic", "Pixel Art", "Imaginative", "Cartoon"], label="(optional) Select the avatar style:")
59
  generate_button = gr.Button("Generate Avatar", variant="primary")
60
  with gr.Column(scale=2, min_width=300):
61
  generated_image = gr.Image(type="numpy", label="Generated Avatar", height=300)
62
+ generate_button.click(process_and_generate, inputs=[selected_image, avatar_style ], outputs=generated_image)
 
 
 
 
63
  demo.launch()