Spaces:
Running
on
Zero
Running
on
Zero
from st_keyup import st_keyup | |
from streamlit_helpers import * | |
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler | |
VERSION2SPECS = { | |
"SDXL-Turbo": { | |
"H": 512, | |
"W": 512, | |
"C": 4, | |
"f": 8, | |
"is_legacy": False, | |
"config": "configs/inference/sd_xl_base.yaml", | |
"ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors", | |
}, | |
"SD-Turbo": { | |
"H": 512, | |
"W": 512, | |
"C": 4, | |
"f": 8, | |
"is_legacy": False, | |
"config": "configs/inference/sd_2_1.yaml", | |
"ckpt": "checkpoints/sd_turbo.safetensors", | |
}, | |
} | |
class SubstepSampler(EulerAncestralSampler): | |
def __init__(self, n_sample_steps=1, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.n_sample_steps = n_sample_steps | |
self.steps_subset = [0, 100, 200, 300, 1000] | |
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): | |
sigmas = self.discretization( | |
self.num_steps if num_steps is None else num_steps, device=self.device | |
) | |
sigmas = sigmas[ | |
self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:] | |
] | |
uc = cond | |
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) | |
num_sigmas = len(sigmas) | |
s_in = x.new_ones([x.shape[0]]) | |
return x, s_in, sigmas, num_sigmas, cond, uc | |
def seeded_randn(shape, seed): | |
randn = np.random.RandomState(seed).randn(*shape) | |
randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32) | |
return randn | |
class SeededNoise: | |
def __init__(self, seed): | |
self.seed = seed | |
def __call__(self, x): | |
self.seed = self.seed + 1 | |
return seeded_randn(x.shape, self.seed) | |
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): | |
value_dict = {} | |
for key in keys: | |
if key == "txt": | |
value_dict["prompt"] = prompt | |
value_dict["negative_prompt"] = "" | |
if key == "original_size_as_tuple": | |
orig_width = init_dict["orig_width"] | |
orig_height = init_dict["orig_height"] | |
value_dict["orig_width"] = orig_width | |
value_dict["orig_height"] = orig_height | |
if key == "crop_coords_top_left": | |
crop_coord_top = 0 | |
crop_coord_left = 0 | |
value_dict["crop_coords_top"] = crop_coord_top | |
value_dict["crop_coords_left"] = crop_coord_left | |
if key == "aesthetic_score": | |
value_dict["aesthetic_score"] = 6.0 | |
value_dict["negative_aesthetic_score"] = 2.5 | |
if key == "target_size_as_tuple": | |
value_dict["target_width"] = init_dict["target_width"] | |
value_dict["target_height"] = init_dict["target_height"] | |
return value_dict | |
def sample( | |
model, | |
sampler, | |
prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.", | |
H=1024, | |
W=1024, | |
seed=0, | |
filter=None, | |
): | |
F = 8 | |
C = 4 | |
shape = (1, C, H // F, W // F) | |
value_dict = init_embedder_options( | |
keys=get_unique_embedder_keys_from_conditioner(model.conditioner), | |
init_dict={ | |
"orig_width": W, | |
"orig_height": H, | |
"target_width": W, | |
"target_height": H, | |
}, | |
prompt=prompt, | |
) | |
if seed is None: | |
seed = torch.seed() | |
precision_scope = autocast | |
with torch.no_grad(): | |
with precision_scope("cuda"): | |
batch, batch_uc = get_batch( | |
get_unique_embedder_keys_from_conditioner(model.conditioner), | |
value_dict, | |
[1], | |
) | |
c = model.conditioner(batch) | |
uc = None | |
randn = seeded_randn(shape, seed) | |
def denoiser(input, sigma, c): | |
return model.denoiser( | |
model.model, | |
input, | |
sigma, | |
c, | |
) | |
samples_z = sampler(denoiser, randn, cond=c, uc=uc) | |
samples_x = model.decode_first_stage(samples_z) | |
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
if filter is not None: | |
samples = filter(samples) | |
samples = ( | |
(255 * samples) | |
.to(dtype=torch.uint8) | |
.permute(0, 2, 3, 1) | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
return samples | |
def v_spacer(height) -> None: | |
for _ in range(height): | |
st.write("\n") | |
if __name__ == "__main__": | |
st.title("Turbo") | |
head_cols = st.columns([1, 1, 1]) | |
with head_cols[0]: | |
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) | |
version_dict = VERSION2SPECS[version] | |
with head_cols[1]: | |
v_spacer(2) | |
if st.checkbox("Load Model"): | |
mode = "txt2img" | |
else: | |
mode = "skip" | |
if mode != "skip": | |
state = init_st(version_dict, load_filter=True) | |
if state["msg"]: | |
st.info(state["msg"]) | |
model = state["model"] | |
load_model(model) | |
# seed | |
if "seed" not in st.session_state: | |
st.session_state.seed = 0 | |
def increment_counter(): | |
st.session_state.seed += 1 | |
def decrement_counter(): | |
if st.session_state.seed > 0: | |
st.session_state.seed -= 1 | |
with head_cols[2]: | |
n_steps = st.number_input(label="number of steps", min_value=1, max_value=4) | |
sampler = SubstepSampler( | |
n_sample_steps=1, | |
num_steps=1000, | |
eta=1.0, | |
discretization_config=dict( | |
target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization" | |
), | |
) | |
sampler.n_sample_steps = n_steps | |
default_prompt = ( | |
"A cinematic shot of a baby racoon wearing an intricate italian priest robe." | |
) | |
prompt = st_keyup( | |
"Enter a value", value=default_prompt, debounce=300, key="interactive_text" | |
) | |
cols = st.columns([1, 5, 1]) | |
if mode != "skip": | |
with cols[0]: | |
v_spacer(14) | |
st.button("↩", on_click=decrement_counter) | |
with cols[2]: | |
v_spacer(14) | |
st.button("↪", on_click=increment_counter) | |
sampler.noise_sampler = SeededNoise(seed=st.session_state.seed) | |
out = sample( | |
model, | |
sampler, | |
H=512, | |
W=512, | |
seed=st.session_state.seed, | |
prompt=prompt, | |
filter=state.get("filter"), | |
) | |
with cols[1]: | |
st.image(out[0]) | |