File size: 2,809 Bytes
07a421e
 
23dca80
07a421e
 
d2d56e8
b5b4791
 
 
07a421e
 
 
 
 
a960bc2
 
b5b4791
7785249
c66e22e
85913ad
07a421e
 
 
 
b5b4791
7785249
 
7ca8bcd
b5b4791
7785249
 
 
3654a3e
a59bcf0
07a421e
 
 
a960bc2
 
 
b5b4791
a960bc2
b5b4791
a960bc2
 
b5b4791
a960bc2
4715e52
a960bc2
07a421e
23dca80
07a421e
4715e52
07a421e
 
 
 
 
 
 
 
 
7785249
4715e52
 
7785249
07a421e
 
 
 
 
7ca8bcd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from diffusers import FluxPipeline
from transformers import pipeline
import gradio as gr
import spaces


device=torch.device('cuda')

# Load the model and LoRA weights
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
pipe.fuse_lora(lora_scale=1.5)
pipe.to("cuda")

# Load the NSFW classifier
image_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection",device=device)
#text_classifier = pipeline("text-classification", model="eliasalbouzidi/distilbert-nsfw-text-classifier",device=device)

NSFW_THRESHOLD = 0.3

# Define the function to generate the sketch
@spaces.GPU
def generate_sketch(prompt, num_inference_steps, guidance_scale):
    # Classify the text for NSFW content
    #text_classification = text_classifier(prompt)
    #print(text_classification)
    
    # Check the classification results
    #for result in text_classification:
    #    if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD: 
    #        return gr.update(visible=False),gr.Text(value="Inappropriate prompt detected. Please try another prompt.")
    print(prompt)
    image = pipe("sketched style, " + prompt, 
                 num_inference_steps=num_inference_steps, 
                 guidance_scale=guidance_scale,
                ).images[0]
    

    # Classify the image for NSFW content
    image_classification = image_classifier(image)

    print(image_classification)

    # Check the classification results
    for result in image_classification:
        if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD: 
           return None,"Inappropriate content detected. Please try another prompt." #return gr.update(visible=False),gr.Text(value="Inappropriate content detected. Please try another prompt.")
    
    image_path = "generated_sketch.png"
    
    image.save(image_path)
    return image_path,None #gr.Image(value=image_path), gr.update(visible=False)

# Gradio interface with sliders for num_inference_steps and guidance_scale
interface = gr.Interface(
    fn=generate_sketch,
    inputs=[
        "text",  # Prompt input
        gr.Slider(5, 50, value=24, step=1, label="Number of Inference Steps"),  # Slider for num_inference_steps
        gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance Scale")  # Slider for guidance_scale
    ],
    outputs=[
        gr.Image(label="Generated Sketch"), 
        gr.Textbox(label="Message")  
    ],
    title="Kids Sketch Generator",
    description="Enter a text prompt and generate a fun sketch for kids with customizable inference steps and guidance scale."
)

# Launch the app
interface.launch()