TRCaptionNet / demo.py
serdaryildiz's picture
init
dfd33e5
raw
history blame contribute delete
No virus
1.6 kB
import argparse
import glob
import os
import cv2
import numpy
import torch
from PIL import Image
from Model import TRCaptionNet, clip_transform
def demo(opt):
preprocess = clip_transform(224)
model = TRCaptionNet({
"max_length": 35,
"clip": "ViT-L/14",
"bert": "dbmdz/bert-base-turkish-cased",
"proj": True,
"proj_num_head": 16
})
device = torch.device(opt.device)
model.load_state_dict(torch.load(opt.model_ckpt, map_location=device)["model"], strict=True)
model = model.to(device)
model.eval()
image_paths = glob.glob(os.path.join(opt.input_dir, '*.jpg'))
for image_path in sorted(image_paths):
img_name = image_path.split('/')[-1]
img0 = Image.open(image_path)
batch = preprocess(img0).unsqueeze(0).to(device)
caption = model.generate(batch, min_length=11, repetition_penalty=1.6)[0]
print(f"{img_name} :", caption)
orj_img = numpy.array(img0)[:, :, ::-1]
h, w, _ = orj_img.shape
new_h = 800
new_w = int(new_h * (w / h))
orj_img = cv2.resize(orj_img, (new_w, new_h))
cv2.imshow("image", orj_img)
cv2.waitKey(0)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Turkish-Image-Captioning!')
parser.add_argument('--model-ckpt', type=str, default='./checkpoints/TRCaptionNet_L14_berturk.pth')
parser.add_argument('--input-dir', type=str, default='./images/')
parser.add_argument('--device', type=str, default='cuda:0')
args = parser.parse_args()
demo(args)