Text-to-Image
Diffusers
Safetensors
text-generation-inference
File size: 1,657 Bytes
0a834b4
 
78de8e7
 
 
 
 
1cfb6fb
78de8e7
 
0a834b4
 
1cfb6fb
78de8e7
fdaadb2
 
 
2708f7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
---
license: apache-2.0
datasets:
- yuvalkirstain/pickapic_v2
base_model:
- stable-diffusion-v1-5/stable-diffusion-v1-5
library_name: diffusers
pipeline_tag: text-to-image
tags:
- text-generation-inference
---

Pretrained SD-1.5 weight for [SePPO: Semi-Policy Preference Optimization for Diffusion Alignment](https://huggingface.co/papers/2410.05255)

See Github Repo: [SePPO](https://github.com/DwanZhang-AI/SePPO/tree/main)

Paper Report: [Daily Paper](https://huggingface.co/papers/2410.05255)

Inference Code:

```
import os
import argparse
import numpy as np
import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from PIL import Image

torch.set_grad_enabled(False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate images and calculate scores.")
    parser.add_argument('--unet_checkpoint', type=str, required=True, help="Path to the UNet model checkpoint")
    parser.add_argument('--prompt', type=str, required=True, help="Prompt")

    args = parser.parse_args()

    unet = UNet2DConditionModel.from_pretrained(args.unet_checkpoint, torch_dtype=torch.float16).to('cuda')

    pipe = StableDiffusionPipeline.from_pretrained("pt-sk/stable-diffusion-1.5", torch_dtype=torch.float16)

    pipe = pipe.to('cuda')
    pipe.safety_checker = None
    pipe.unet = unet
    generator = torch.Generator(device='cuda').manual_seed(0)
    gs = 7.5

    ims = pipe(prompt=args.prompt, generator=generator, guidance_scale=gs).images[0]
    img_path = os.path.join('SePPO', "0.png")
    
    if isinstance(ims, np.ndarray):
        ims = Image.fromarray(ims)
    ims.save(img_path, format='PNG')
```