|
import torch |
|
from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline |
|
|
|
@torch.no_grad() |
|
def generate(pipe, prompt, negative_prompt, **generating_conf): |
|
pipe_longprompt = StableDiffusionLongPromptWeightingPipeline( |
|
unet=pipe.unet, |
|
text_encoder=pipe.text_encoder, |
|
vae=pipe.vae, |
|
tokenizer=pipe.tokenizer, |
|
scheduler=pipe.scheduler, |
|
safety_checker=None, |
|
feature_extractor=None, |
|
) |
|
print('generating: ', prompt) |
|
print('using negative prompt: ', negative_prompt) |
|
embeds = pipe_longprompt._encode_prompt(prompt=prompt, negative_prompt=negative_prompt, device=pipe.device, num_images_per_prompt=1, do_classifier_free_guidance=generating_conf['guidance_scale']>1,) |
|
negative_prompt_embeds, prompt_embeds = embeds.split(embeds.shape[0]//2) |
|
pipe_out = pipe( |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
**generating_conf, |
|
) |
|
return pipe_out |
|
|
|
if __name__ == '__main__': |
|
from diffusers.pipelines import StableDiffusionPipeline |
|
import argparse |
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--prompts',type=str,default=['starry night','Impression Sunrise, drawn by Claude Monet'], nargs='*' |
|
) |
|
|
|
args = parser.parse_args() |
|
prompts = args.prompts |
|
print(f'generating {prompts}') |
|
model_id = 'pretrained_model/sd-v1-4' |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id,).to('cuda') |
|
images = pipe(prompts).images |
|
for i, image in enumerate(images): |
|
image.save(f'{prompts[i]}_{i}.png') |
|
|
|
main() |
|
|