flux4 / app.py
salomonsky's picture
Update app.py
599e239 verified
raw
history blame
2.93 kB
import streamlit as st
import torch
from PIL import Image
import random
from diffusers import StableDiffusionInstructPix2PixPipeline, DiffusionPipeline, FluxPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
model_options = {
"Instruct Pix2Pix": "timbrooks/instruct-pix2pix",
"SDXL Turbo": "stabilityai/sdxl-turbo",
"FLUX": "alimama-creative/FLUX.1-Turbo-Alpha"
}
selected_model = st.sidebar.selectbox("Selecciona el modelo", list(model_options.keys()))
model_id = model_options[selected_model]
@st.cache_resource
def load_model():
if selected_model == "Instruct Pix2Pix":
return StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, safety_checker=None).to(device)
elif selected_model == "SDXL Turbo":
return DiffusionPipeline.from_pretrained(model_id, safety_checker=None).to(device)
else:
return FluxPipeline.from_pretrained(model_id, safety_checker=None).to(device)
pipe = load_model()
def resize(img, max_size=512):
ratio = min(max_size / img.width, max_size / img.height)
new_size = (int(img.width * ratio), int(img.height * ratio))
return img.resize(new_size, Image.LANCZOS)
def infer(source_img, prompt, steps, seed, guidance_scale):
torch.manual_seed(seed)
source_image = resize(source_img)
progress_bar = st.progress(0)
st.text("Generando imagen...")
try:
for i in range(steps):
output = pipe(
prompt,
image=source_image,
guidance_scale=guidance_scale / 10,
num_inference_steps=steps
)
result = output.images[0] if output.images else None
if result is None:
raise ValueError("No se generaron imágenes.")
progress_bar.progress((i + 1) / steps)
progress_bar.progress(1.0)
except Exception as e:
st.error(f"Error durante la inferencia: {str(e)}")
return None
return result
st.title("Flux Image to Image")
with st.sidebar:
uploaded_image = st.file_uploader("Sube una imagen", type=["png", "jpg", "jpeg"], key="unique_file_uploader")
prompt = st.text_input("Texto del prompt (máx. 77 tokens)")
steps = st.slider("Número de Iteraciones", min_value=1, max_value=50, value=2, step=1)
randomize_seed = st.radio("Randomize Seed", ["Randomize Seed", "Fix Seed"])
seed = st.slider("Seed", min_value=0, max_value=9999, step=1, value=random.randint(0, 9999) if randomize_seed == "Randomize Seed" else 1)
guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=10.0, step=0.01, value=9.0)
if uploaded_image is not None and st.button("Generar imagen"):
image = Image.open(uploaded_image).convert("RGB")
result_image = infer(image, prompt, steps, seed, guidance_scale)
if result_image is not None:
st.image(result_image, caption="Imagen generada", use_column_width=True)