import streamlit as st import torch from model import Generator import torchvision.utils as vutils import os from math import log2 # Function to generate images def generate_images(): Z_DIM = 256 IN_CHANNELS = 256 # Load pretrained generator weights checkpoint = torch.load("generator.pth", map_location=torch.device('cpu')) # Filter out optimizer-related keys state_dict = checkpoint['state_dict'] # Load the filtered state dictionary into the model generator = Generator(Z_DIM, IN_CHANNELS, img_channels=3) generator.load_state_dict(state_dict) generator.eval() # Set output directory output_dir = "generated_images" os.makedirs(output_dir, exist_ok=True) # Generate images img_sizes = [256] images = [] for img_size in img_sizes: num_steps = int(log2(img_size / 4)) x = torch.randn((6, Z_DIM, 1, 1)) # Generate a batch of 6 images with torch.no_grad(): z = generator(x, alpha=0.5, steps=num_steps) # Normalize the generated images to the range [-1, 1] z = (z + 1) / 2 assert z.shape == (6, 3, img_size, img_size) # Append generated images to the list for i in range(6): images.append(z[i].detach()) return images # Main function to create Streamlit web app def main(): st.title('Image Generation with pro-gan 🤖') st.write("Click the buttons below to generate images.") st.write("Trained on CelebHQ dataset.") # Prompt message about image size st.write("Note: Due to limited resources, the model has been trained to generate 256x256 size images. They are still awesome!") # Generate images on button click if st.button('Generate Images'): images = generate_images() # Display the generated images for i, image in enumerate(images): st.image(image.permute(1, 2, 0).cpu().numpy(), caption=f'Generated Image {i+1}', use_column_width=True) if __name__ == '__main__': main()