|
import json |
|
import time |
|
from PIL import Image |
|
import torch |
|
from torchvision.transforms import transforms |
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
snapshot_download(repo_id="Thouph/eva-vit-1b-224-8043", local_dir = "./") |
|
|
|
|
|
model = torch.load('model.pth', map_location=torch.device('cpu')) |
|
model.eval() |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[ |
|
0.48145466, |
|
0.4578275, |
|
0.40821073 |
|
], std=[ |
|
0.26862954, |
|
0.26130258, |
|
0.27577711 |
|
]) |
|
]) |
|
|
|
|
|
with open("tags_8040.json", "r") as file: |
|
tags = json.load(file) |
|
allowed_tags = sorted(tags) |
|
allowed_tags.append("explicit") |
|
allowed_tags.append("questionable") |
|
allowed_tags.append("safe") |
|
|
|
def create_tags(image, thres): |
|
img = image.convert('RGB') |
|
tensor = transform(img).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
out = model(tensor) |
|
probabilities = torch.nn.functional.sigmoid(out[0]) |
|
indices = torch.where(probabilities > thres)[0] |
|
values = probabilities[indices] |
|
|
|
temp = [] |
|
for i in range(indices.size(0)): |
|
temp.append([allowed_tags[indices[i]], values[i].item()]) |
|
text = "" |
|
for i in range(len(temp)): |
|
text += temp[i][0] + (', ' if i < len(temp) - 1 else '') |
|
text = text.replace(r"placeholder1, ", "") |
|
text = text.replace("_", " ") |
|
text = text.replace("(", "\\(").replace(")", "\\)") |
|
print(text) |
|
return text |
|
|
|
demo = gr.Interface( |
|
fn=create_tags, |
|
inputs=[gr.Image(type="pil"), gr.Slider(minimum=0, maximum = 1, step = 0.01, value = 0.3)], |
|
outputs=["text"], |
|
) |
|
demo.queue() |
|
demo.launch(debug = True) |
|
|