File size: 3,416 Bytes
39b3d5c
 
33b0931
2053864
 
33b0931
 
2053864
39b3d5c
 
33b0931
39b3d5c
 
2162e3e
33b0931
39b3d5c
 
 
 
c1a9a8e
ba00eb1
39b3d5c
aec599b
 
 
 
 
 
39b3d5c
 
d77c785
39b3d5c
1bb2af8
 
 
35a3dd9
 
67e0c9d
1bb2af8
 
 
 
2053864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb2af8
39b3d5c
1bb2af8
 
 
 
2053864
 
39b3d5c
 
 
a2ddcd6
e6af018
39b3d5c
cefcc96
1bb2af8
2053864
39b3d5c
 
1bb2af8
52f87fe
 
2053864
 
1bb2af8
145dc43
219f0c6
 
4cda995
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from gradio_client import Client
import os 
import numpy as np
import random

hf_token = os.environ.get("HF_TKN")
MAX_SEED = np.iinfo(np.int32).max

def get_caption(image_in):
    client = Client("https://fffiloni-moondream1.hf.space/", hf_token=hf_token)
    result = client.predict(
		image_in,	# filepath  in 'image' Image component
		"Describe the image",	# str  in 'Question' Textbox component
		api_name="/predict"
    )
    print(result)
    return result

def get_lcm(prompt):
    client = Client("https://latent-consistency-lcm-lora-for-sdxl.hf.space/")
    result = client.predict(
        prompt,	# str  in 'parameter_5' Textbox component
        0.3,	# float (numeric value between 0.0 and 5) in 'Guidance' Slider component
        8,	# float (numeric value between 2 and 10) in 'Steps' Slider component
        0,	# float (numeric value between 0 and 12013012031030) in 'Seed' Slider component
        True,	# bool  in 'Randomize' Checkbox component
        api_name="/predict"
    )
    print(result)
    return result[0]

def get_sdxl_lightning(prompt):
    client = Client("AP123/SDXL-Lightning")
    result = client.predict(
        prompt,	# str  in 'parameter_1' Textbox component
        "4-Step",
        api_name="/generate_image"
    )
    print(result)
    return result

def get_turbo(prompt):
    seed = random.randint(0, MAX_SEED)
    print(f"SEED: {seed}")
    client = Client("https://diffusers-unofficial-sdxl-turbo-i2i-t2i.hf.space/")
    result = client.predict(
        None,	# filepath  in 'Webcam' Image component
        prompt,	# str  in 'parameter_5' Textbox component
        0.7,	# float (numeric value between 0.0 and 1.0) in 'Strength' Slider component
        4,	# float (numeric value between 1 and 10) in 'Steps' Slider component
        seed,	# float (numeric value between 0 and MAX_SEED) in 'Seed' Slider component
        api_name="/predict"
    )
    print(result)
    return result

def infer(image_in, chosen_method):
    caption = get_caption(image_in)
    if chosen_method == "LCM" :
        img_var = get_lcm(caption)
    elif chosen_method == "SDXL Lightning" :
        img_var = get_sdxl_lightning(caption)
    elif chosen_method == "SDXL Turbo" :
        img_var = get_turbo(caption)
    return img_var

gr.Interface(
    title = "Supa Fast Image Variation",
    description = "Get quick image variation from image input, using <a href='https://huggingface.co/vikhyatk/moondream1' target='_blank'>moondream1</a> for caption, and <a href='https://huggingface.co/spaces/latent-consistency/lcm-lora-for-sdxl' target='_blank'>LCM SDXL</a>, <a href='https://huggingface.co/spaces/AP123/SDXL-Lightning' target='_blank'>SDXL Lightning</a> or <a href='https://huggingface.co/spaces/diffusers/unofficial-SDXL-Turbo-i2i-t2i' target='_blank'>SDXL Turbo</a> for image generation",
    fn = infer,
    inputs = [
        gr.Image(type="filepath", label="Image input"),
        gr.Dropdown(label="Choose a model", choices=["LCM", "SDXL Lightning", "SDXL Turbo"], value="SDXL Lightning")
    ],
    outputs = [
        gr.Image(label="Image variation")
    ],
    examples = [
        ["examples/frog_clean.jpg", "LCM"],
        ["examples/martin_pecheur.jpeg", "SDXL Turbo"],
        ["examples/forest_deer.png", "SDXL Lightning"]
    ],
    cache_examples = False,
    concurrency_limit = 2
).queue(max_size=25).launch(show_api=False, show_error=True)