Files changed (1) hide show
  1. app.py +1 -32
app.py CHANGED
@@ -20,33 +20,10 @@ checkpoints = {
20
  }
21
  loaded = None
22
 
23
-
24
  # Ensure model and scheduler are initialized in GPU-enabled function
25
  if torch.cuda.is_available():
26
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
27
 
28
- if SAFETY_CHECKER:
29
- from safety_checker import StableDiffusionSafetyChecker
30
- from transformers import CLIPFeatureExtractor
31
-
32
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(
33
- "CompVis/stable-diffusion-safety-checker"
34
- ).to("cuda")
35
- feature_extractor = CLIPFeatureExtractor.from_pretrained(
36
- "openai/clip-vit-base-patch32"
37
- )
38
-
39
- def check_nsfw_images(
40
- images: list[Image.Image],
41
- ) -> tuple[list[Image.Image], list[bool]]:
42
- safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
43
- has_nsfw_concepts = safety_checker(
44
- images=[images],
45
- clip_input=safety_checker_input.pixel_values.to("cuda")
46
- )
47
-
48
- return images, has_nsfw_concepts
49
-
50
  # Function
51
  @spaces.GPU(enable_queue=True)
52
  def generate_image(prompt, ckpt):
@@ -63,16 +40,8 @@ def generate_image(prompt, ckpt):
63
 
64
  results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
65
 
66
- if SAFETY_CHECKER:
67
- images, has_nsfw_concepts = check_nsfw_images(results.images)
68
- if any(has_nsfw_concepts):
69
- gr.Warning("NSFW content detected.")
70
- return Image.new("RGB", (512, 512))
71
- return images[0]
72
  return results.images[0]
73
 
74
-
75
-
76
  # Gradio Interface
77
  description = """
78
  This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
@@ -98,4 +67,4 @@ with gr.Blocks(css="style.css") as demo:
98
  outputs=img,
99
  )
100
 
101
- demo.queue().launch()
 
20
  }
21
  loaded = None
22
 
 
23
  # Ensure model and scheduler are initialized in GPU-enabled function
24
  if torch.cuda.is_available():
25
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Function
28
  @spaces.GPU(enable_queue=True)
29
  def generate_image(prompt, ckpt):
 
40
 
41
  results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
42
 
 
 
 
 
 
 
43
  return results.images[0]
44
 
 
 
45
  # Gradio Interface
46
  description = """
47
  This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
 
67
  outputs=img,
68
  )
69
 
70
+ demo.queue().launch()