linoyts HF staff commited on
Commit
f217e4d
1 Parent(s): 826b04b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -10
app.py CHANGED
@@ -1,17 +1,39 @@
1
  import gradio as gr
2
  import spaces
 
 
 
 
 
 
 
 
3
 
4
  @spaces.GPU
5
- def generate(slider_x, slider_y, prompt):
 
 
 
 
 
 
 
 
 
6
  comma_concepts_x = ', '.join(slider_x)
7
  comma_concepts_y = ', '.join(slider_y)
8
- return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), gr.update()
9
 
10
- def update_x(slider_x):
11
- return gr.update()
 
 
 
 
 
12
 
13
- def update_y(slider_y):
14
- return gr.update()
 
15
 
16
  css = '''
17
  #group {
@@ -38,6 +60,12 @@ css = '''
38
  #image_out{position:absolute; width: 80%; right: 10px; top: 40px}
39
  '''
40
  with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
41
  with gr.Row():
42
  with gr.Column():
43
  slider_x = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
@@ -50,10 +78,10 @@ with gr.Blocks(css=css) as demo:
50
  output_image = gr.Image(elem_id="image_out")
51
 
52
  submit.click(fn=generate,
53
- inputs=[slider_x, slider_y, prompt],
54
- outputs=[x, y, output_image])
55
- x.change(fn=update_x, inputs=[slider_x], outputs=[output_image])
56
- y.change(fn=update_y, inputs=[slider_y], outputs=[output_image])
57
 
58
  if __name__ == "__main__":
59
  demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
+ import torch
4
+ from clip_slider_pipeline import CLIPSliderXL
5
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
6
+
7
+
8
+ flash_pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash").to("cuda", torch.float16)
9
+ flash_pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
10
+ clip_slider = CLIPSliderXL(flash_pipe, device=torch.device("cuda"))
11
 
12
  @spaces.GPU
13
+ def generate(slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2):
14
+
15
+ # check if avg diff for directions need to be re-calculated
16
+ if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
17
+ clip_slider.avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[0])
18
+ x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
19
+ if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
20
+ clip_slider.avg_diff_2nd = clip_slider.find_latent_direction(slider_y[0], slider_y[0])
21
+ y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
22
+
23
  comma_concepts_x = ', '.join(slider_x)
24
  comma_concepts_y = ', '.join(slider_y)
 
25
 
26
+ image = clip_slider(prompt, scale=0, scale_2nd=0, num_inference_steps=8)
27
+
28
+ return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, image
29
+
30
+ def update_x(x,y,prompt):
31
+ image = clip_slider(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
32
+ return image
33
 
34
+ def update_y(x,y,prompt):
35
+ image = clip_slider(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
36
+ return image
37
 
38
  css = '''
39
  #group {
 
60
  #image_out{position:absolute; width: 80%; right: 10px; top: 40px}
61
  '''
62
  with gr.Blocks(css=css) as demo:
63
+
64
+ x_concept_1 = gr.State("")
65
+ x_concept_2 = gr.State("")
66
+ y_concept_1 = gr.State("")
67
+ y_concept_2 = gr.State("")
68
+
69
  with gr.Row():
70
  with gr.Column():
71
  slider_x = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
 
78
  output_image = gr.Image(elem_id="image_out")
79
 
80
  submit.click(fn=generate,
81
+ inputs=[slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2],
82
+ outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, output_image])
83
+ x.change(fn=update_x, inputs=[x,y, prompt], outputs=[output_image])
84
+ y.change(fn=update_y, inputs=[x,y, prompt], outputs=[output_image])
85
 
86
  if __name__ == "__main__":
87
  demo.launch()