John6666 commited on
Commit
75db765
β€’
1 Parent(s): 2ff2df6

Upload 13 files

Browse files
Files changed (5) hide show
  1. README.md +2 -2
  2. app.py +86 -101
  3. model.py +3 -0
  4. multit2i.py +88 -105
  5. requirements.txt +1 -1
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Free Multi Models Text-to-Image Heavy-Armed Demo
3
  emoji: 🌐🌊
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.39.0
8
  app_file: app.py
9
  short_description: Text-to-Image
10
  pinned: true
 
1
  ---
2
+ title: Free Multi Models Text-to-Image Heavy-Armed Demo V2
3
  emoji: 🌐🌊
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.40.0
8
  app_file: app.py
9
  short_description: Text-to-Image
10
  pinned: true
app.py CHANGED
@@ -1,71 +1,37 @@
1
  import gradio as gr
2
  from model import models
3
  from multit2i import (
4
- load_models,
5
- infer_multi,
6
- infer_multi_random,
7
- save_gallery_images,
8
- change_model,
9
- get_model_info_md,
10
- loaded_models,
11
- get_positive_prefix,
12
- get_positive_suffix,
13
- get_negative_prefix,
14
- get_negative_suffix,
15
- get_recom_prompt_type,
16
- set_recom_prompt_preset,
17
- get_tag_type,
18
  )
19
  from tagger.tagger import (
20
- predict_tags_wd,
21
- remove_specific_prompt,
22
- convert_danbooru_to_e621_prompt,
23
- insert_recom_prompt,
24
- compose_prompt_to_copy,
25
  )
26
  from tagger.fl2sd3longcap import predict_tags_fl2_sd3
27
- from tagger.v2 import (
28
- V2_ALL_MODELS,
29
- v2_random_prompt,
30
- )
31
  from tagger.utils import (
32
- V2_ASPECT_RATIO_OPTIONS,
33
- V2_RATING_OPTIONS,
34
- V2_LENGTH_OPTIONS,
35
- V2_IDENTITY_OPTIONS,
36
  )
37
 
38
 
39
- load_models(models, 5)
40
- #load_models(models, 20) # Fetching 20 models at the same time. default: 5
41
-
42
 
43
  css = """
44
- #model_info { text-align: center; }
 
 
45
  """
46
 
47
- with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
48
  with gr.Column():
49
- with gr.Accordion("Advanced settings", open=False):
50
- with gr.Accordion("Recommended Prompt", open=True):
51
- recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
52
- with gr.Row():
53
- positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
54
- positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
55
- negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[], visible=False)
56
- negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"], visible=False)
57
- with gr.Accordion("Prompt Transformer", open=False):
58
- v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
59
- v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
60
- v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
61
- v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
62
- v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
63
- v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
64
- v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
65
- v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
66
- with gr.Accordion("Model", open=True):
67
- model_name = gr.Dropdown(label="Select Model", show_label=False, choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
68
- model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_id="model_info")
69
  with gr.Group():
70
  with gr.Accordion("Prompt from Image File", open=False):
71
  tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
@@ -82,63 +48,82 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
82
  v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
83
  random_prompt = gr.Button(value="Extend Prompt 🎲", size="sm", scale=1)
84
  clear_prompt = gr.Button(value="Clear Prompt πŸ—‘οΈ", size="sm", scale=1)
85
- prompt = gr.Text(label="Prompt", lines=1, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
86
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  with gr.Row():
88
  run_button = gr.Button("Generate Image", scale=6)
89
  random_button = gr.Button("Random Model 🎲", scale=3)
90
- image_num = gr.Number(label="Count", minimum=1, maximum=16, value=1, step=1, interactive=True, scale=1)
91
- results = gr.Gallery(label="Gallery", interactive=False, show_download_button=True, show_share_button=False,
92
- container=True, format="png", object_fit="contain")
93
- image_files = gr.Files(label="Download", interactive=False)
94
- clear_results = gr.Button("Clear Gallery / Download")
95
- examples = gr.Examples(
96
- examples = [
97
- ["souryuu asuka langley, 1girl, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors"],
98
- ["sailor moon, magical girl transformation, sparkles and ribbons, soft pastel colors, crescent moon motif, starry night sky background, shoujo manga style"],
99
- ["kafuu chino, 1girl, solo"],
100
- ["1girl"],
101
- ["beautiful sunset"],
102
- ],
103
- inputs=[prompt],
104
- )
105
- gr.Markdown(
106
- f"""This demo was created in reference to the following demos.
107
- - [Nymbo/Flood](https://huggingface.co/spaces/Nymbo/Flood).
108
- - [Yntec/ToyWorldXL](https://huggingface.co/spaces/Yntec/ToyWorldXL).
109
- """
110
- )
111
- gr.DuplicateButton(value="Duplicate Space")
 
 
 
 
 
 
 
 
 
112
 
113
- model_name.change(change_model, [model_name], [model_info], queue=False, show_api=False)
114
- gr.on(
115
- triggers=[run_button.click, prompt.submit],
116
- fn=infer_multi,
117
- inputs=[prompt, neg_prompt, results, image_num, model_name,
118
- positive_prefix, positive_suffix, negative_prefix, negative_suffix],
119
- outputs=[results],
120
- queue=True,
121
- trigger_mode="multiple",
122
- concurrency_limit=5,
123
- show_progress="full",
124
- show_api=True,
125
- ).then(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
126
- gr.on(
127
- triggers=[random_button.click],
128
- fn=infer_multi_random,
129
- inputs=[prompt, neg_prompt, results, image_num,
130
- positive_prefix, positive_suffix, negative_prefix, negative_suffix],
131
- outputs=[results],
132
- queue=True,
133
- trigger_mode="multiple",
134
- concurrency_limit=5,
135
- show_progress="full",
136
- show_api=True,
137
- ).then(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
138
- clear_prompt.click(lambda: (None, None, None), None, [prompt, v2_series, v2_character], queue=False, show_api=False)
139
  clear_results.click(lambda: (None, None), None, [results, image_files], queue=False, show_api=False)
140
  recom_prompt_preset.change(set_recom_prompt_preset, [recom_prompt_preset],
141
  [positive_prefix, positive_suffix, negative_prefix, negative_suffix], queue=False, show_api=False)
 
142
  random_prompt.click(
143
  v2_random_prompt, [prompt, v2_series, v2_character, v2_rating, v2_aspect_ratio, v2_length,
144
  v2_identity, v2_ban_tags, v2_model], [prompt, v2_series, v2_character], show_api=False,
 
1
  import gradio as gr
2
  from model import models
3
  from multit2i import (
4
+ load_models, infer_fn, infer_rand_fn, save_gallery,
5
+ change_model, warm_model, get_model_info_md, loaded_models,
6
+ get_positive_prefix, get_positive_suffix, get_negative_prefix, get_negative_suffix,
7
+ get_recom_prompt_type, set_recom_prompt_preset, get_tag_type,
 
 
 
 
 
 
 
 
 
 
8
  )
9
  from tagger.tagger import (
10
+ predict_tags_wd, remove_specific_prompt, convert_danbooru_to_e621_prompt,
11
+ insert_recom_prompt, compose_prompt_to_copy,
 
 
 
12
  )
13
  from tagger.fl2sd3longcap import predict_tags_fl2_sd3
14
+ from tagger.v2 import V2_ALL_MODELS, v2_random_prompt
 
 
 
15
  from tagger.utils import (
16
+ V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS,
17
+ V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS,
 
 
18
  )
19
 
20
 
21
+ max_images = 8
22
+ load_models(models)
 
23
 
24
  css = """
25
+ .model_info { text-align: center; }
26
+ .output { width=112px; height=112px; !important; }
27
+ .gallery { width=100%; min_height=768px; !important; }
28
  """
29
 
30
+ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
31
  with gr.Column():
32
+ with gr.Group():
33
+ model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
34
+ model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_classes="model_info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  with gr.Group():
36
  with gr.Accordion("Prompt from Image File", open=False):
37
  tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
 
48
  v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
49
  random_prompt = gr.Button(value="Extend Prompt 🎲", size="sm", scale=1)
50
  clear_prompt = gr.Button(value="Clear Prompt πŸ—‘οΈ", size="sm", scale=1)
51
+ prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
52
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
53
+ with gr.Accordion("Recommended Prompt", open=False):
54
+ recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
55
+ with gr.Row():
56
+ positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
57
+ positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
58
+ negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[], visible=False)
59
+ negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"], visible=False)
60
+ with gr.Accordion("Prompt Transformer", open=False):
61
+ v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
62
+ v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
63
+ v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
64
+ v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
65
+ v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
66
+ v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
67
+ v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
68
+ v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
69
+ image_num = gr.Slider(label="Number of images", minimum=1, maximum=max_images, value=1, step=1, interactive=True, scale=1)
70
  with gr.Row():
71
  run_button = gr.Button("Generate Image", scale=6)
72
  random_button = gr.Button("Random Model 🎲", scale=3)
73
+ stop_button = gr.Button('Stop', interactive=False, scale=1)
74
+ with gr.Column():
75
+ with gr.Group():
76
+ with gr.Row():
77
+ output = [gr.Image(label='', elem_classes="output", type="filepath", format=".png",
78
+ show_download_button=True, show_share_button=False, show_label=False,
79
+ interactive=False, min_width=80, visible=True) for _ in range(max_images)]
80
+ with gr.Group():
81
+ results = gr.Gallery(label="Gallery", elem_classes="gallery", interactive=False, show_download_button=True, show_share_button=False,
82
+ container=True, format="png", object_fit="cover", columns=2, rows=2)
83
+ image_files = gr.Files(label="Download", interactive=False)
84
+ clear_results = gr.Button("Clear Gallery / Download πŸ—‘οΈ")
85
+ with gr.Column():
86
+ examples = gr.Examples(
87
+ examples = [
88
+ ["souryuu asuka langley, 1girl, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors"],
89
+ ["sailor moon, magical girl transformation, sparkles and ribbons, soft pastel colors, crescent moon motif, starry night sky background, shoujo manga style"],
90
+ ["kafuu chino, 1girl, solo"],
91
+ ["1girl"],
92
+ ["beautiful sunset"],
93
+ ],
94
+ inputs=[prompt],
95
+ )
96
+ gr.Markdown(
97
+ f"""This demo was created in reference to the following demos.<br>
98
+ [Nymbo/Flood](https://huggingface.co/spaces/Nymbo/Flood),
99
+ [Yntec/ToyWorldXL](https://huggingface.co/spaces/Yntec/ToyWorldXL),
100
+ [Yntec/Diffusion80XX](https://huggingface.co/spaces/Yntec/Diffusion80XX).
101
+ """
102
+ )
103
+ gr.DuplicateButton(value="Duplicate Space")
104
 
105
+ gr.on(triggers=[run_button.click, prompt.submit, random_button.click], fn=lambda: gr.update(interactive=True), inputs=None, outputs=stop_button, show_api=False)
106
+ model_name.change(change_model, [model_name], [model_info], queue=False, show_api=False)\
107
+ .success(warm_model, [model_name], None, queue=True, show_api=False)
108
+ for i, o in enumerate(output):
109
+ img_i = gr.Number(i, visible=False)
110
+ image_num.change(lambda i, n: gr.update(visible = (i < n)), [img_i, image_num], o, show_api=False)
111
+ gen_event = gr.on(triggers=[run_button.click, prompt.submit],
112
+ fn=lambda i, n, m, t1, t2, l1, l2, l3, l4: infer_fn(m, t1, t2, l1, l2, l3, l4) if (i < n) else None,
113
+ inputs=[img_i, image_num, model_name, prompt, neg_prompt, positive_prefix, positive_suffix, negative_prefix, negative_suffix],
114
+ outputs=[o], queue=True, show_api=True)
115
+ gen_event2 = gr.on(triggers=[random_button.click],
116
+ fn=lambda i, n, m, t1, t2, l1, l2, l3, l4: infer_rand_fn(m, t1, t2, l1, l2, l3, l4) if (i < n) else None,
117
+ inputs=[img_i, image_num, model_name, prompt, neg_prompt, positive_prefix, positive_suffix, negative_prefix, negative_suffix],
118
+ outputs=[o], queue=True, show_api=True)
119
+ o.change(save_gallery, [o, results], [results, image_files], show_api=False)
120
+ stop_button.click(lambda: gr.update(interactive=False), None, stop_button, cancels=[gen_event, gen_event2], show_api=False)
121
+
122
+ clear_prompt.click(lambda: None, None, [prompt], queue=False, show_api=False)
 
 
 
 
 
 
 
 
123
  clear_results.click(lambda: (None, None), None, [results, image_files], queue=False, show_api=False)
124
  recom_prompt_preset.change(set_recom_prompt_preset, [recom_prompt_preset],
125
  [positive_prefix, positive_suffix, negative_prefix, negative_suffix], queue=False, show_api=False)
126
+
127
  random_prompt.click(
128
  v2_random_prompt, [prompt, v2_series, v2_character, v2_rating, v2_aspect_ratio, v2_length,
129
  v2_identity, v2_ban_tags, v2_model], [prompt, v2_series, v2_character], show_api=False,
model.py CHANGED
@@ -9,6 +9,7 @@ models = [
9
  'votepurchase/ponyDiffusionV6XL',
10
  'eienmojiki/Anything-XL',
11
  'eienmojiki/Starry-XL-v5.2',
 
12
  'digiplay/majicMIX_sombre_v2',
13
  'digiplay/majicMIX_realistic_v7',
14
  'votepurchase/counterfeitV30_v30',
@@ -19,6 +20,8 @@ models = [
19
  'Raelina/Raemu-XL-V4',
20
  ]
21
 
 
 
22
 
23
  # Examples:
24
  #models = ['yodayo-ai/kivotos-xl-2.0', 'yodayo-ai/holodayo-xl-2.1'] # specific models
 
9
  'votepurchase/ponyDiffusionV6XL',
10
  'eienmojiki/Anything-XL',
11
  'eienmojiki/Starry-XL-v5.2',
12
+ "digiplay/MilkyWonderland_v1",
13
  'digiplay/majicMIX_sombre_v2',
14
  'digiplay/majicMIX_realistic_v7',
15
  'votepurchase/counterfeitV30_v30',
 
20
  'Raelina/Raemu-XL-V4',
21
  ]
22
 
23
+ #models = find_model_list("Disty0", [], "", "last_modified", 100)
24
+
25
 
26
  # Examples:
27
  #models = ['yodayo-ai/kivotos-xl-2.0', 'yodayo-ai/holodayo-xl-2.1'] # specific models
multit2i.py CHANGED
@@ -80,31 +80,32 @@ def get_t2i_model_info_dict(repo_id: str):
80
  return info
81
 
82
 
83
- def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
 
84
  from datetime import datetime, timezone, timedelta
85
- progress(0, desc="Updating gallery...")
86
  dt_now = datetime.now(timezone(timedelta(hours=9)))
87
- basename = dt_now.strftime('%Y%m%d_%H%M%S_')
88
- i = 1
89
- if not images: return images
90
- output_images = []
91
- output_paths = []
92
- for image in images:
93
- filename = f'{image[1]}_{basename}{str(i)}.png'
94
- i += 1
95
- oldpath = Path(image[0])
96
- newpath = oldpath
97
- try:
98
- if oldpath.stem == "image" and oldpath.exists():
99
- newpath = oldpath.resolve().rename(Path(filename).resolve())
100
- except Exception as e:
101
- print(e)
102
- pass
103
- finally:
104
- output_paths.append(str(newpath))
105
- output_images.append((str(newpath), str(filename)))
106
- progress(1, desc="Gallery updated.")
107
- return gr.update(value=output_images), gr.update(value=output_paths)
108
 
109
 
110
  # https://github.com/gradio-app/gradio/blob/main/gradio/external.py
@@ -124,7 +125,7 @@ def load_from_model(model_name: str, hf_token: str = None):
124
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
125
  )
126
  headers["X-Wait-For-Model"] = "true"
127
- client = huggingface_hub.InferenceClient(model=model_name, headers=headers, token=hf_token, timeout=300)
128
  inputs = gr.components.Textbox(label="Input")
129
  outputs = gr.components.Image(label="Output")
130
  fn = client.text_to_image
@@ -163,28 +164,9 @@ def load_model(model_name: str):
163
  return loaded_models[model_name]
164
 
165
 
166
- async def async_load_models(models: list, limit: int=5):
167
- sem = asyncio.Semaphore(limit)
168
- async def async_load_model(model: str):
169
- async with sem:
170
- try:
171
- await asyncio.sleep(0.5)
172
- return await asyncio.to_thread(load_model, model)
173
- except Exception as e:
174
- print(e)
175
- tasks = [asyncio.create_task(async_load_model(model)) for model in models]
176
- return await asyncio.gather(*tasks, return_exceptions=True)
177
-
178
-
179
- def load_models(models: list, limit: int=5):
180
- loop = asyncio.new_event_loop()
181
- try:
182
- loop.run_until_complete(async_load_models(models, limit))
183
- except Exception as e:
184
- print(e)
185
- pass
186
- finally:
187
- loop.close()
188
 
189
 
190
  positive_prefix = {
@@ -298,72 +280,73 @@ def change_model(model_name: str):
298
  return get_model_info_md(model_name)
299
 
300
 
301
- def infer(prompt: str, neg_prompt: str, model_name: str):
302
- from PIL import Image
 
 
 
 
 
 
 
 
 
303
  import random
304
- seed = ""
305
  rand = random.randint(1, 500)
306
  for i in range(rand):
307
- seed += " "
308
- caption = model_name.split("/")[-1]
 
 
 
309
  try:
310
- model = load_model(model_name)
311
- if not model: return (Image.Image(), None)
312
- image_path = model(prompt + seed, neg_prompt)
313
- image = Image.open(image_path).convert('RGBA')
314
- except Exception as e:
315
  print(e)
316
- return (Image.Image(), None)
317
- return (image, caption)
318
-
319
-
320
- async def infer_multi(prompt: str, neg_prompt: str, results: list, image_num: float, model_name: str,
321
- pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
322
- import asyncio
323
- progress(0, desc="Start inference.")
324
- image_num = int(image_num)
325
- images = results if results else []
326
- image_num_offset = len(images)
327
- prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
328
- tasks = [asyncio.create_task(asyncio.to_thread(infer, prompt, neg_prompt, model_name)) for i in range(image_num)]
329
- await asyncio.sleep(0)
330
- for task in tasks:
331
- progress(float(len(images) - image_num_offset) / float(image_num), desc="Running inference.")
332
- try:
333
- result = await asyncio.wait_for(task, timeout=120)
334
- except (Exception, asyncio.TimeoutError) as e:
335
- print(e)
336
- if not task.done(): task.cancel()
337
- result = None
338
- image_num_offset += 1
339
  with lock:
340
- if result and len(result) == 2 and result[1]: images.append(result)
341
- await asyncio.sleep(0)
342
- yield images
 
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
- async def infer_multi_random(prompt: str, neg_prompt: str, results: list, image_num: float,
346
- pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
 
347
  import random
348
- progress(0, desc="Start inference.")
349
- image_num = int(image_num)
350
- images = results if results else []
351
- image_num_offset = len(images)
352
  random.seed()
353
- model_names = random.choices(list(loaded_models.keys()), k = image_num)
354
- prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
355
- tasks = [asyncio.create_task(asyncio.to_thread(infer, prompt, neg_prompt, model_name)) for model_name in model_names]
356
- await asyncio.sleep(0)
357
- for task in tasks:
358
- progress(float(len(images) - image_num_offset) / float(image_num), desc="Running inference.")
359
- try:
360
- result = await asyncio.wait_for(task, timeout=120)
361
- except (Exception, asyncio.TimeoutError) as e:
362
- print(e)
363
- if not task.done(): task.cancel()
364
- result = None
365
- image_num_offset += 1
366
- with lock:
367
- if result and len(result) == 2 and result[1]: images.append(result)
368
- await asyncio.sleep(0)
369
- yield images
 
80
  return info
81
 
82
 
83
+ def rename_image(image_path: str | None, model_name: str):
84
+ from PIL import Image
85
  from datetime import datetime, timezone, timedelta
86
+ if image_path is None: return None
87
  dt_now = datetime.now(timezone(timedelta(hours=9)))
88
+ filename = f"{model_name.split('/')[-1]}_{dt_now.strftime('%Y%m%d_%H%M%S')}.png"
89
+ try:
90
+ if Path(image_path).exists():
91
+ png_path = "image.png"
92
+ Image.open(image_path).convert('RGBA').save(png_path, "PNG")
93
+ new_path = str(Path(png_path).resolve().rename(Path(filename).resolve()))
94
+ return new_path
95
+ else:
96
+ return None
97
+ except Exception as e:
98
+ print(e)
99
+ return None
100
+
101
+
102
+ def save_gallery(image_path: str | None, images: list[tuple] | None):
103
+ if images is None: images = []
104
+ files = [i[0] for i in images]
105
+ if image_path is None: return images, files
106
+ files.insert(0, str(image_path))
107
+ images.insert(0, (str(image_path), Path(image_path).stem))
108
+ return images, files
109
 
110
 
111
  # https://github.com/gradio-app/gradio/blob/main/gradio/external.py
 
125
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
126
  )
127
  headers["X-Wait-For-Model"] = "true"
128
+ client = huggingface_hub.InferenceClient(model=model_name, headers=headers, token=hf_token, timeout=600)
129
  inputs = gr.components.Textbox(label="Input")
130
  outputs = gr.components.Image(label="Output")
131
  fn = client.text_to_image
 
164
  return loaded_models[model_name]
165
 
166
 
167
+ def load_models(models: list):
168
+ for model in models:
169
+ load_model(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
 
172
  positive_prefix = {
 
280
  return get_model_info_md(model_name)
281
 
282
 
283
+ def warm_model(model_name: str):
284
+ model = load_model(model_name)
285
+ if model:
286
+ try:
287
+ print(f"Warming model: {model_name}")
288
+ model(" ")
289
+ except Exception as e:
290
+ print(e)
291
+
292
+
293
+ async def infer(model_name: str, prompt: str, neg_prompt: str, timeout: float):
294
  import random
295
+ noise = ""
296
  rand = random.randint(1, 500)
297
  for i in range(rand):
298
+ noise += " "
299
+ model = load_model(model_name)
300
+ if not model: return None
301
+ task = asyncio.create_task(asyncio.to_thread(model, f'{prompt} {noise}'))
302
+ await asyncio.sleep(0)
303
  try:
304
+ result = await asyncio.wait_for(task, timeout=timeout)
305
+ except (Exception, asyncio.TimeoutError) as e:
 
 
 
306
  print(e)
307
+ print(f"Task timed out: {model_name}")
308
+ if not task.done(): task.cancel()
309
+ result = None
310
+ if task.done() and result is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  with lock:
312
+ image = rename_image(result, model_name)
313
+ return image
314
+ return None
315
+
316
 
317
+ infer_timeout = 300
318
+ def infer_fn(model_name: str, prompt: str, neg_prompt: str,
319
+ pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
320
+ if model_name == 'NA':
321
+ return None
322
+ try:
323
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
324
+ loop = asyncio.new_event_loop()
325
+ result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, infer_timeout))
326
+ except (Exception, asyncio.CancelledError) as e:
327
+ print(e)
328
+ print(f"Task aborted: {model_name}")
329
+ result = None
330
+ finally:
331
+ loop.close()
332
+ return result
333
 
334
+
335
+ def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str,
336
+ pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
337
  import random
338
+ if model_name_dummy == 'NA':
339
+ return None
 
 
340
  random.seed()
341
+ model_name = random.choice(list(loaded_models.keys()))
342
+ try:
343
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
344
+ loop = asyncio.new_event_loop()
345
+ result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, infer_timeout))
346
+ except (Exception, asyncio.CancelledError) as e:
347
+ print(e)
348
+ print(f"Task aborted: {model_name}")
349
+ result = None
350
+ finally:
351
+ loop.close()
352
+ return result
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  huggingface_hub
2
- torch
3
  torchvision
4
  accelerate
5
  transformers
 
1
  huggingface_hub
2
+ torch==2.2.0
3
  torchvision
4
  accelerate
5
  transformers