Spaces:
Running
Running
Update nsfw_detector.py
#1
by
sewanapi
- opened
- nsfw_detector.py +0 -65
nsfw_detector.py
CHANGED
@@ -1,65 +0,0 @@
|
|
1 |
-
from torchvision.transforms import Normalize
|
2 |
-
import torchvision.transforms as T
|
3 |
-
import torch.nn as nn
|
4 |
-
from PIL import Image
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
import timm
|
8 |
-
from tqdm import tqdm
|
9 |
-
|
10 |
-
# https://github.com/Whiax/NSFW-Classifier/raw/main/nsfwmodel_281.pth
|
11 |
-
normalize_t = Normalize((0.4814, 0.4578, 0.4082), (0.2686, 0.2613, 0.2757))
|
12 |
-
|
13 |
-
#nsfw classifier
|
14 |
-
class NSFWClassifier(nn.Module):
|
15 |
-
def __init__(self):
|
16 |
-
super().__init__()
|
17 |
-
nsfw_model=self
|
18 |
-
nsfw_model.root_model = timm.create_model('convnext_base_in22ft1k', pretrained=True)
|
19 |
-
nsfw_model.linear_probe = nn.Linear(1024, 1, bias=False)
|
20 |
-
|
21 |
-
def forward(self, x):
|
22 |
-
nsfw_model = self
|
23 |
-
x = normalize_t(x)
|
24 |
-
x = nsfw_model.root_model.stem(x)
|
25 |
-
x = nsfw_model.root_model.stages(x)
|
26 |
-
x = nsfw_model.root_model.head.global_pool(x)
|
27 |
-
x = nsfw_model.root_model.head.norm(x)
|
28 |
-
x = nsfw_model.root_model.head.flatten(x)
|
29 |
-
x = nsfw_model.linear_probe(x)
|
30 |
-
return x
|
31 |
-
|
32 |
-
def is_nsfw(self, img_paths, threshold = 0.98):
|
33 |
-
skip_step = 1
|
34 |
-
total_len = len(img_paths)
|
35 |
-
if total_len < 100: skip_step = 1
|
36 |
-
if total_len > 100 and total_len < 500: skip_step = 10
|
37 |
-
if total_len > 500 and total_len < 1000: skip_step = 20
|
38 |
-
if total_len > 1000 and total_len < 10000: skip_step = 50
|
39 |
-
if total_len > 10000: skip_step = 100
|
40 |
-
|
41 |
-
for idx in tqdm(range(0, total_len, skip_step), total=int(total_len // skip_step), desc="Checking for NSFW contents"):
|
42 |
-
_img = Image.open(img_paths[idx]).convert('RGB')
|
43 |
-
img = _img.resize((224, 224))
|
44 |
-
img = np.array(img)/255
|
45 |
-
img = T.ToTensor()(img).unsqueeze(0).float()
|
46 |
-
if next(self.parameters()).is_cuda:
|
47 |
-
img = img.cuda()
|
48 |
-
with torch.no_grad():
|
49 |
-
score = self.forward(img).sigmoid()[0].item()
|
50 |
-
if score > threshold:
|
51 |
-
print(f"Detected nsfw score:{score}")
|
52 |
-
_img.save("nsfw.jpg")
|
53 |
-
return True
|
54 |
-
return False
|
55 |
-
|
56 |
-
def get_nsfw_detector(model_path='nsfwmodel_281.pth', device="cpu"):
|
57 |
-
#load base model
|
58 |
-
nsfw_model = NSFWClassifier()
|
59 |
-
nsfw_model = nsfw_model.eval()
|
60 |
-
#load linear weights
|
61 |
-
linear_pth = model_path
|
62 |
-
linear_state_dict = torch.load(linear_pth, map_location='cpu')
|
63 |
-
nsfw_model.linear_probe.load_state_dict(linear_state_dict)
|
64 |
-
nsfw_model = nsfw_model.to(device)
|
65 |
-
return nsfw_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|