m3 commited on
Commit
e97054d
1 Parent(s): 8e91c3f

chore: add sscd models

Browse files
config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "model_type": "sscd-copy-detection",
3
+ "model_path": "sscd_disc_mixup.torchscript.pt"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4eeba5dba464b0a8cc110ef7d1c3f261cb001f5dce9a39123390b700c20172b
3
+ size 104
preprocessor_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "size": 288,
3
+ "do_resize": true,
4
+ "image_mean": [
5
+ 0.485,
6
+ 0.456,
7
+ 0.406
8
+ ],
9
+ "image_processor_type": "SscdImageProcessor",
10
+ "image_std": [
11
+ 0.229,
12
+ 0.224,
13
+ 0.225
14
+ ],
15
+ "do_convert_rgb": true
16
+ }
src/model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+
5
+ from transformers.image_processing_utils import BaseImageProcessor
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+ import os
8
+ from huggingface_hub import hf_hub_download
9
+ import torch
10
+ import torch.nn as nn
11
+ class SscdImageProcessor(BaseImageProcessor):
12
+ def __init__(
13
+ self,
14
+ do_resize: bool = True,
15
+ size: int = 288,
16
+ image_mean: Optional[Union[float, List[float]]] = None,
17
+ image_std: Optional[Union[float, List[float]]] = None,
18
+ do_convert_rgb: bool = True,
19
+ **kwargs,
20
+ ) -> None:
21
+ super().__init__(**kwargs)
22
+ self.size = size
23
+ self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406]
24
+ self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225]
25
+ self.do_convert_rgb = do_convert_rgb
26
+ self.do_resize = do_resize
27
+
28
+ def preprocess(
29
+ self,
30
+ image: Image,
31
+ do_resize: bool = None,
32
+ **kwargs,
33
+ ):
34
+ size_transforms = [
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(
37
+ mean=self.image_mean, std=self.image_std,
38
+ ),
39
+ ]
40
+ if do_resize is None:
41
+ do_resize = self.do_resize
42
+ if do_resize:
43
+ size_transforms.append(transforms.Resize(self.size))
44
+ preprocess = transforms.Compose([
45
+ transforms.Resize(self.size),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(
48
+ mean=self.image_mean, std=self.image_std,
49
+ ),
50
+ ])
51
+ if self.do_convert_rgb:
52
+ image = image.convert('RGB')
53
+ return preprocess(image).unsqueeze(0)
54
+
55
+ class SscdConfig(PretrainedConfig):
56
+ model_type = 'sscd-copy-detection'
57
+ def __init__(self, model_path: str = None, **kwargs):
58
+ if model_path is None:
59
+ model_path = 'sscd_disc_mixup.torchscript.pt'
60
+ super().__init__(model_path=model_path, **kwargs)
61
+
62
+ class SscdModel(PreTrainedModel):
63
+ config_class = SscdConfig
64
+
65
+ def __init__(self, config):
66
+ super().__init__(config)
67
+ self.dummy_param = nn.Parameter(torch.zeros(0))
68
+
69
+ print("______", config.name_or_path)
70
+
71
+ is_local = os.path.isdir(config.name_or_path)
72
+ if is_local:
73
+ config.base_path = config.name_or_path
74
+ else:
75
+ config_path = hf_hub_download(repo_id=config.name_or_path, filename='config.json')
76
+ config.base_path = os.path.dirname(config_path)
77
+ model_path = config.base_path + '/' + config.model_path
78
+ print("___model_path___", model_path)
79
+
80
+ def forward(self, inputs):
81
+ return self.model(inputs)
82
+
83
+ sscd_processor = SscdImageProcessor()
84
+ sscd_processor.save_pretrained('new_model')
85
+ sscd_config = SscdConfig(model_path='sscd_disc_mixup.torchscript.pt')
86
+ sscd_config.save_pretrained('new_model')
87
+
88
+ model = SscdModel.from_pretrained('new_model')
89
+
90
+
91
+
sscd_disc_advanced.torchscript.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3d50eb1056ca163de7c31c5bba739e21705cdc65c5802628716420286100ff2
3
+ size 98795803
sscd_disc_blur.torchscript.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1573bb7f3f62051c9a63c6f20b38e0215499dc24ca6d21255783342aa6c43ae
3
+ size 98790335
sscd_disc_mixup.torchscript.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f26bd4c848cc19b73d2ae92eea6e04886f61a7b764ceb7a13aeee62e6a6db56
3
+ size 98791638