import torch import gradio as gr from optimum.onnxruntime import ORTModelForCausalLM from transformers import AutoTokenizer from huggingface_hub import InferenceClient # https://huggingface.co/collections/p1atdev/dart-v2-danbooru-tags-transformer-v2-66291115701b6fe773399b0a model_id = "p1atdev/dart-v2-sft" model = ORTModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True) txt2imgclient = InferenceClient() # https://huggingface.co/docs/transformers/v4.44.2/en/internal/generation_utils#transformers.NoBadWordsLogitsProcessor def get_tokens_as_list(word_list): "Converts a sequence of words into a list of tokens" tokens_list = [] for word in word_list: tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0] tokens_list.append(tokenized_word) return tokens_list def generate_tags(general_tags: str, generate_image: bool = False): # https://huggingface.co/p1atdev/dart-v2-sft#prompt-format general_tags = ",".join(tag.strip() for tag in general_tags.split(",") if tag) prompt = ( "<|bos|>" # "" # "" "<|rating:general|><|aspect_ratio:tall|><|length:medium|>" f"{general_tags}<|identity:none|><|input_end|>" ) inputs = tokenizer(prompt, return_tensors="pt").input_ids # bad_words_ids = get_tokens_as_list(word_list=[""]) with torch.no_grad(): outputs = model.generate( inputs, do_sample=True, temperature=1.0, top_p=1.0, top_k=100, max_new_tokens=128, num_beams=1, # bad_words_ids=bad_words_ids, ) output_tags = ", ".join( [tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""] ) yield (output_tags, None) if generate_image: txt2img_prompt = f"score_9, score_8_up, score_7_up, {output_tags}" img = txt2imgclient.text_to_image( prompt=txt2img_prompt, negative_prompt="score_6, score_5, score_4, rating_explicit, child, loli, shota", num_inference_steps=25, height=1152, width=896, model="John6666/wai-real-mix-v8-sdxl", scheduler="EulerAncestralDiscreteScheduler", ) yield (output_tags, img) demo = gr.Interface( fn=generate_tags, inputs=[ gr.TextArea("1girl, black hair", lines=4), gr.Checkbox( False, label="Generate Image", info="Generating image using InferenceClient (really slow) with output_tags as prompt", ), ], outputs=[ gr.Textbox(label="output_tags", show_copy_button=True), gr.Image(label="generated_image", format="jpeg", type="pil"), ], clear_btn=None, analytics_enabled=False, concurrency_limit=64, ) demo.queue().launch()