Spaces:
Runtime error
Runtime error
import io | |
from PIL import Image | |
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ToTensor, Compose | |
from torchvision.transforms.functional import InterpolationMode | |
import torch | |
import numpy as np | |
from transformers import MarianTokenizer | |
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration | |
import logging | |
import streamlit as st | |
from mtranslate import translate | |
class CaptionGenerator: | |
def __init__(self): | |
self.tokenizer = None | |
self.clip_marian_model = None | |
self.marian_model_name = 'Helsinki-NLP/opus-mt-en-id' | |
self.clip_marian_model_name = 'flax-community/Image-captioning-Indonesia' | |
self.config = None | |
self.image_size = None | |
self.custom_transforms = None | |
def load(self): | |
logging.info("Loading tokenizer...") | |
marian_model_name = 'Helsinki-NLP/opus-mt-en-id' | |
self.tokenizer = MarianTokenizer.from_pretrained(self.marian_model_name) | |
logging.info("Tokenizer loaded.") | |
logging.info("Loading model...") | |
self.model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(self.clip_marian_model_name) | |
logging.info("Model loaded.") | |
self.config = self.model.config | |
self.image_size = self.config.clip_vision_config.image_size | |
self.custom_transforms = torch.nn.Sequential( | |
Resize([self.image_size], interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(self.image_size), | |
ConvertImageDtype(torch.float), | |
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
) | |
def process_image(self, file): | |
logging.info("Loading image...") | |
image_data = file.read() | |
input_image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
loader = Compose([ToTensor()]) | |
image = loader(input_image) | |
image = self.custom_transforms(image) | |
pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy() | |
logging.info("Image loaded.") | |
return pixel_values | |
def generate_step(self, pixel_values, max_len, num_beams): | |
gen_kwargs = {"max_length": max_len , "num_beams": num_beams} | |
logging.info("Generating caption...") | |
output_ids = self.model.generate(pixel_values, **gen_kwargs) | |
token_ids = np.array(output_ids.sequences)[0] | |
caption = self.tokenizer.decode(token_ids) | |
logging.info("Caption generated.") | |
return caption | |
def get_caption(self, file, max_len, num_beams): | |
pixel_values = self.process_image(file) | |
generated_ids = self.generate_step(pixel_values, max_len, num_beams) | |
return generated_ids | |
def load_caption_generator(): | |
generator = CaptionGenerator() | |
generator.load() | |
return generator | |
def main(): | |
st.set_page_config(page_title="Indonesian Image Captioning Demo", page_icon="🖼️") | |
generator = load_caption_generator() | |
st.title("Indonesian Image Captioning Demo") | |
st.markdown( | |
"""Indonesian image captioning demo, trained on [CLIP](https://huggingface.co/transformers/model_doc/clip.html) and [Marian](https://huggingface.co/transformers/model_doc/marian.html). Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/). | |
""" | |
) | |
st.sidebar.subheader("Configurable parameters") | |
max_len = st.sidebar.number_input( | |
"Maximum length", | |
value=8, | |
help="The maximum length of the sequence (caption) to be generated." | |
) | |
num_beams = st.sidebar.number_input( | |
"Number of beams", | |
value=4, | |
help="Number of beams for beam search. 1 means no beam search." | |
) | |
input_image = st.file_uploader("Insert image") | |
if st.button("Run"): | |
with st.spinner(text="Getting results..."): | |
if input_image: | |
caption = generator.get_caption(file=input_image, max_len=max_len, num_beams=num_beams) | |
st.subheader("Result") | |
st.write(caption.replace("<pad>", "")) | |
st.text("English translation") | |
st.write(translate(caption, "en", "id").replace("<pad>", "")) | |
else: | |
st.write("Please upload an image.") | |
if __name__ == '__main__': | |
main() | |