Poster2Plot
An image captioning model to generate movie/t.v show plot from poster. It generates decent plots but is no way perfect. We are still working on improving the model.
Live demo on Hugging Face Spaces: https://huggingface.co/spaces/deepklarity/poster2plot
Model Details
The base model uses a Vision Transformer (ViT) model as an image encoder and GPT-2 as a decoder.
We used the following models:
- Encoder: google/vit-base-patch16-224-in21k
- Decoder: gpt2
Datasets
Publicly available IMDb datasets were used to train the model.
How to use
In PyTorch
import torch
import re
import requests
from PIL import Image
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
# Pattern to ignore all the text after 2 or more full stops
regex_pattern = "[.]{2,}"
def post_process(text):
try:
text = text.strip()
text = re.split(regex_pattern, text)[0]
except Exception as e:
print(e)
pass
return text
def predict(image, max_length=64, num_beams=4):
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
with torch.no_grad():
output_ids = model.generate(
pixel_values,
max_length=max_length,
num_beams=num_beams,
return_dict_in_generate=True,
).sequences
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
pred = post_process(preds[0])
return pred
model_name_or_path = "deepklarity/poster2plot"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model.
model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
model.to(device)
print("Loaded model")
feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
print("Loaded feature_extractor")
tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
if model.decoder.name_or_path == "gpt2":
tokenizer.pad_token = tokenizer.eos_token
print("Loaded tokenizer")
url = "https://upload.wikimedia.org/wikipedia/en/2/26/Moana_Teaser_Poster.jpg"
with Image.open(requests.get(url, stream=True).raw) as image:
pred = predict(image)
print(pred)
- Downloads last month
- 6
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.