sayakpaul HF staff commited on
Commit
0acc836
1 Parent(s): d46e1fe
app.py CHANGED
@@ -1,15 +1,150 @@
 
 
 
 
1
  import torch
2
  from diffusers import StableDiffusionInstructPix2PixPipeline
3
- import gradio as gr
4
- import PIL
5
 
6
  cartoonization_id = "instruction-tuning-sd/cartoonizer"
7
  image_proc_id = "instruction-tuning-sd/low-level-img-proc"
8
 
 
 
 
 
9
  def load_pipeline(id: str):
10
- pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(id, torch_dtype=torch.float16).to("cuda")
11
- return pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- def infer(prompt: str, image: PIL.Image.Image, num_inference_steps:int, img_guidance_scale: float):
15
- pass
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import PIL
5
  import torch
6
  from diffusers import StableDiffusionInstructPix2PixPipeline
 
 
7
 
8
  cartoonization_id = "instruction-tuning-sd/cartoonizer"
9
  image_proc_id = "instruction-tuning-sd/low-level-img-proc"
10
 
11
+ title = "Instruction-tuned Stable Diffusion"
12
+ description = "This Space demonstrates the instruction-tuning on Stable Diffusion. To know more, please check out the [corresponding blog post](https://hf.co/blog/instruction-tuning-sd). Some experimentation tips are available from [the original InstructPix2Pix Space](https://huggingface.co/spaces/timbrooks/instruct-pix2pix)."
13
+
14
+
15
  def load_pipeline(id: str):
16
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
17
+ id, torch_dtype=torch.float16
18
+ ).to("cuda")
19
+ return pipeline
20
+
21
+
22
+ def infer_cartoonization(
23
+ prompt: str,
24
+ negative_prompt: str,
25
+ image: PIL.Image.Image,
26
+ steps: int,
27
+ img_cfg: float,
28
+ text_cfg: float,
29
+ seed: int,
30
+ ):
31
+ pipeline = load_pipeline(cartoonization_id)
32
+ images = pipeline(
33
+ prompt,
34
+ image,
35
+ negative_prompt=negative_prompt,
36
+ num_inference_steps=int(steps),
37
+ image_guidance_scale=img_cfg,
38
+ guidance_scale=text_cfg,
39
+ generator=torch.manual_seed(int(seed)),
40
+ )
41
+ return images
42
+
43
+
44
+ def infer_img_proc(
45
+ prompt: str,
46
+ negative_prompt: str,
47
+ image: PIL.Image.Image,
48
+ steps: int,
49
+ img_cfg: float,
50
+ text_cfg: float,
51
+ seed: int,
52
+ ):
53
+ pipeline = load_pipeline(image_proc_id)
54
+ images = pipeline(
55
+ prompt,
56
+ image,
57
+ negative_prompt=negative_prompt,
58
+ num_inference_steps=int(steps),
59
+ image_guidance_scale=img_cfg,
60
+ guidance_scale=text_cfg,
61
+ generator=torch.manual_seed(int(seed)),
62
+ )
63
+ return images
64
+
65
+
66
+ examples = [
67
+ [
68
+ cartoonization_id,
69
+ "cartoonize this image",
70
+ "low quality",
71
+ "examples/mountain.png",
72
+ 20,
73
+ 1.5,
74
+ 7.5,
75
+ random.randint(0, 100000),
76
+ ],
77
+ [
78
+ image_proc_id,
79
+ "derain this image",
80
+ "low quality",
81
+ "examples/duck.png",
82
+ 20,
83
+ 1.5,
84
+ 7.5,
85
+ random.randint(0, 100000),
86
+ ],
87
+ ]
88
+
89
+ with gr.Blocks(theme="gradio/soft") as demo:
90
+ gr.Markdown(f"## {title}")
91
+ gr.Markdown(description)
92
+
93
+ with gr.Tab("Cartoonization"):
94
+ prompt = gr.Textbox(label="Prompt")
95
+ neg_prompt = gr.Textbox(label="Negative Prompt")
96
+ input_image = gr.Image(label="Input Image")
97
+ steps = gr.Slider(minimum=5, maximum=100, step=1)
98
+ img_cfg = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
99
+ text_cfg = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
100
+ seed = gr.Slider(minimum=0, maximum=100000, step=1)
101
+
102
+ car_output_gallery = gr.Gallery(
103
+ label="Generated images", show_label=False, elem_id="gallery"
104
+ ).style(columns=[2], rows=[2], object_fit="contain", height="auto")
105
+ submit_btn = gr.Button(value="Submit")
106
+ all_car_inputs = [prompt, neg_prompt, input_image, img_cfg, text_cfg, seed]
107
+ submit_btn.click(
108
+ fn=infer_cartoonization,
109
+ inputs=all_car_inputs,
110
+ outputs=[car_output_gallery],
111
+ )
112
+
113
+ with gr.Tab("Low-level image processing"):
114
+ rompt = gr.Textbox(label="Prompt")
115
+ neg_prompt = gr.Textbox(label="Negative Prompt")
116
+ input_image = gr.Image(label="Input Image")
117
+ steps = gr.Slider(minimum=5, maximum=100, step=1)
118
+ img_cfg = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
119
+ text_cfg = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
120
+ seed = gr.Slider(minimum=0, maximum=100000, step=1)
121
+
122
+ img_proc_output_gallery = gr.Gallery(
123
+ label="Generated images", show_label=False, elem_id="gallery"
124
+ ).style(columns=[2], rows=[2], object_fit="contain", height="auto")
125
+ submit_btn = gr.Button(value="Submit")
126
+ all_img_proc_inputs = [prompt, neg_prompt, input_image, img_cfg, text_cfg, seed]
127
+ submit_btn.click(
128
+ fn=infer_img_proc,
129
+ inputs=all_img_proc_inputs,
130
+ outputs=[img_proc_output_gallery],
131
+ )
132
 
133
+ gr.Markdown("### Cartoonization example")
134
+ gr.Examples(
135
+ [examples[0]],
136
+ inputs=all_car_inputs,
137
+ outputs=car_output_gallery,
138
+ fn=infer_cartoonization,
139
+ cache_examples=True,
140
+ )
141
+ gr.Markdown("### Low-level image processing example")
142
+ gr.Examples(
143
+ [examples[0]],
144
+ inputs=all_img_proc_inputs,
145
+ outputs=img_proc_output_gallery,
146
+ fn=infer_img_proc,
147
+ cache_examples=True,
148
+ )
149
 
150
+ demo.launch()
 
examples/{derain_the_image_1.png → duck.png} RENAMED
File without changes