from st_keyup import st_keyup from streamlit_helpers import * from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler VERSION2SPECS = { "SDXL-Turbo": { "H": 512, "W": 512, "C": 4, "f": 8, "is_legacy": False, "config": "configs/inference/sd_xl_base.yaml", "ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors", }, "SD-Turbo": { "H": 512, "W": 512, "C": 4, "f": 8, "is_legacy": False, "config": "configs/inference/sd_2_1.yaml", "ckpt": "checkpoints/sd_turbo.safetensors", }, } class SubstepSampler(EulerAncestralSampler): def __init__(self, n_sample_steps=1, *args, **kwargs): super().__init__(*args, **kwargs) self.n_sample_steps = n_sample_steps self.steps_subset = [0, 100, 200, 300, 1000] def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): sigmas = self.discretization( self.num_steps if num_steps is None else num_steps, device=self.device ) sigmas = sigmas[ self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:] ] uc = cond x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) s_in = x.new_ones([x.shape[0]]) return x, s_in, sigmas, num_sigmas, cond, uc def seeded_randn(shape, seed): randn = np.random.RandomState(seed).randn(*shape) randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32) return randn class SeededNoise: def __init__(self, seed): self.seed = seed def __call__(self, x): self.seed = self.seed + 1 return seeded_randn(x.shape, self.seed) def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): value_dict = {} for key in keys: if key == "txt": value_dict["prompt"] = prompt value_dict["negative_prompt"] = "" if key == "original_size_as_tuple": orig_width = init_dict["orig_width"] orig_height = init_dict["orig_height"] value_dict["orig_width"] = orig_width value_dict["orig_height"] = orig_height if key == "crop_coords_top_left": crop_coord_top = 0 crop_coord_left = 0 value_dict["crop_coords_top"] = crop_coord_top value_dict["crop_coords_left"] = crop_coord_left if key == "aesthetic_score": value_dict["aesthetic_score"] = 6.0 value_dict["negative_aesthetic_score"] = 2.5 if key == "target_size_as_tuple": value_dict["target_width"] = init_dict["target_width"] value_dict["target_height"] = init_dict["target_height"] return value_dict def sample( model, sampler, prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.", H=1024, W=1024, seed=0, filter=None, ): F = 8 C = 4 shape = (1, C, H // F, W // F) value_dict = init_embedder_options( keys=get_unique_embedder_keys_from_conditioner(model.conditioner), init_dict={ "orig_width": W, "orig_height": H, "target_width": W, "target_height": H, }, prompt=prompt, ) if seed is None: seed = torch.seed() precision_scope = autocast with torch.no_grad(): with precision_scope("cuda"): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [1], ) c = model.conditioner(batch) uc = None randn = seeded_randn(shape, seed) def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, ) samples_z = sampler(denoiser, randn, cond=c, uc=uc) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) samples = ( (255 * samples) .to(dtype=torch.uint8) .permute(0, 2, 3, 1) .detach() .cpu() .numpy() ) return samples def v_spacer(height) -> None: for _ in range(height): st.write("\n") if __name__ == "__main__": st.title("Turbo") head_cols = st.columns([1, 1, 1]) with head_cols[0]: version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) version_dict = VERSION2SPECS[version] with head_cols[1]: v_spacer(2) if st.checkbox("Load Model"): mode = "txt2img" else: mode = "skip" if mode != "skip": state = init_st(version_dict, load_filter=True) if state["msg"]: st.info(state["msg"]) model = state["model"] load_model(model) # seed if "seed" not in st.session_state: st.session_state.seed = 0 def increment_counter(): st.session_state.seed += 1 def decrement_counter(): if st.session_state.seed > 0: st.session_state.seed -= 1 with head_cols[2]: n_steps = st.number_input(label="number of steps", min_value=1, max_value=4) sampler = SubstepSampler( n_sample_steps=1, num_steps=1000, eta=1.0, discretization_config=dict( target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization" ), ) sampler.n_sample_steps = n_steps default_prompt = ( "A cinematic shot of a baby racoon wearing an intricate italian priest robe." ) prompt = st_keyup( "Enter a value", value=default_prompt, debounce=300, key="interactive_text" ) cols = st.columns([1, 5, 1]) if mode != "skip": with cols[0]: v_spacer(14) st.button("↩", on_click=decrement_counter) with cols[2]: v_spacer(14) st.button("↪", on_click=increment_counter) sampler.noise_sampler = SeededNoise(seed=st.session_state.seed) out = sample( model, sampler, H=512, W=512, seed=st.session_state.seed, prompt=prompt, filter=state.get("filter"), ) with cols[1]: st.image(out[0])