Update README.md
Browse files
README.md
CHANGED
@@ -68,12 +68,12 @@ inputs = processor.preprocess(image, return_tensors="pt")
|
|
68 |
|
69 |
with torch.no_grad():
|
70 |
outputs = model(**inputs.to(model.device, model.dtype))
|
71 |
-
logits = torch.sigmoid(outputs.logits[0])
|
72 |
|
73 |
# get probabilities
|
74 |
results = {model.config.id2label[i]: logit.float() for i, logit in enumerate(logits)}
|
75 |
results = {
|
76 |
-
k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
|
77 |
}
|
78 |
print(results) # rating tags and character tags are also included
|
79 |
#{'1girl': tensor(0.9974),
|
|
|
68 |
|
69 |
with torch.no_grad():
|
70 |
outputs = model(**inputs.to(model.device, model.dtype))
|
71 |
+
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
72 |
|
73 |
# get probabilities
|
74 |
results = {model.config.id2label[i]: logit.float() for i, logit in enumerate(logits)}
|
75 |
results = {
|
76 |
+
k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True) if v > 0.35 # 35% threshold
|
77 |
}
|
78 |
print(results) # rating tags and character tags are also included
|
79 |
#{'1girl': tensor(0.9974),
|