import streamlit as st from st_pages import Page, show_pages, add_page_title, Section from lib.utils.model import get_model, get_similarities, get_detr, segment_images from lib.utils.timer import timer add_page_title() show_pages( [ Page('', 'IRRA Text-To-Image-Retrival'), Section('Implementation Details'), Page('pages/', 'Loss functions'), ] ) st.markdown(''' A text-to-image retrieval model implemented from [arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval]( The uploaded images should be `384x128` with only one person in the shot. ''') st.header('Inputs') caption = st.text_input('Description Input') images = st.file_uploader('Upload images', accept_multiple_files=True) if images is not None: st.image(images) # type: ignore st.header('Options') st.subheader('Ranks', help='How many predictions the model is allowed to make') ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed', value=5) do_segment = st.checkbox('Segment images with DETR', value=False) button = st.button('Match most similar', disabled=len( images) == 0 or caption == '') if button: if do_segment: detr, processor = get_detr() images = segment_images(detr, processor, images) st.header('Results') with st.spinner('Loading model'): model = get_model() st.text( f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters') time = timer() with st.spinner('Computing and ranking similarities'): with timer() as t: similarities = get_similarities(caption, images, model).squeeze(0) elapsed = t() indices = similarities.argsort(descending=True).cpu().tolist()[:ranks] c1, c2, c3 = st.columns(3) with c1: st.subheader('Rank') with c2: st.subheader('Image') with c3: st.subheader('Cosine Similarity', help='Due to the nature of the SDM loss, the higher the similarity, the more similar the match is') for i, idx in enumerate(indices): c1, c2, c3 = st.columns(3) with c1: st.text(f'{i + 1}') with c2: st.image(images[idx]) with c3: st.text(f'{similarities[idx].cpu():.2f}') st.success(f'Done in {elapsed:.2f}s') with st.sidebar: st.title('IRRA Text-To-Image Retrival') st.subheader('Useful Links') st.markdown('[arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](') st.markdown( '[IRRA implementation (Pytorch Lightning + Transformers)](') st.markdown( '[IRRA implementation (PyTorch)](')