Gurgen-Blbulyan
adding files for app
08cc25a
raw
history blame
1.15 kB
import torch
from transformers import AutoTokenizer, VisionEncoderDecoderModel
import utils
class Inference:
def __init__(self, decoder_model_name, max_length=32):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
self.encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained('armgabrielyan/video-summarization')
self.encoder_decoder_model.to(self.device)
self.max_length = max_length
def generate_text(self, video, encoder_model_name):
if isinstance(video, str):
pixel_values = utils.video2image_from_path(video, encoder_model_name)
else:
pixel_values = video
if not self.tokenizer.pad_token:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))
generated_ids = self.encoder_decoder_model.generate(pixel_values.unsqueeze(0).to(self.device), max_length=self.max_length)
generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text