from denoising_diffusion_pytorch import Unet, GaussianDiffusion import streamlit as st import torch def get_model(): unet = Unet( dim = 64, dim_mults = (1, 2, 4, 8) ) model = GaussianDiffusion( unet, image_size = 64, timesteps = 1000, # number of steps sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) loss_type = 'l1' , # L1 or L2 p2_loss_weight_gamma = 1. ) model.load_state_dict(torch.load("./model-final.pt", map_location="cpu")) model.eval() return model def scale_to_255(x): return ((x+1)/2*255).astype('uint8') if __name__ == "__main__": st.title("Sushi Diffusion") st.text("The generation process takes about 10 mins.") st.text("If you don't want to wait, please visit: https://thissushidoesnotexist.com/") model = get_model() st.text('Press the button below to generate sushi!') if st.button('🍣'): bar = st.progress(0) img = torch.randn((1,3,64,64), device="cpu") for t in reversed(range(0, model.num_timesteps)): img, _ = model.p_sample(img, t, None) bar.progress((model.num_timesteps-t) / model.num_timesteps) img = scale_to_255(img.squeeze().numpy().transpose(1,2,0)) st.image(img, caption='This sushi does not exist.')