MagicDoodles / app.py
yasserrmd's picture
Update app.py
85913ad verified
raw
history blame contribute delete
No virus
2.81 kB
import torch
from diffusers import FluxPipeline
from transformers import pipeline
import gradio as gr
import spaces
device=torch.device('cuda')
# Load the model and LoRA weights
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
pipe.fuse_lora(lora_scale=1.5)
pipe.to("cuda")
# Load the NSFW classifier
image_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection",device=device)
#text_classifier = pipeline("text-classification", model="eliasalbouzidi/distilbert-nsfw-text-classifier",device=device)
NSFW_THRESHOLD = 0.3
# Define the function to generate the sketch
@spaces.GPU
def generate_sketch(prompt, num_inference_steps, guidance_scale):
# Classify the text for NSFW content
#text_classification = text_classifier(prompt)
#print(text_classification)
# Check the classification results
#for result in text_classification:
# if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
# return gr.update(visible=False),gr.Text(value="Inappropriate prompt detected. Please try another prompt.")
print(prompt)
image = pipe("sketched style, " + prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
# Classify the image for NSFW content
image_classification = image_classifier(image)
print(image_classification)
# Check the classification results
for result in image_classification:
if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
return None,"Inappropriate content detected. Please try another prompt." #return gr.update(visible=False),gr.Text(value="Inappropriate content detected. Please try another prompt.")
image_path = "generated_sketch.png"
image.save(image_path)
return image_path,None #gr.Image(value=image_path), gr.update(visible=False)
# Gradio interface with sliders for num_inference_steps and guidance_scale
interface = gr.Interface(
fn=generate_sketch,
inputs=[
"text", # Prompt input
gr.Slider(5, 50, value=24, step=1, label="Number of Inference Steps"), # Slider for num_inference_steps
gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance Scale") # Slider for guidance_scale
],
outputs=[
gr.Image(label="Generated Sketch"),
gr.Textbox(label="Message")
],
title="Kids Sketch Generator",
description="Enter a text prompt and generate a fun sketch for kids with customizable inference steps and guidance scale."
)
# Launch the app
interface.launch()