jerukperas commited on
Commit
13e5846
1 Parent(s): b411b2e

Update application files

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +58 -3
  3. requirements.txt +7 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Dart
3
- emoji: 👁
4
  colorFrom: yellow
5
  colorTo: pink
6
  sdk: gradio
 
1
  ---
2
+ title: dart
3
+ emoji: 🏷️
4
  colorFrom: yellow
5
  colorTo: pink
6
  sdk: gradio
app.py CHANGED
@@ -1,7 +1,62 @@
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
+ import torch
2
  import gradio as gr
3
+ from optimum.onnxruntime import ORTModelForCausalLM
4
+ from transformers import AutoTokenizer
5
 
6
+ # https://huggingface.co/collections/p1atdev/dart-v2-danbooru-tags-transformer-v2-66291115701b6fe773399b0a
7
+ model_id = "p1atdev/dart-v2-sft"
8
+ model = ORTModelForCausalLM.from_pretrained(model_id)
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True)
11
+
12
+
13
+ # https://huggingface.co/docs/transformers/v4.44.2/en/internal/generation_utils#transformers.NoBadWordsLogitsProcessor
14
+ def get_tokens_as_list(word_list):
15
+ "Converts a sequence of words into a list of tokens"
16
+ tokens_list = []
17
+ for word in word_list:
18
+ tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
19
+ tokens_list.append(tokenized_word)
20
+ return tokens_list
21
+
22
+
23
+ def generate_tags(general_tags: str):
24
+ # https://huggingface.co/p1atdev/dart-v2-sft#prompt-format
25
+ general_tags = ",".join(tag.strip() for tag in general_tags.split(",") if tag)
26
+ prompt = (
27
+ "<|bos|>"
28
+ # "<copyright></copyright>"
29
+ # "<character></character>"
30
+ "<|rating:general|><|aspect_ratio:tall|><|length:long|>"
31
+ f"<general>{general_tags}<|identity:none|><|input_end|>"
32
+ )
33
+
34
+ inputs = tokenizer(prompt, return_tensors="pt").input_ids
35
+ # bad_words_ids = get_tokens_as_list(word_list=[""])
36
+
37
+ with torch.no_grad():
38
+ outputs = model.generate(
39
+ inputs,
40
+ do_sample=True,
41
+ temperature=1.0,
42
+ top_p=1.0,
43
+ top_k=100,
44
+ max_new_tokens=128,
45
+ num_beams=1,
46
+ # bad_words_ids=bad_words_ids,
47
+ )
48
+
49
+ return ", ".join(
50
+ [tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""]
51
+ )
52
+
53
+
54
+ demo = gr.Interface(
55
+ fn=generate_tags,
56
+ inputs=gr.TextArea("1girl, black hair", lines=4),
57
+ outputs=gr.Textbox(show_copy_button=True),
58
+ clear_btn=None,
59
+ analytics_enabled=False,
60
+ )
61
 
 
62
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Comment out the line below for default installation on Linux
2
+ --extra-index-url https://download.pytorch.org/whl/cpu
3
+
4
+ gradio==4.42.0
5
+ torch
6
+ transformers
7
+ optimum[onnxruntime] # or optimum[onnxruntime-gpu]