linoyts HF staff commited on
Commit
b1c5569
1 Parent(s): 1b59bc0

offloading to cpu

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -10,7 +10,10 @@ flash_pipe.scheduler = EulerDiscreteScheduler.from_config(flash_pipe.scheduler.c
10
  clip_slider = CLIPSliderXL(flash_pipe, device=torch.device("cuda"), iterations=50)
11
 
12
  @spaces.GPU
13
- def generate(slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y):
 
 
 
14
 
15
  # check if avg diff for directions need to be re-calculated
16
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
@@ -25,16 +28,18 @@ def generate(slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1,
25
  comma_concepts_x = ', '.join(slider_x)
26
  comma_concepts_y = ', '.join(slider_y)
27
 
28
- avg_diff_x = clip_slider.avg_diff.cpu()
29
- avg_diff_y = clip_slider.avg_diff_2nd.cpu()
 
 
30
 
31
- return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, image
32
 
33
- def update_x(x,y,prompt, avg_diff_x, avg_diff_y):
34
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
35
  return image
36
 
37
- def update_y(x,y,prompt, avg_diff_x, avg_diff_y):
38
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
39
  return image
40
 
@@ -69,8 +74,10 @@ with gr.Blocks(css=css) as demo:
69
  y_concept_1 = gr.State("")
70
  y_concept_2 = gr.State("")
71
 
72
- avg_diff_x = gr.State()
73
- avg_diff_y = gr.State()
 
 
74
 
75
  with gr.Row():
76
  with gr.Column():
@@ -84,10 +91,10 @@ with gr.Blocks(css=css) as demo:
84
  output_image = gr.Image(elem_id="image_out")
85
 
86
  submit.click(fn=generate,
87
- inputs=[slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y],
88
- outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image])
89
- x.change(fn=update_x, inputs=[x,y, prompt, avg_diff_x, avg_diff_y], outputs=[output_image])
90
- y.change(fn=update_y, inputs=[x,y, prompt, avg_diff_x, avg_diff_y], outputs=[output_image])
91
 
92
  if __name__ == "__main__":
93
  demo.launch()
 
10
  clip_slider = CLIPSliderXL(flash_pipe, device=torch.device("cuda"), iterations=50)
11
 
12
  @spaces.GPU
13
+ def generate(slider_x, slider_y, prompt,
14
+ x_concept_1, x_concept_2, y_concept_1, y_concept_2,
15
+ avg_diff_x_1, avg_diff_x_2,
16
+ avg_diff_y_1, avg_diff_y_2):
17
 
18
  # check if avg diff for directions need to be re-calculated
19
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
 
28
  comma_concepts_x = ', '.join(slider_x)
29
  comma_concepts_y = ', '.join(slider_y)
30
 
31
+ avg_diff_x_1 = clip_slider.avg_diff[0].cpu()
32
+ avg_diff_x_2 = clip_slider.avg_diff[1].cpu()
33
+ avg_diff_y_1 = clip_slider.avg_diff_2nd[0].cpu()
34
+ avg_diff_y_2 = clip_slider.avg_diff_2nd[1].cpu()
35
 
36
+ return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, image
37
 
38
+ def update_x(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
39
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
40
  return image
41
 
42
+ def update_y(x,y,prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
43
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
44
  return image
45
 
 
74
  y_concept_1 = gr.State("")
75
  y_concept_2 = gr.State("")
76
 
77
+ avg_diff_x_1 = gr.State()
78
+ avg_diff_x_2 = gr.State()
79
+ avg_diff_y_1 = gr.State()
80
+ avg_diff_y_2 = gr.State()
81
 
82
  with gr.Row():
83
  with gr.Column():
 
91
  output_image = gr.Image(elem_id="image_out")
92
 
93
  submit.click(fn=generate,
94
+ inputs=[slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2],
95
+ outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image])
96
+ x.change(fn=update_x, inputs=[x,y, prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
97
+ y.change(fn=update_y, inputs=[x,y, prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
98
 
99
  if __name__ == "__main__":
100
  demo.launch()