Image2Body_gradio / scripts /generate_prompt.py
yeq6x's picture
prepare spaces
721391f
raw
history blame
5.31 kB
import argparse
import csv
import os
import json
from PIL import Image
import cv2
import numpy as np
from tensorflow.keras.layers import TFSMLayer
from huggingface_hub import hf_hub_download
from pathlib import Path
# from wd14 tagger
IMAGE_SIZE = 448
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
def load_wd14_tagger_model():
model_dir = "wd14_tagger_model"
repo_id = DEFAULT_WD14_TAGGER_REPO
if not os.path.exists(model_dir):
print(f"downloading wd14 tagger model from hf_hub. id: {repo_id}")
for file in FILES:
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
else:
print("using existing wd14 tagger model")
# モデルを読み込む
model = TFSMLayer(model_dir, call_endpoint='serving_default')
return model
def generate_tags(images, model_dir, model):
with open(os.path.join(model_dir, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
tag_freq = {}
undesired_tags = ['one-piece_swimsuit',
'swimsuit',
'leotard',
'saitama_(one-punch_man)',
'1boy',
]
probs = model(images, training=False)
probs = probs['predictions_sigmoid'].numpy()
tag_text_list = []
for prob in probs:
combined_tags = []
general_tag_text = ""
character_tag_text = ""
thresh = 0.35
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= thresh:
tag_name = general_tags[i]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= thresh:
tag_name = character_tags[i - len(general_tags)]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
combined_tags.append(tag_name)
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[2:]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
tag_text = ", ".join(combined_tags)
tag_text_list.append(tag_text)
return tag_text_list
def generate_prompt_json(target_folder, prompt_file, model_dir, model):
image_files = [f for f in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, f))]
image_count = len(image_files)
prompt_list = []
for i, filename in enumerate(image_files, 1):
source_path = "source/" + filename
target_path = os.path.join(target_folder, filename) # Use absolute path
target_path2 = "target/" + filename
prompt = generate_tags(target_path, model_dir, model)
for j in range(4):
prompt_data = {
"source": f"{source_path.split('.')[0]}_{j}.jpg",
"target": f"{target_path2.split('.')[0]}_{j}.jpg",
"prompt": prompt
}
prompt_list.append(prompt_data)
print(f"Processed Images: {i}/{image_count}", end="\r", flush=True)
with open(prompt_file, "w") as file:
for prompt_data in prompt_list:
json.dump(prompt_data, file)
file.write("\n")
print(f"Processing completed. Total Images: {image_count}")
if __name__ == '__main__':
model_dir = "wd14_tagger_model"
model = load_wd14_tagger_model()
prompt = generate_tags(target_path, model_dir, model)