Vivien Chappelier
commited on
Commit
•
71f9973
1
Parent(s):
b747147
add second class in export to make it compatible with inference api
Browse files- calibration.safetensors +1 -1
- export_detector.py +20 -6
calibration.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1999934
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ebfb5a2481c315bca61b2348b0268d621a351d48cf5e07785d45d5877f67ebf3
|
3 |
size 1999934
|
export_detector.py
CHANGED
@@ -10,9 +10,6 @@ from PIL import Image
|
|
10 |
# read logits file
|
11 |
data=(np.asarray([float(x) for x in open(sys.argv[1]).readlines()]))
|
12 |
|
13 |
-
# negate for consistency with "1" = "watermarked"
|
14 |
-
data = -data
|
15 |
-
|
16 |
# sort and convert to safetensors format
|
17 |
data = np.sort(data)
|
18 |
data_min = data.min()
|
@@ -47,12 +44,29 @@ image_processor = BlipImageProcessor(do_resize=True,
|
|
47 |
|
48 |
detector = AutoModelForImageClassification.from_pretrained(detector_path)
|
49 |
|
50 |
-
# make it output "
|
51 |
detector.eval()
|
52 |
with torch.no_grad():
|
53 |
-
detector.classifier[1].weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
|
56 |
#image_processor.push_to_hub("imatag/stable-signature-bzh-detector-resnet18")
|
57 |
examples = ['examples/not_watermarked.png', 'examples/watermarked.png']
|
58 |
|
|
|
10 |
# read logits file
|
11 |
data=(np.asarray([float(x) for x in open(sys.argv[1]).readlines()]))
|
12 |
|
|
|
|
|
|
|
13 |
# sort and convert to safetensors format
|
14 |
data = np.sort(data)
|
15 |
data_min = data.min()
|
|
|
44 |
|
45 |
detector = AutoModelForImageClassification.from_pretrained(detector_path)
|
46 |
|
47 |
+
# make it output 2 labels, with label "0" for "watermarked"
|
48 |
detector.eval()
|
49 |
with torch.no_grad():
|
50 |
+
w0 = detector.classifier[1].weight
|
51 |
+
b0 = detector.classifier[1].bias
|
52 |
+
fdim = w0.shape[1]
|
53 |
+
w = torch.nn.Parameter(torch.zeros((2, fdim), dtype=w0.dtype))
|
54 |
+
w[0, :] = w0
|
55 |
+
w[1, :] = -w0
|
56 |
+
detector.classifier[1].weight = w
|
57 |
+
b = torch.nn.Parameter(torch.zeros((2,), dtype=b0.dtype))
|
58 |
+
b[0] = b0
|
59 |
+
b[1] = -b0
|
60 |
+
detector.classifier[1].bias = b
|
61 |
+
labels = ["no watermark detected", "watermarked"]
|
62 |
+
label2id, id2label = dict(), dict()
|
63 |
+
for i, label in enumerate(labels):
|
64 |
+
label2id[label] = str(i)
|
65 |
+
id2label[str(i)] = label
|
66 |
+
detector.config.id2label=id2label
|
67 |
+
detector.config.label2id=label2id
|
68 |
|
69 |
+
detector.push_to_hub("imatag/stable-signature-bzh-detector-resnet18")
|
70 |
#image_processor.push_to_hub("imatag/stable-signature-bzh-detector-resnet18")
|
71 |
examples = ['examples/not_watermarked.png', 'examples/watermarked.png']
|
72 |
|