tombetthauser commited on
Commit
6f7ae8f
β€’
1 Parent(s): 5da0d3d

Added prompt filtering for terms of service

Browse files
Files changed (1) hide show
  1. app.py +36 -6
app.py CHANGED
@@ -28,6 +28,17 @@ my_token = os.environ['api_key']
28
 
29
  pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", revision="fp16", torch_dtype=torch.float16, use_auth_token=my_token).to("cuda")
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
32
  loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
33
 
@@ -135,17 +146,25 @@ def image_prompt(prompt, guidance, steps, seed, height, width):
135
  # image_count += 1
136
  curr_time = datetime.datetime.now()
137
 
 
 
138
  print("----- advanced tab prompt ------------------------------")
139
  print(f"prompt: {prompt}, size: {width}px x {height}px, guidance: {guidance}, steps: {steps}, seed: {int(seed)}")
140
  # print(f"image_count: {image_count}, datetime: `{e}`")
141
  print(f"datetime: `{curr_time}`")
 
142
  print("-------------------------------------------------------")
143
-
144
- return (
 
145
  pipe(prompt=prompt, guidance_scale=guidance, num_inference_steps=steps, generator=generator, height=height, width=width).images[0],
146
  f"prompt: '{prompt}', seed = {int(seed)},\nheight: {height}px, width: {width}px,\nguidance: {guidance}, steps: {steps}"
147
- )
148
-
 
 
 
 
149
 
150
  def default_guidance():
151
  return 7.5
@@ -282,17 +301,28 @@ def simple_image_prompt(prompt, dropdown, size_dropdown):
282
 
283
  # image_count += 1
284
  curr_time = datetime.datetime.now()
 
285
 
286
  print("----- welcome / beta tab prompt ------------------------------")
287
  print(f"prompt: {prompt}, size: {width}px x {height}px, guidance: {guidance}, steps: {steps}, seed: {int(seed)}")
288
  # print(f"image_count: {image_count}, datetime: `{e}`")
289
  print(f"datetime: `{curr_time}`")
 
290
  print("-------------------------------------------------------")
291
-
292
- return (
 
293
  pipe(prompt=prompt, guidance_scale=guidance, num_inference_steps=steps, generator=generator, height=height, width=width).images[0],
294
  f"prompt: '{prompt}', seed = {int(seed)},\nheight: {height}px, width: {width}px,\nguidance: {guidance}, steps: {steps}"
295
  )
 
 
 
 
 
 
 
 
296
 
297
 
298
 
 
28
 
29
  pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", revision="fp16", torch_dtype=torch.float16, use_auth_token=my_token).to("cuda")
30
 
31
+ def check_prompt(prompt):
32
+ SPAM_WORDS = [
33
+ "seductive",
34
+ "breast"
35
+ ]
36
+ for spam_word in SPAM_WORDS:
37
+ if spam_word in prompt:
38
+ return False
39
+ return True
40
+
41
+
42
  def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
43
  loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
44
 
 
146
  # image_count += 1
147
  curr_time = datetime.datetime.now()
148
 
149
+ is_clean = check_prompt(prompt)
150
+
151
  print("----- advanced tab prompt ------------------------------")
152
  print(f"prompt: {prompt}, size: {width}px x {height}px, guidance: {guidance}, steps: {steps}, seed: {int(seed)}")
153
  # print(f"image_count: {image_count}, datetime: `{e}`")
154
  print(f"datetime: `{curr_time}`")
155
+ print(f"is_prompt_clean: {is_clean}")
156
  print("-------------------------------------------------------")
157
+
158
+ if is_clean:
159
+ return (
160
  pipe(prompt=prompt, guidance_scale=guidance, num_inference_steps=steps, generator=generator, height=height, width=width).images[0],
161
  f"prompt: '{prompt}', seed = {int(seed)},\nheight: {height}px, width: {width}px,\nguidance: {guidance}, steps: {steps}"
162
+ )
163
+ else:
164
+ return (
165
+ pipe(prompt="", guidance_scale=0, num_inference_steps=1, generator=generator, height=50, width=50).images[0],
166
+ f"Prompt violates Hugging Face's Terms of Service"
167
+ )
168
 
169
  def default_guidance():
170
  return 7.5
 
301
 
302
  # image_count += 1
303
  curr_time = datetime.datetime.now()
304
+ is_clean = check_prompt(prompt)
305
 
306
  print("----- welcome / beta tab prompt ------------------------------")
307
  print(f"prompt: {prompt}, size: {width}px x {height}px, guidance: {guidance}, steps: {steps}, seed: {int(seed)}")
308
  # print(f"image_count: {image_count}, datetime: `{e}`")
309
  print(f"datetime: `{curr_time}`")
310
+ print(f"is_prompt_clean: {is_clean}")
311
  print("-------------------------------------------------------")
312
+
313
+ if is_clean:
314
+ return (
315
  pipe(prompt=prompt, guidance_scale=guidance, num_inference_steps=steps, generator=generator, height=height, width=width).images[0],
316
  f"prompt: '{prompt}', seed = {int(seed)},\nheight: {height}px, width: {width}px,\nguidance: {guidance}, steps: {steps}"
317
  )
318
+ else:
319
+ return (
320
+ pipe(prompt="", guidance_scale=0, num_inference_steps=1, generator=generator, height=50, width=50).images[0],
321
+ f"Prompt violates Hugging Face's Terms of Service"
322
+ )
323
+
324
+
325
+
326
 
327
 
328