Refactor test_on_image.py
Browse files- test_on_image.py +14 -11
test_on_image.py
CHANGED
@@ -8,10 +8,8 @@ from transformers import AutoTokenizer
|
|
8 |
from modeling_hybrid_clip import FlaxHybridCLIP
|
9 |
from run_hybrid_clip import Transform
|
10 |
|
11 |
-
model = FlaxHybridCLIP.from_pretrained("clip_spanish_1_percent")
|
12 |
-
tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
|
13 |
|
14 |
-
def prepare_image(image_path):
|
15 |
image = read_image(image_path, mode=ImageReadMode.RGB)
|
16 |
preprocess = Transform(model.config.vision_config.image_size)
|
17 |
preprocess = torch.jit.script(preprocess)
|
@@ -19,18 +17,23 @@ def prepare_image(image_path):
|
|
19 |
pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
|
20 |
return pixel_values
|
21 |
|
22 |
-
def prepare_text(text):
|
23 |
return tokenizer(text, return_tensors="np")
|
24 |
|
25 |
-
def run_inference(image_path, text):
|
26 |
-
pixel_values = prepare_image(image_path)
|
27 |
-
input_text = prepare_text(text)
|
28 |
model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], token_type_ids=input_text["token_type_ids"], train=False, return_dict=True)
|
29 |
logits = model_output["logits_per_image"]
|
30 |
-
score = jax.nn.sigmoid(logits)
|
31 |
return score
|
32 |
|
33 |
-
image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Self_Portrait_by_David_Allan.jpg"
|
34 |
-
text = "Patio interior de un edificio"
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from modeling_hybrid_clip import FlaxHybridCLIP
|
9 |
from run_hybrid_clip import Transform
|
10 |
|
|
|
|
|
11 |
|
12 |
+
def prepare_image(image_path, model):
|
13 |
image = read_image(image_path, mode=ImageReadMode.RGB)
|
14 |
preprocess = Transform(model.config.vision_config.image_size)
|
15 |
preprocess = torch.jit.script(preprocess)
|
|
|
17 |
pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
|
18 |
return pixel_values
|
19 |
|
20 |
+
def prepare_text(text, tokenizer):
|
21 |
return tokenizer(text, return_tensors="np")
|
22 |
|
23 |
+
def run_inference(image_path, text, model, tokenizer):
|
24 |
+
pixel_values = prepare_image(image_path, model)
|
25 |
+
input_text = prepare_text(text, tokenizer)
|
26 |
model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], token_type_ids=input_text["token_type_ids"], train=False, return_dict=True)
|
27 |
logits = model_output["logits_per_image"]
|
28 |
+
score = jax.nn.sigmoid(logits)[0][0]
|
29 |
return score
|
30 |
|
|
|
|
|
31 |
|
32 |
+
if __name__ == "__main__":
|
33 |
+
model = FlaxHybridCLIP.from_pretrained("clip_spanish_141230_samples")
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
|
35 |
+
|
36 |
+
image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg"
|
37 |
+
text = "Fachada del Santuario"
|
38 |
+
|
39 |
+
print(run_inference(image_path, text, model, tokenizer))
|