Davies commited on
Commit
f6e2254
1 Parent(s): ca1fa59

we added app.py, requirements and utils.py

Browse files
Files changed (2) hide show
  1. app.py +41 -4
  2. utils.py +15 -0
app.py CHANGED
@@ -1,7 +1,44 @@
1
  import streamlit as st
 
2
 
 
3
 
4
- st.title("Este es mi demo")
5
- st.markdown("Esta es mi descripción")
6
- x = st.slider("Selecciona un Valor")
7
- st.write(x, "el cuadrado es", x*x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from utils import carga_modelo, genera
3
 
4
+ ## pagina principal
5
 
6
+ st.title("Generador de Mariposas")
7
+ st.write("Este es un modelo Light GAN pre entrenado")
8
+
9
+ ##Barra lateral
10
+
11
+ st.sidebar.subheader("Esta Mariposa no existe! haha")
12
+ st.sidebar.image("assets/logo.png", width = 200)
13
+ st.sidebar.caption("Dema creado usando Streamlit")
14
+
15
+
16
+ ## Cargamos el modelo
17
+ repo_id = "ceyda/butterfly_cropped_uniq1K_512"
18
+ modelo_gan = carga_modelo(repo_id)
19
+
20
+ ##Generamos 4 mariposas
21
+ n_mariposas = 4
22
+
23
+
24
+ def run_model():
25
+ with st.spinner("Generando Mariposa... :)")
26
+ ims = genera(modelo_gan, n_mariposas)
27
+ st.session_state["ims"] = ims
28
+ if "ims" not in st.session_state:
29
+ st.session_state["ims"] = None
30
+ run_model()
31
+ ims = st.session_state["ims"]
32
+
33
+ run_model_button = st.button(
34
+ "Genera Mariposa",
35
+ on_click = run_model()
36
+ help="Estamos generando la Mariposa "
37
+
38
+ )
39
+
40
+ if ims is not None:
41
+ clos = st.columns(n_mariposas)
42
+ for jo, im in enumerate(ims):
43
+ i = j% n_mariposas
44
+ cols[i].image(im, use_colum_width=True)
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
4
+
5
+ def carga_modelo(modelo_name ="ceyda/butterfly_cropped_uniq1K_512", model_version =None):
6
+ gan = LightweightGAN.from_pretrained(model_name, version =model_version)
7
+ gan.eval()
8
+ return gan
9
+
10
+ def genera(gan, batch_size =1):
11
+ with torch.no_grad():
12
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0)*255
13
+ ims = ims.permute(0,2,3,1).deatch().cpu().numpy().astype(np.uint8)
14
+
15
+ return ims