salomonsky commited on
Commit
4533b44
1 Parent(s): f7786aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -32
app.py CHANGED
@@ -18,11 +18,11 @@ model_id = model_options[selected_model]
18
  @st.cache_resource
19
  def load_model():
20
  if selected_model == "Instruct Pix2Pix":
21
- return StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float32, safety_checker=None).to(device)
22
  elif selected_model == "SDXL Turbo":
23
- return DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, safety_checker=None).to(device)
24
  else:
25
- return FluxPipeline.from_pretrained(model_id, torch_dtype=torch.float32, safety_checker=None).to(device)
26
 
27
  pipe = load_model()
28
 
@@ -31,7 +31,7 @@ def resize(img, max_size=512):
31
  new_size = (int(img.width * ratio), int(img.height * ratio))
32
  return img.resize(new_size, Image.LANCZOS)
33
 
34
- def infer(source_img, prompt, steps, seed, text_cfg_scale, image_cfg_scale):
35
  torch.manual_seed(seed)
36
  source_image = resize(source_img)
37
  progress_bar = st.progress(0)
@@ -39,35 +39,19 @@ def infer(source_img, prompt, steps, seed, text_cfg_scale, image_cfg_scale):
39
 
40
  try:
41
  for i in range(steps):
42
- if selected_model == "Instruct Pix2Pix":
43
- output = pipe(
44
- prompt,
45
- image=source_image,
46
- guidance_scale=text_cfg_scale,
47
- num_inference_steps=1,
48
- image_guidance_scale=image_cfg_scale
49
- )
50
- elif selected_model == "SDXL Turbo":
51
- output = pipe(
52
- prompt,
53
- num_inference_steps=steps,
54
- guidance_scale=text_cfg_scale,
55
- image=source_image
56
- )
57
- else:
58
- output = pipe(
59
- prompt,
60
- image=source_image,
61
- num_inference_steps=steps,
62
- guidance_scale=text_cfg_scale
63
- )
64
 
65
  result = output.images[0] if output.images else None
66
-
67
  if result is None:
68
  raise ValueError("No se generaron imágenes.")
69
 
70
- progress_bar.progress((i + 1) / steps)
71
 
72
  progress_bar.progress(1.0)
73
  except Exception as e:
@@ -84,11 +68,10 @@ with st.sidebar:
84
  steps = st.slider("Número de Iteraciones", min_value=1, max_value=50, value=11, step=1)
85
  randomize_seed = st.radio("Randomize Seed", ["Randomize Seed", "Fix Seed"])
86
  seed = st.slider("Seed", min_value=0, max_value=9999, step=1, value=random.randint(0, 9999) if randomize_seed == "Randomize Seed" else 1111)
87
- text_cfg_scale = st.slider("Text CFG Scale", min_value=0.0, max_value=10.0, step=0.1, value=9.0)
88
- image_cfg_scale = st.slider("Image CFG Scale", min_value=0.0, max_value=10.0, step=0.1, value=1.0)
89
 
90
  if uploaded_image is not None and st.button("Generar imagen"):
91
  image = Image.open(uploaded_image).convert("RGB")
92
- result_image = infer(image, prompt, steps, seed, text_cfg_scale, image_cfg_scale)
93
  if result_image is not None:
94
- st.image(result_image, caption="Imagen generada", use_column_width=True)
 
18
  @st.cache_resource
19
  def load_model():
20
  if selected_model == "Instruct Pix2Pix":
21
+ return StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, safety_checker=None).to(device)
22
  elif selected_model == "SDXL Turbo":
23
+ return DiffusionPipeline.from_pretrained(model_id, safety_checker=None).to(device)
24
  else:
25
+ return FluxPipeline.from_pretrained(model_id, safety_checker=None).to(device)
26
 
27
  pipe = load_model()
28
 
 
31
  new_size = (int(img.width * ratio), int(img.height * ratio))
32
  return img.resize(new_size, Image.LANCZOS)
33
 
34
+ def infer(source_img, prompt, steps, seed, guidance_scale):
35
  torch.manual_seed(seed)
36
  source_image = resize(source_img)
37
  progress_bar = st.progress(0)
 
39
 
40
  try:
41
  for i in range(steps):
42
+ output = pipe(
43
+ prompt,
44
+ image=source_image,
45
+ guidance_scale=guidance_scale / 10,
46
+ num_inference_steps=steps
47
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  result = output.images[0] if output.images else None
50
+
51
  if result is None:
52
  raise ValueError("No se generaron imágenes.")
53
 
54
+ progress_bar.progress((i + 1) / steps)
55
 
56
  progress_bar.progress(1.0)
57
  except Exception as e:
 
68
  steps = st.slider("Número de Iteraciones", min_value=1, max_value=50, value=11, step=1)
69
  randomize_seed = st.radio("Randomize Seed", ["Randomize Seed", "Fix Seed"])
70
  seed = st.slider("Seed", min_value=0, max_value=9999, step=1, value=random.randint(0, 9999) if randomize_seed == "Randomize Seed" else 1111)
71
+ guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=10.0, step=0.01, value=9.0)
 
72
 
73
  if uploaded_image is not None and st.button("Generar imagen"):
74
  image = Image.open(uploaded_image).convert("RGB")
75
+ result_image = infer(image, prompt, steps, seed, guidance_scale)
76
  if result_image is not None:
77
+ st.image(result_image, caption="Imagen generada", use_column_width=True)