import os import sys import base64 from io import BytesIO sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch from torch import nn from fastapi import FastAPI import numpy as np from PIL import Image from dalle.models import Dalle import logging import streamlit as st print("Loading models...") app = FastAPI() from huggingface_hub import hf_hub_download logging.info("Start downloading") full_dict_path = hf_hub_download(repo_id="MatthiasC/dall-e-logo", filename="full_dict_new.ckpt", use_auth_token=st.secrets["model_hub"]) logging.info("End downloading") device = "cuda" if torch.cuda.is_available() else "cpu" model = Dalle.from_pretrained("minDALL-E/1.3B") model.load_state_dict(torch.load(full_dict_path, map_location=torch.device('cpu'))) model.to(device=device) print("Models loaded !") @app.get("/") def read_root(): return {"minDALL-E!"} @app.get("/{generate}") def generate(prompt): images = sample(prompt) images = [to_base64(image) for image in images] return {"images": images} def sample(prompt): # Sampling logging.info("starting sampling") images = ( model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device) .cpu() .numpy() ) logging.info("sampling succeeded") images = np.transpose(images, (0, 2, 3, 1)) pil_images = [] for i in range(len(images)): im = Image.fromarray((images[i] * 255).astype(np.uint8)) pil_images.append(im) return pil_images def to_base64(pil_image): buffered = BytesIO() pil_image.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue())