Kohaku-Blueleaf
first commit
7d4afe8
import time
import pathlib
import kgen.models as models
from kgen.formatter import seperate_tags, apply_format, apply_dtg_prompt
from kgen.metainfo import TARGET
from kgen.generate import tag_gen
from kgen.logging import logger
SEED_MAX = 2**31 - 1
DEFAULT_FORMAT = """<|special|>,
<|characters|>, <|copyrights|>,
<|artist|>,
<|general|>,
<|quality|>, <|meta|>, <|rating|>"""
def process(
prompt: str,
aspect_ratio: float,
seed: int,
tag_length: str,
ban_tags: str,
format: str,
temperature: float,
):
propmt_preview = prompt.replace("\n", " ")[:40]
logger.info(f"Processing propmt: {propmt_preview}...")
logger.info(f"Processing with seed: {seed}")
black_list = [tag.strip() for tag in ban_tags.split(",") if tag.strip()]
all_tags = [tag.strip() for tag in prompt.strip().split(",") if tag.strip()]
tag_length = tag_length.replace(" ", "_")
len_target = TARGET[tag_length]
tag_map = seperate_tags(all_tags)
dtg_prompt = apply_dtg_prompt(tag_map, tag_length, aspect_ratio)
for _, extra_tokens, iter_count in tag_gen(
models.text_model,
models.tokenizer,
dtg_prompt,
tag_map["special"] + tag_map["general"],
len_target,
black_list,
temperature=temperature,
top_p=0.8,
top_k=80,
max_new_tokens=512,
max_retry=10,
max_same_output=5,
seed=seed % SEED_MAX,
):
pass
tag_map["general"] += extra_tokens
prompt_by_dtg = apply_format(tag_map, format)
logger.info(
"Prompt processing done. General Tags Count: "
f"{len(tag_map['general'] + tag_map['special'])}"
f" | Total iterations: {iter_count}"
)
return prompt_by_dtg
if __name__ == "__main__":
models.model_dir = pathlib.Path(__file__).parent / "models"
file = models.download_gguf()
files = models.list_gguf()
file = files[-1]
logger.info(f"Use gguf model from local file: {file}")
models.load_model(file, gguf=True)
prompt = """
1girl, ask (askzy), masterpiece
"""
t0 = time.time_ns()
result = process(
prompt,
aspect_ratio=1.0,
seed=1,
tag_length="long",
ban_tags="",
format=DEFAULT_FORMAT,
temperature=1.35,
)
t1 = time.time_ns()
logger.info(f"Result:\n{result}")
logger.info(f"Time cost: {(t1 - t0) / 10**6:.1f}ms")