Update nsfw_detector.py

#1
by sewanapi - opened
Files changed (1) hide show
  1. nsfw_detector.py +0 -65
nsfw_detector.py CHANGED
@@ -1,65 +0,0 @@
1
- from torchvision.transforms import Normalize
2
- import torchvision.transforms as T
3
- import torch.nn as nn
4
- from PIL import Image
5
- import numpy as np
6
- import torch
7
- import timm
8
- from tqdm import tqdm
9
-
10
- # https://github.com/Whiax/NSFW-Classifier/raw/main/nsfwmodel_281.pth
11
- normalize_t = Normalize((0.4814, 0.4578, 0.4082), (0.2686, 0.2613, 0.2757))
12
-
13
- #nsfw classifier
14
- class NSFWClassifier(nn.Module):
15
- def __init__(self):
16
- super().__init__()
17
- nsfw_model=self
18
- nsfw_model.root_model = timm.create_model('convnext_base_in22ft1k', pretrained=True)
19
- nsfw_model.linear_probe = nn.Linear(1024, 1, bias=False)
20
-
21
- def forward(self, x):
22
- nsfw_model = self
23
- x = normalize_t(x)
24
- x = nsfw_model.root_model.stem(x)
25
- x = nsfw_model.root_model.stages(x)
26
- x = nsfw_model.root_model.head.global_pool(x)
27
- x = nsfw_model.root_model.head.norm(x)
28
- x = nsfw_model.root_model.head.flatten(x)
29
- x = nsfw_model.linear_probe(x)
30
- return x
31
-
32
- def is_nsfw(self, img_paths, threshold = 0.98):
33
- skip_step = 1
34
- total_len = len(img_paths)
35
- if total_len < 100: skip_step = 1
36
- if total_len > 100 and total_len < 500: skip_step = 10
37
- if total_len > 500 and total_len < 1000: skip_step = 20
38
- if total_len > 1000 and total_len < 10000: skip_step = 50
39
- if total_len > 10000: skip_step = 100
40
-
41
- for idx in tqdm(range(0, total_len, skip_step), total=int(total_len // skip_step), desc="Checking for NSFW contents"):
42
- _img = Image.open(img_paths[idx]).convert('RGB')
43
- img = _img.resize((224, 224))
44
- img = np.array(img)/255
45
- img = T.ToTensor()(img).unsqueeze(0).float()
46
- if next(self.parameters()).is_cuda:
47
- img = img.cuda()
48
- with torch.no_grad():
49
- score = self.forward(img).sigmoid()[0].item()
50
- if score > threshold:
51
- print(f"Detected nsfw score:{score}")
52
- _img.save("nsfw.jpg")
53
- return True
54
- return False
55
-
56
- def get_nsfw_detector(model_path='nsfwmodel_281.pth', device="cpu"):
57
- #load base model
58
- nsfw_model = NSFWClassifier()
59
- nsfw_model = nsfw_model.eval()
60
- #load linear weights
61
- linear_pth = model_path
62
- linear_state_dict = torch.load(linear_pth, map_location='cpu')
63
- nsfw_model.linear_probe.load_state_dict(linear_state_dict)
64
- nsfw_model = nsfw_model.to(device)
65
- return nsfw_model