salomonsky commited on
Commit
13d8085
1 Parent(s): 608a8fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -81
app.py CHANGED
@@ -3,112 +3,74 @@ from PIL import Image
3
  import streamlit as st
4
  import os
5
  import random
6
- import numpy as np
7
  import torch
8
- from diffusers import FluxPipeline # Importar FluxPipeline
9
 
10
- # Configuraciones y cliente de inferencia
11
- MAX_SEED = np.iinfo(np.int32).max
12
  DATA_PATH = Path("./data")
13
  DATA_PATH.mkdir(exist_ok=True)
14
 
15
- # Cargar modelo de Flux
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
- flux_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16).to(device)
18
 
19
- # Generación de imagen usando prompt y/o imagen subida
20
- async def gen(prompt, width, height, uploaded_image=None):
21
- images = []
22
- try:
23
- seed = random.randint(0, MAX_SEED)
24
- if uploaded_image:
25
- uploaded_image = Image.open(uploaded_image).convert("RGB")
26
- image = await generate_img2img(prompt, uploaded_image)
27
- else:
28
- image = await generate_image(prompt, width, height, seed)
29
- image_path = save_image(image, f"generated_image_{seed}.jpg", prompt)
30
- if image_path:
31
- images.append(str(image_path))
32
- except Exception as e:
33
- st.error(f"Error al generar imágenes: {e}")
34
- return images
 
35
 
36
- # Guardar y mostrar galería
37
  def list_saved_images():
38
  return list(DATA_PATH.glob("*.jpg"))
39
 
40
  def display_gallery():
41
- st.header("Galería de Imágenes Guardadas")
42
  images = list_saved_images()
43
  if images:
44
- cols = st.columns(8)
45
  for i, image_file in enumerate(images):
46
- with cols[i % 8]:
47
  st.image(str(image_file), caption=image_file.name, use_column_width=True)
48
- prompt = get_prompt_for_image(image_file.name)
49
- st.write(prompt[:300])
50
-
51
- if st.button(f"Borrar", key=f"delete_{i}_{image_file.name}"):
52
  os.remove(image_file)
53
- st.success("Imagen borrada")
54
- display_gallery()
55
  else:
56
  st.info("No hay imágenes guardadas.")
57
 
58
- def save_image(image, file_name, prompt=None):
59
- image_path = DATA_PATH / file_name
60
- if image_path.exists():
61
- st.warning(f"La imagen '{file_name}' ya existe en la galería. No se guardó.")
62
- return None
63
- else:
64
- image.save(image_path, format="JPEG")
65
- if prompt:
66
- save_prompt(f"{file_name}: {prompt}")
67
- return image_path
68
-
69
- # Generación de imagen desde texto (txt2img)
70
- async def generate_image(prompt, width, height, seed):
71
- image = flux_pipeline(prompt=prompt, width=width, height=height, num_inference_steps=4).images[0]
72
- return image
73
-
74
- # Generación de imagen desde imagen y texto (img+txt=img)
75
- async def generate_img2img(prompt, image):
76
- image = flux_pipeline(prompt=prompt, init_image=image, num_inference_steps=4).images[0]
77
- return image
78
-
79
- def get_prompt_for_image(image_name):
80
- prompts = {}
81
- try:
82
- with open(DATA_PATH / "prompts.txt", "r") as f:
83
- for line in f:
84
- if line.startswith(image_name):
85
- prompts[image_name] = line.split(": ", 1)[1].strip()
86
- except FileNotFoundError:
87
- return "No hay prompt asociado."
88
- return prompts.get(image_name, "No hay prompt asociado.")
89
-
90
- # Función principal
91
- async def main():
92
  st.set_page_config(layout="wide")
 
93
 
94
- st.title("Flux +Image +Variants")
95
- prompt = st.sidebar.text_area("Descripción de la imagen", height=150, max_chars=500)
96
- format_option = st.sidebar.selectbox("Formato", ["9:16", "16:9"])
97
- width, height = (360, 640) if format_option == "9:16" else (640, 360)
98
-
99
- uploaded_image = st.sidebar.file_uploader("Sube una imagen (opcional)", type=["png", "jpg", "jpeg"])
100
 
101
  if st.sidebar.button("Generar Imagen"):
102
- with st.spinner("Generando imágenes..."):
103
- try:
104
- results = await gen(prompt, width, height, uploaded_image)
105
- for result in results:
106
- st.image(result, caption="Imagen Generada")
107
- except Exception as e:
108
- st.error(f"Error al generar las imágenes: {str(e)}")
 
 
109
 
110
  display_gallery()
111
 
112
- # Ejecución del código principal
113
  if __name__ == "__main__":
114
- asyncio.run(main())
 
3
  import streamlit as st
4
  import os
5
  import random
 
6
  import torch
7
+ from diffusers import StableDiffusionImg2ImgPipeline # Reemplazar FluxPipeline
8
 
9
+ # Configuraciones y ruta de guardado
 
10
  DATA_PATH = Path("./data")
11
  DATA_PATH.mkdir(exist_ok=True)
12
 
13
+ # Cargar modelo de Stable Diffusion
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
16
 
17
+ # Generar imagen con Stable Diffusion (txt2img o img2img)
18
+ def generate_image(prompt, init_image=None, strength=0.75):
19
+ generator = torch.manual_seed(random.randint(0, 1000000))
20
+ if init_image:
21
+ image = pipe(prompt=prompt, init_image=init_image, strength=strength, generator=generator).images[0]
22
+ else:
23
+ image = pipe(prompt=prompt, generator=generator).images[0]
24
+ return image
25
+
26
+ # Guardar imagen
27
+ def save_image(image, file_name, prompt=None):
28
+ image_path = DATA_PATH / file_name
29
+ image.save(image_path, format="JPEG")
30
+ if prompt:
31
+ with open(DATA_PATH / "prompts.txt", "a") as f:
32
+ f.write(f"{file_name}: {prompt}\n")
33
+ return image_path
34
 
35
+ # Mostrar galería
36
  def list_saved_images():
37
  return list(DATA_PATH.glob("*.jpg"))
38
 
39
  def display_gallery():
40
+ st.header("Galería de Imágenes")
41
  images = list_saved_images()
42
  if images:
43
+ cols = st.columns(4)
44
  for i, image_file in enumerate(images):
45
+ with cols[i % 4]:
46
  st.image(str(image_file), caption=image_file.name, use_column_width=True)
47
+ if st.button(f"Borrar {image_file.name}", key=f"delete_{i}"):
 
 
 
48
  os.remove(image_file)
49
+ st.success(f"Imagen {image_file.name} borrada.")
 
50
  else:
51
  st.info("No hay imágenes guardadas.")
52
 
53
+ # Función principal de la app
54
+ def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  st.set_page_config(layout="wide")
56
+ st.title("Generación de Imágenes - Stable Diffusion")
57
 
58
+ prompt = st.sidebar.text_area("Descripción de la imagen", height=150)
59
+ uploaded_image = st.sidebar.file_uploader("Sube una imagen para img2img (opcional)", type=["png", "jpg", "jpeg"])
 
 
 
 
60
 
61
  if st.sidebar.button("Generar Imagen"):
62
+ with st.spinner("Generando imagen..."):
63
+ if uploaded_image:
64
+ init_image = Image.open(uploaded_image).convert("RGB")
65
+ generated_image = generate_image(prompt, init_image)
66
+ else:
67
+ generated_image = generate_image(prompt)
68
+
69
+ image_path = save_image(generated_image, f"generated_{random.randint(0, 10000)}.jpg", prompt)
70
+ st.image(str(image_path), caption="Imagen Generada")
71
 
72
  display_gallery()
73
 
74
+ # Ejecución de la app
75
  if __name__ == "__main__":
76
+ main()