import argparse import glob import os import cv2 import numpy import torch from PIL import Image from Model import TRCaptionNet, clip_transform def demo(opt): preprocess = clip_transform(224) model = TRCaptionNet({ "max_length": 35, "clip": "ViT-L/14", "bert": "dbmdz/bert-base-turkish-cased", "proj": True, "proj_num_head": 16 }) device = torch.device(opt.device) model.load_state_dict(torch.load(opt.model_ckpt, map_location=device)["model"], strict=True) model = model.to(device) model.eval() image_paths = glob.glob(os.path.join(opt.input_dir, '*.jpg')) for image_path in sorted(image_paths): img_name = image_path.split('/')[-1] img0 = Image.open(image_path) batch = preprocess(img0).unsqueeze(0).to(device) caption = model.generate(batch, min_length=11, repetition_penalty=1.6)[0] print(f"{img_name} :", caption) orj_img = numpy.array(img0)[:, :, ::-1] h, w, _ = orj_img.shape new_h = 800 new_w = int(new_h * (w / h)) orj_img = cv2.resize(orj_img, (new_w, new_h)) cv2.imshow("image", orj_img) cv2.waitKey(0) return if __name__ == '__main__': parser = argparse.ArgumentParser(description='Turkish-Image-Captioning!') parser.add_argument('--model-ckpt', type=str, default='./checkpoints/TRCaptionNet_L14_berturk.pth') parser.add_argument('--input-dir', type=str, default='./images/') parser.add_argument('--device', type=str, default='cuda:0') args = parser.parse_args() demo(args)