Spaces:
Runtime error
Runtime error
gchhablani
commited on
Commit
•
bea24f7
1
Parent(s):
185a893
Allow clearing of cache
Browse files
app.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
from io import BytesIO
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
4 |
-
import json
|
5 |
import os
|
6 |
import numpy as np
|
7 |
-
from streamlit
|
8 |
from PIL import Image
|
9 |
from model.flax_clip_vision_marian.modeling_clip_vision_marian import (
|
10 |
FlaxCLIPVisionMarianMT,
|
@@ -31,7 +30,7 @@ tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
|
|
31 |
|
32 |
@st.cache(persist=True)
|
33 |
def generate_sequence(pixel_values, num_beams, temperature, top_p):
|
34 |
-
output_ids = model.generate(input_ids=pixel_values, max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p)
|
35 |
print(output_ids)
|
36 |
output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
|
37 |
return output_sequence
|
@@ -60,7 +59,8 @@ st.sidebar.title("Generation Parameters")
|
|
60 |
num_beams = st.sidebar.number_input("Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
|
61 |
temperature = st.sidebar.select_slider("Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
|
62 |
top_p = st.sidebar.select_slider("Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
|
63 |
-
|
|
|
64 |
|
65 |
image_col, intro_col = st.beta_columns([3, 8])
|
66 |
image_col.image("./misc/sic-logo.png", use_column_width="always")
|
@@ -84,6 +84,10 @@ with st.beta_expander("Article"):
|
|
84 |
st.write(read_markdown("acknowledgements.md"))
|
85 |
|
86 |
|
|
|
|
|
|
|
|
|
87 |
first_index = 20
|
88 |
# Init Session State
|
89 |
if state.image_file is None:
|
@@ -124,8 +128,7 @@ new_col2.markdown(
|
|
124 |
f"""**English Translation**: {translate(state.caption, 'en')}"""
|
125 |
)
|
126 |
|
127 |
-
|
128 |
-
model = load_model(checkpoints[0])
|
129 |
sequence = ['']
|
130 |
if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
|
131 |
with st.spinner("Generating Sequence..."):
|
|
|
1 |
from io import BytesIO
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
|
|
4 |
import os
|
5 |
import numpy as np
|
6 |
+
from streamlit import caching
|
7 |
from PIL import Image
|
8 |
from model.flax_clip_vision_marian.modeling_clip_vision_marian import (
|
9 |
FlaxCLIPVisionMarianMT,
|
|
|
30 |
|
31 |
@st.cache(persist=True)
|
32 |
def generate_sequence(pixel_values, num_beams, temperature, top_p):
|
33 |
+
output_ids = state.model.generate(input_ids=pixel_values, max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p)
|
34 |
print(output_ids)
|
35 |
output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
|
36 |
return output_sequence
|
|
|
59 |
num_beams = st.sidebar.number_input("Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
|
60 |
temperature = st.sidebar.select_slider("Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
|
61 |
top_p = st.sidebar.select_slider("Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
|
62 |
+
if st.sidebar.button("Clear All Cache"):
|
63 |
+
caching.clear_cache()
|
64 |
|
65 |
image_col, intro_col = st.beta_columns([3, 8])
|
66 |
image_col.image("./misc/sic-logo.png", use_column_width="always")
|
|
|
84 |
st.write(read_markdown("acknowledgements.md"))
|
85 |
|
86 |
|
87 |
+
if state.model is None:
|
88 |
+
with st.spinner("Loading model..."):
|
89 |
+
state.model = load_model(checkpoints[0])
|
90 |
+
|
91 |
first_index = 20
|
92 |
# Init Session State
|
93 |
if state.image_file is None:
|
|
|
128 |
f"""**English Translation**: {translate(state.caption, 'en')}"""
|
129 |
)
|
130 |
|
131 |
+
|
|
|
132 |
sequence = ['']
|
133 |
if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
|
134 |
with st.spinner("Generating Sequence..."):
|