linoyts HF staff commited on
Commit
b658584
1 Parent(s): 0382e32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -25,6 +25,7 @@ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell",
25
 
26
  pipe.transformer.to(memory_format=torch.channels_last)
27
  pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
 
28
  clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))
29
 
30
 
@@ -46,6 +47,7 @@ def generate(slider_x, prompt, seed, recalc_directions, iterations, steps, guida
46
  # check if avg diff for directions need to be re-calculated
47
  print("slider_x", slider_x)
48
  print("x_concept_1", x_concept_1, "x_concept_2", x_concept_2)
 
49
 
50
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]) or recalc_directions:
51
  #avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations).to(torch.float16)
@@ -63,6 +65,8 @@ def generate(slider_x, prompt, seed, recalc_directions, iterations, steps, guida
63
  scale=0, scale_2nd=0,
64
  seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
65
 
 
 
66
  comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"
67
 
68
  avg_diff_x = avg_diff.cpu()
@@ -75,16 +79,36 @@ def update_scales(x,prompt,seed, steps, guidance_scale,
75
  img2img_type = None, img = None,
76
  controlnet_scale= None, ip_adapter_scale=None,):
77
  avg_diff = avg_diff_x.cuda()
 
78
  if img2img_type=="controlnet canny" and img is not None:
79
  control_img = process_controlnet_img(img)
80
  image = t5_slider_controlnet.generate(prompt, guidance_scale=guidance_scale, image=control_img, controlnet_conditioning_scale =controlnet_scale, scale=x, seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
81
  elif img2img_type=="ip adapter" and img is not None:
82
  image = clip_slider.generate(prompt, guidance_scale=guidance_scale, ip_adapter_image=img, scale=x,seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
83
  else:
84
- image = clip_slider.generate(prompt,
85
- #guidance_scale=guidance_scale,
86
- scale=x,
87
- seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  return image
89
 
90
  def reset_recalc_directions():
@@ -98,14 +122,12 @@ css = '''
98
  margin-bottom: 20px;
99
  background-color: white;
100
  }
101
-
102
  #x {
103
  position: absolute;
104
  bottom: 20px; /* Moved further down */
105
  left: 30px; /* Adjusted left margin */
106
  width: 540px; /* Increased width to match the new container size */
107
  }
108
-
109
  #y {
110
  position: absolute;
111
  bottom: 200px; /* Increased bottom margin to ensure proper spacing from #x */
@@ -114,14 +136,12 @@ css = '''
114
  transform: rotate(-90deg);
115
  transform-origin: left bottom;
116
  }
117
-
118
  #image_out {
119
  position: absolute;
120
  width: 80%; /* Adjust width as needed */
121
  right: 10px;
122
  top: 10px; /* Increased top margin to clear space occupied by #x */
123
  }
124
-
125
  '''
126
  intro = """
127
  <div style="display: flex;align-items: center;justify-content: center">
@@ -166,7 +186,7 @@ with gr.Blocks(css=css) as demo:
166
  submit = gr.Button("find directions")
167
  with gr.Column():
168
  with gr.Group(elem_id="group"):
169
- x = gr.Slider(minimum=-3, value=0, maximum=3.5, step=0.1, elem_id="x", interactive=False)
170
  #y = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
171
  output_image = gr.Image(elem_id="image_out")
172
  # with gr.Row():
 
25
 
26
  pipe.transformer.to(memory_format=torch.channels_last)
27
  pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
28
+ #pipe.enable_model_cpu_offload()
29
  clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))
30
 
31
 
 
47
  # check if avg diff for directions need to be re-calculated
48
  print("slider_x", slider_x)
49
  print("x_concept_1", x_concept_1, "x_concept_2", x_concept_2)
50
+ #torch.manual_seed(seed)
51
 
52
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]) or recalc_directions:
53
  #avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations).to(torch.float16)
 
65
  scale=0, scale_2nd=0,
66
  seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
67
 
68
+
69
+ #comma_concepts_x = ', '.join(slider_x)
70
  comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"
71
 
72
  avg_diff_x = avg_diff.cpu()
 
79
  img2img_type = None, img = None,
80
  controlnet_scale= None, ip_adapter_scale=None,):
81
  avg_diff = avg_diff_x.cuda()
82
+ torch.manual_seed(seed)
83
  if img2img_type=="controlnet canny" and img is not None:
84
  control_img = process_controlnet_img(img)
85
  image = t5_slider_controlnet.generate(prompt, guidance_scale=guidance_scale, image=control_img, controlnet_conditioning_scale =controlnet_scale, scale=x, seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
86
  elif img2img_type=="ip adapter" and img is not None:
87
  image = clip_slider.generate(prompt, guidance_scale=guidance_scale, ip_adapter_image=img, scale=x,seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
88
  else:
89
+ image = clip_slider.generate(prompt, guidance_scale=guidance_scale, scale=x, seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
90
+ return image
91
+
92
+
93
+
94
+ @spaces.GPU
95
+ def update_x(x,y,prompt,seed, steps,
96
+ avg_diff_x, avg_diff_y,
97
+ img2img_type = None,
98
+ img = None):
99
+ avg_diff = avg_diff_x.cuda()
100
+ avg_diff_2nd = avg_diff_y.cuda()
101
+ image = clip_slider.generate(prompt, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
102
+ return image
103
+
104
+ @spaces.GPU
105
+ def update_y(x,y,prompt,seed, steps,
106
+ avg_diff_x, avg_diff_y,
107
+ img2img_type = None,
108
+ img = None):
109
+ avg_diff = avg_diff_x.cuda()
110
+ avg_diff_2nd = avg_diff_y.cuda()
111
+ image = clip_slider.generate(prompt, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
112
  return image
113
 
114
  def reset_recalc_directions():
 
122
  margin-bottom: 20px;
123
  background-color: white;
124
  }
 
125
  #x {
126
  position: absolute;
127
  bottom: 20px; /* Moved further down */
128
  left: 30px; /* Adjusted left margin */
129
  width: 540px; /* Increased width to match the new container size */
130
  }
 
131
  #y {
132
  position: absolute;
133
  bottom: 200px; /* Increased bottom margin to ensure proper spacing from #x */
 
136
  transform: rotate(-90deg);
137
  transform-origin: left bottom;
138
  }
 
139
  #image_out {
140
  position: absolute;
141
  width: 80%; /* Adjust width as needed */
142
  right: 10px;
143
  top: 10px; /* Increased top margin to clear space occupied by #x */
144
  }
 
145
  '''
146
  intro = """
147
  <div style="display: flex;align-items: center;justify-content: center">
 
186
  submit = gr.Button("find directions")
187
  with gr.Column():
188
  with gr.Group(elem_id="group"):
189
+ x = gr.Slider(minimum=-3, value=0, maximum=3.5, elem_id="x", interactive=False)
190
  #y = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
191
  output_image = gr.Image(elem_id="image_out")
192
  # with gr.Row():