ItzRoBeerT commited on
Commit
d5e510d
1 Parent(s): 84cbd85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -29
app.py CHANGED
@@ -1,50 +1,52 @@
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  from diffusers import StableDiffusionPipeline
4
- from diffusers import DiffusionPipeline
5
  import torch
6
  from PIL import Image
7
 
8
- device = "cpu"
 
9
  if torch.cuda.is_available():
10
  device = "cuda"
11
- elif torch.mps.is_available():
12
  device = "mps"
13
-
 
14
  model_id_image = "CompVis/stable-diffusion-v1-4"
15
  model_id_image_description = "vikhyatk/moondream2"
16
  revision = "2024-08-26"
17
 
18
  torch_dtype = torch.float32
19
-
20
  if torch.cuda.is_available():
21
- torch_dtype = torch.bfloat16
22
 
23
- def generate_description(image):
24
- model = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
25
- tokenizer = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)
 
26
 
27
- image_test = Image.open(image)
28
- enc_image = model.encode_image(image_test)
29
- res = model.answer_question(enc_image, "Describe this image to create an avatar", tokenizer)
30
- return res
31
 
32
- def generate_image_by_description(description, avatar_style=None):
33
- pipe = StableDiffusionPipeline.from_pretrained(model_id_image, torch_dtype=torch_dtype)
34
- pipe = pipe.to(device)
35
- pipe.enable_attention_slicing()
36
-
37
- prompt = (
38
- f"Create a pigeon profile avatar. "
39
- f"Use the following description: {description}. "
40
- )
41
 
42
- if avatar_style:
43
- prompt += f"Use {avatar_style} avatar style."
 
 
 
44
 
45
- image = pipe(prompt).images[0]
46
- return image
 
 
47
 
 
 
48
 
49
  def process_and_generate(image, avatar_style):
50
  description = generate_description(image)
@@ -53,11 +55,15 @@ def process_and_generate(image, avatar_style):
53
  with gr.Blocks() as demo:
54
  with gr.Row():
55
  with gr.Column(scale=2, min_width=300):
56
- selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon",height=300)
57
  avatar_style = gr.Radio(
58
- ["Realistic", "Pixel Art", "Imaginative", "Cartoon"], label="(optional) Select the avatar style:")
 
 
59
  generate_button = gr.Button("Generate Avatar", variant="primary")
60
  with gr.Column(scale=2, min_width=300):
61
  generated_image = gr.Image(type="numpy", label="Generated Avatar", height=300)
62
- generate_button.click(process_and_generate, inputs=[selected_image, avatar_style ], outputs=generated_image)
 
 
63
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  from diffusers import StableDiffusionPipeline
 
4
  import torch
5
  from PIL import Image
6
 
7
+ # Configuración del dispositivo
8
+ device = "cpu"
9
  if torch.cuda.is_available():
10
  device = "cuda"
11
+ elif torch.backends.mps.is_available():
12
  device = "mps"
13
+
14
+ # Configuración de modelos
15
  model_id_image = "CompVis/stable-diffusion-v1-4"
16
  model_id_image_description = "vikhyatk/moondream2"
17
  revision = "2024-08-26"
18
 
19
  torch_dtype = torch.float32
 
20
  if torch.cuda.is_available():
21
+ torch_dtype = torch.bfloat16 # Optimización en GPU
22
 
23
+ # Carga de modelos persistente
24
+ print("Cargando modelo de descripción de imágenes...")
25
+ model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
26
+ tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)
27
 
28
+ print("Cargando modelo de Stable Diffusion...")
29
+ pipe_sd = StableDiffusionPipeline.from_pretrained(model_id_image, torch_dtype=torch_dtype)
30
+ pipe_sd = pipe_sd.to(device)
 
31
 
32
+ # Opciones para optimizar memoria
33
+ pipe_sd.enable_attention_slicing()
34
+ if device == "cuda":
35
+ pipe_sd.enable_sequential_cpu_offload() # Liberar memoria gradualmente para GPUs pequeñas
 
 
 
 
 
36
 
37
+ def generate_description(image_path):
38
+ image_test = Image.open(image_path)
39
+ enc_image = model_description.encode_image(image_test)
40
+ description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description)
41
+ return description
42
 
43
+ def generate_image_by_description(description, avatar_style=None):
44
+ prompt = f"Create a pigeon profile avatar. Use the following description: {description}."
45
+ if avatar_style:
46
+ prompt += f" Use {avatar_style} style."
47
 
48
+ result = pipe_sd(prompt)
49
+ return result.images[0]
50
 
51
  def process_and_generate(image, avatar_style):
52
  description = generate_description(image)
 
55
  with gr.Blocks() as demo:
56
  with gr.Row():
57
  with gr.Column(scale=2, min_width=300):
58
+ selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300)
59
  avatar_style = gr.Radio(
60
+ ["Realistic", "Pixel Art", "Imaginative", "Cartoon"],
61
+ label="(optional) Select the avatar style:"
62
+ )
63
  generate_button = gr.Button("Generate Avatar", variant="primary")
64
  with gr.Column(scale=2, min_width=300):
65
  generated_image = gr.Image(type="numpy", label="Generated Avatar", height=300)
66
+
67
+ generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image)
68
+
69
  demo.launch()