Update README.md
Browse files
README.md
CHANGED
@@ -12,4 +12,41 @@ tags:
|
|
12 |
|
13 |
Pretrained SD-1.5 weight for SePPO
|
14 |
|
15 |
-
See Github Repo:[SePPO](https://github.com/DwanZhang-AI/SePPO/tree/main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
Pretrained SD-1.5 weight for SePPO
|
14 |
|
15 |
+
See Github Repo:[SePPO](https://github.com/DwanZhang-AI/SePPO/tree/main)
|
16 |
+
|
17 |
+
Inference Code:
|
18 |
+
|
19 |
+
```
|
20 |
+
import os
|
21 |
+
import argparse
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
25 |
+
from PIL import Image
|
26 |
+
|
27 |
+
torch.set_grad_enabled(False)
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
parser = argparse.ArgumentParser(description="Generate images and calculate scores.")
|
31 |
+
parser.add_argument('--unet_checkpoint', type=str, required=True, help="Path to the UNet model checkpoint")
|
32 |
+
parser.add_argument('--prompt', type=str, required=True, help="Prompt")
|
33 |
+
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
unet = UNet2DConditionModel.from_pretrained(args.unet_checkpoint, torch_dtype=torch.float16).to('cuda')
|
37 |
+
|
38 |
+
pipe = StableDiffusionPipeline.from_pretrained("pt-sk/stable-diffusion-1.5", torch_dtype=torch.float16)
|
39 |
+
|
40 |
+
pipe = pipe.to('cuda')
|
41 |
+
pipe.safety_checker = None
|
42 |
+
pipe.unet = unet
|
43 |
+
generator = torch.Generator(device='cuda').manual_seed(0)
|
44 |
+
gs = 7.5
|
45 |
+
|
46 |
+
ims = pipe(prompt=args.prompt, generator=generator, guidance_scale=gs).images[0]
|
47 |
+
img_path = os.path.join('SePPO', "0.png")
|
48 |
+
|
49 |
+
if isinstance(ims, np.ndarray):
|
50 |
+
ims = Image.fromarray(ims)
|
51 |
+
ims.save(img_path, format='PNG')
|
52 |
+
```
|