Балаганский Никита Николаевич
commited on
Commit
•
b5251fc
1
Parent(s):
367e85c
fix
Browse files
app.py
CHANGED
@@ -89,6 +89,7 @@ def main():
|
|
89 |
alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
|
90 |
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
|
91 |
auth_token = os.environ.get('TOKEN') or True
|
|
|
92 |
with st.spinner('Running inference...'):
|
93 |
text = inference(
|
94 |
lm_model_name=lm_model_name,
|
@@ -97,6 +98,7 @@ def main():
|
|
97 |
alpha=alpha,
|
98 |
target_label_id=target_label_id,
|
99 |
entropy_threshold=entropy_threshold,
|
|
|
100 |
)
|
101 |
st.subheader("Generated text:")
|
102 |
st.markdown(text)
|
@@ -110,7 +112,7 @@ def load_generator(lm_model_name: str) -> Generator:
|
|
110 |
generator = Generator(lm_model_name=lm_model_name, device=device)
|
111 |
return generator
|
112 |
|
113 |
-
|
114 |
def load_sampler(cls_model_name, lm_tokenizer):
|
115 |
with st.spinner('Loading classifier model...'):
|
116 |
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
|
|
|
89 |
alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
|
90 |
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
|
91 |
auth_token = os.environ.get('TOKEN') or True
|
92 |
+
fp16 = st.checkbox("FP16", value=True)
|
93 |
with st.spinner('Running inference...'):
|
94 |
text = inference(
|
95 |
lm_model_name=lm_model_name,
|
|
|
98 |
alpha=alpha,
|
99 |
target_label_id=target_label_id,
|
100 |
entropy_threshold=entropy_threshold,
|
101 |
+
fp16=fp16,
|
102 |
)
|
103 |
st.subheader("Generated text:")
|
104 |
st.markdown(text)
|
|
|
112 |
generator = Generator(lm_model_name=lm_model_name, device=device)
|
113 |
return generator
|
114 |
|
115 |
+
#@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
|
116 |
def load_sampler(cls_model_name, lm_tokenizer):
|
117 |
with st.spinner('Loading classifier model...'):
|
118 |
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
|