|
import os |
|
|
|
import jax |
|
import torch |
|
from torchvision.io import ImageReadMode, read_image |
|
from transformers import AutoTokenizer |
|
|
|
from modeling_hybrid_clip import FlaxHybridCLIP |
|
from run_hybrid_clip import Transform |
|
|
|
|
|
def prepare_image(image_path, model): |
|
image = read_image(image_path, mode=ImageReadMode.RGB) |
|
preprocess = Transform(model.config.vision_config.image_size) |
|
preprocess = torch.jit.script(preprocess) |
|
preprocessed_image = preprocess(image) |
|
pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy() |
|
return pixel_values |
|
|
|
def prepare_text(text, tokenizer): |
|
return tokenizer(text, return_tensors="np") |
|
|
|
def run_inference(image_path, text, model, tokenizer): |
|
pixel_values = prepare_image(image_path, model) |
|
input_text = prepare_text(text, tokenizer) |
|
model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], train=False, return_dict=True) |
|
logits = model_output["logits_per_image"] |
|
score = jax.nn.sigmoid(logits)[0][0] |
|
return score |
|
|
|
|
|
if __name__ == "__main__": |
|
model = FlaxHybridCLIP.from_pretrained("./") |
|
tokenizer = AutoTokenizer.from_pretrained("bertin-project/bertin-roberta-base-spanish") |
|
|
|
image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg" |
|
text = "Fachada del Santuario" |
|
|
|
print(run_inference(image_path, text, model, tokenizer)) |