ai_architecture / utils.py
JMalott's picture
Update utils.py
8a41a8e
raw
history blame
5.63 kB
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts
from htbuilder.units import percent, px
from htbuilder.funcs import rgba, rgb
import streamlit as st
import os
import sys
import argparse
import clip
import numpy as np
from PIL import Image
from dalle.models import Dalle
from dalle.utils.utils import set_seed, clip_score
import streamlit.components.v1 as components
import torch
#from IPython.display import display
import random
def link(link, text, **style):
return a(_href=link, _target="_blank", style=styles(**style))(text)
def layout(*args):
style = """
<style>
# MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.stApp { bottom: 125px; }
button[title="View fullscreen"]{display:none;}
body {background-color: white;}
</style>
"""
style_div = styles(
position="fixed",
left=0,
bottom=0,
margin=px(0, 0, 0, 0),
width=percent(100),
color="black",
text_align="center",
height="auto",
opacity=1
)
style_hr = styles(
display="block",
margin=px(8, 8, "auto", "auto"),
border_style="inset",
border_width=px(2)
)
body = p()
foot = div(
style=style_div
)(
hr(
style=style_hr
),
body
)
st.markdown(style, unsafe_allow_html=True)
for arg in args:
if isinstance(arg, str):
body(arg)
elif isinstance(arg, HtmlElement):
body(arg)
st.markdown(str(foot), unsafe_allow_html=True)
def footer():
#myargs = []
#layout(*myargs)
style = """
<style>
# MainMenu {visibility: hidden;}
button[title="View fullscreen"]{display:none;}
body {background-color: white;}
</style>
"""
st.markdown(style, unsafe_allow_html=True)
st.markdown("")
st.markdown("")
st.markdown("")
st.markdown("This app uses the [min(DALL路E)](https://github.com/kuprel/min-dalle) port of [DALL路E mini](https://github.com/borisdayma/dalle-mini)")
st.markdown("Created by [Jonathan Malott](https://jonathanmalott.com)")
st.markdown("[Good Systems Grand Challenge](https://bridgingbarriers.utexas.edu/good-systems), The University of Texas at Austin. Advised by Dr. Junfeng Jiao.")
from min_dalle import MinDalle
def generate2(prompt,crazy,k):
mm = MinDalle(
models_root='./pretrained',
dtype=torch.float32,
device='cpu',
is_mega=False,
is_reusable=True
)
# Sampling
newPrompt = prompt
if("architecture" not in prompt.lower() ):
newPrompt += " architecture"
image = mm.generate_image(
text=newPrompt,
seed=np.random.randint(0,10000),
grid_size=1,
is_seamless=False,
temperature=crazy,
top_k=k,#2128,
supercondition_factor=32,
is_verbose=False
)
item = {}
item['prompt'] = prompt
item['crazy'] = crazy
item['k'] = k
item['image'] = image
st.session_state.results.append(item)
model = False
def generate(prompt,crazy,k):
global model
device = 'cpu'
if(model == False):
model = Dalle.from_pretrained('minDALL-E/1.3B') # This will automatically download the pretrained model.
model.to(device=device)
num_candidates = 1
images = []
set_seed(np.random.randint(0,10000))
# Sampling
newPrompt = prompt
if("architecture" not in prompt.lower() ):
newPrompt += " architecture"
images = model.sampling(prompt=newPrompt,
top_k=k,
top_p=None,
softmax_temperature=crazy,
num_candidates=num_candidates,
device=device).cpu().numpy()
images = np.transpose(images, (0, 2, 3, 1))
# CLIP Re-ranking
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
model_clip.to(device=device)
rank = clip_score(prompt=newPrompt,
images=images,
model_clip=model_clip,
preprocess_clip=preprocess_clip,
device=device)
result = images[rank]
item = {}
item['prompt'] = prompt
item['crazy'] = crazy
item['k'] = k
item['image'] = Image.fromarray((result*255).astype(np.uint8))
st.session_state.results.append(item)
def drawGrid():
master = {}
for r in st.session_state.results[::-1]:
_txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
if(_txt not in master):
master[_txt] = [r]
else:
master[_txt].append(r)
for i in st.session_state.images:
im = st.empty()
placeholder = st.empty()
with placeholder.container():
for m in master:
txt = master[m][0]['prompt']+" (Temperature:"+ str(master[m][0]['crazy']) + ", Top K:" + str(master[m][0]['k']) + ")"
st.subheader(txt)
col1, col2, col3 = st.columns(3)
for ix, item in enumerate(master[m]):
if ix % 3 == 0:
with col1:
st.session_state.images.append(st.image(item["image"]))
if ix % 3 == 1:
with col2:
st.session_state.images.append(st.image(item["image"]))
if ix % 3 == 2:
with col3:
st.session_state.images.append(st.image(item["image"]))