|
import os |
|
|
|
import jax |
|
import torch |
|
from torchvision.io import ImageReadMode, read_image |
|
from transformers import AutoTokenizer |
|
|
|
from modeling_hybrid_clip import FlaxHybridCLIP |
|
from run_hybrid_clip import Transform |
|
|
|
model = FlaxHybridCLIP.from_pretrained("clip_spanish_1_percent") |
|
tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased") |
|
|
|
def prepare_image(image_path): |
|
image = read_image(image_path, mode=ImageReadMode.RGB) |
|
preprocess = Transform(model.config.vision_config.image_size) |
|
preprocess = torch.jit.script(preprocess) |
|
preprocessed_image = preprocess(image) |
|
pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy() |
|
return pixel_values |
|
|
|
def prepare_text(text): |
|
return tokenizer(text, return_tensors="np") |
|
|
|
def run_inference(image_path, text): |
|
pixel_values = prepare_image(image_path) |
|
input_text = prepare_text(text) |
|
model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], token_type_ids=input_text["token_type_ids"], train=False, return_dict=True) |
|
logits = model_output["logits_per_image"] |
|
score = jax.nn.sigmoid(logits) |
|
return score |
|
|
|
image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Self_Portrait_by_David_Allan.jpg" |
|
text = "Patio interior de un edificio" |
|
|
|
print(run_inference(image_path, text)) |
|
|