Galuh Sahid
Init
653217a unverified
import io
from PIL import Image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ToTensor, Compose
from torchvision.transforms.functional import InterpolationMode
import torch
import numpy as np
from transformers import MarianTokenizer
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration
import logging
import streamlit as st
from mtranslate import translate
class CaptionGenerator:
def __init__(self):
self.tokenizer = None
self.clip_marian_model = None
self.marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
self.clip_marian_model_name = 'flax-community/Image-captioning-Indonesia'
self.config = None
self.image_size = None
self.custom_transforms = None
def load(self):
logging.info("Loading tokenizer...")
marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
self.tokenizer = MarianTokenizer.from_pretrained(self.marian_model_name)
logging.info("Tokenizer loaded.")
logging.info("Loading model...")
self.model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(self.clip_marian_model_name)
logging.info("Model loaded.")
self.config = self.model.config
self.image_size = self.config.clip_vision_config.image_size
self.custom_transforms = torch.nn.Sequential(
Resize([self.image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(self.image_size),
ConvertImageDtype(torch.float),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
)
def process_image(self, file):
logging.info("Loading image...")
image_data = file.read()
input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
loader = Compose([ToTensor()])
image = loader(input_image)
image = self.custom_transforms(image)
pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy()
logging.info("Image loaded.")
return pixel_values
def generate_step(self, pixel_values, max_len, num_beams):
gen_kwargs = {"max_length": max_len , "num_beams": num_beams}
logging.info("Generating caption...")
output_ids = self.model.generate(pixel_values, **gen_kwargs)
token_ids = np.array(output_ids.sequences)[0]
caption = self.tokenizer.decode(token_ids)
logging.info("Caption generated.")
return caption
def get_caption(self, file, max_len, num_beams):
pixel_values = self.process_image(file)
generated_ids = self.generate_step(pixel_values, max_len, num_beams)
return generated_ids
@st.cache(allow_output_mutation=True)
def load_caption_generator():
generator = CaptionGenerator()
generator.load()
return generator
def main():
st.set_page_config(page_title="Indonesian Image Captioning Demo", page_icon="🖼️")
generator = load_caption_generator()
st.title("Indonesian Image Captioning Demo")
st.markdown(
"""Indonesian image captioning demo, trained on [CLIP](https://huggingface.co/transformers/model_doc/clip.html) and [Marian](https://huggingface.co/transformers/model_doc/marian.html). Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).
"""
)
st.sidebar.subheader("Configurable parameters")
max_len = st.sidebar.number_input(
"Maximum length",
value=8,
help="The maximum length of the sequence (caption) to be generated."
)
num_beams = st.sidebar.number_input(
"Number of beams",
value=4,
help="Number of beams for beam search. 1 means no beam search."
)
input_image = st.file_uploader("Insert image")
if st.button("Run"):
with st.spinner(text="Getting results..."):
if input_image:
caption = generator.get_caption(file=input_image, max_len=max_len, num_beams=num_beams)
st.subheader("Result")
st.write(caption.replace("<pad>", ""))
st.text("English translation")
st.write(translate(caption, "en", "id").replace("<pad>", ""))
else:
st.write("Please upload an image.")
if __name__ == '__main__':
main()