edugp commited on
Commit
c309418
1 Parent(s): 5019883

Refactor test_on_image.py

Browse files
Files changed (1) hide show
  1. test_on_image.py +14 -11
test_on_image.py CHANGED
@@ -8,10 +8,8 @@ from transformers import AutoTokenizer
8
  from modeling_hybrid_clip import FlaxHybridCLIP
9
  from run_hybrid_clip import Transform
10
 
11
- model = FlaxHybridCLIP.from_pretrained("clip_spanish_1_percent")
12
- tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
13
 
14
- def prepare_image(image_path):
15
  image = read_image(image_path, mode=ImageReadMode.RGB)
16
  preprocess = Transform(model.config.vision_config.image_size)
17
  preprocess = torch.jit.script(preprocess)
@@ -19,18 +17,23 @@ def prepare_image(image_path):
19
  pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
20
  return pixel_values
21
 
22
- def prepare_text(text):
23
  return tokenizer(text, return_tensors="np")
24
 
25
- def run_inference(image_path, text):
26
- pixel_values = prepare_image(image_path)
27
- input_text = prepare_text(text)
28
  model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], token_type_ids=input_text["token_type_ids"], train=False, return_dict=True)
29
  logits = model_output["logits_per_image"]
30
- score = jax.nn.sigmoid(logits)
31
  return score
32
 
33
- image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Self_Portrait_by_David_Allan.jpg"
34
- text = "Patio interior de un edificio"
35
 
36
- print(run_inference(image_path, text))
 
 
 
 
 
 
 
 
8
  from modeling_hybrid_clip import FlaxHybridCLIP
9
  from run_hybrid_clip import Transform
10
 
 
 
11
 
12
+ def prepare_image(image_path, model):
13
  image = read_image(image_path, mode=ImageReadMode.RGB)
14
  preprocess = Transform(model.config.vision_config.image_size)
15
  preprocess = torch.jit.script(preprocess)
 
17
  pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
18
  return pixel_values
19
 
20
+ def prepare_text(text, tokenizer):
21
  return tokenizer(text, return_tensors="np")
22
 
23
+ def run_inference(image_path, text, model, tokenizer):
24
+ pixel_values = prepare_image(image_path, model)
25
+ input_text = prepare_text(text, tokenizer)
26
  model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], token_type_ids=input_text["token_type_ids"], train=False, return_dict=True)
27
  logits = model_output["logits_per_image"]
28
+ score = jax.nn.sigmoid(logits)[0][0]
29
  return score
30
 
 
 
31
 
32
+ if __name__ == "__main__":
33
+ model = FlaxHybridCLIP.from_pretrained("clip_spanish_141230_samples")
34
+ tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
35
+
36
+ image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg"
37
+ text = "Fachada del Santuario"
38
+
39
+ print(run_inference(image_path, text, model, tokenizer))