Балаганский Никита Николаевич
commited on
Commit
•
7ff7323
1
Parent(s):
cb18e78
fixes
Browse files
app.py
CHANGED
@@ -179,7 +179,8 @@ def main():
|
|
179 |
st.plotly_chart(figure, use_container_width=True)
|
180 |
auth_token = os.environ.get('TOKEN') or True
|
181 |
fp16 = st.checkbox("FP16", value=True)
|
182 |
-
|
|
|
183 |
text = inference(
|
184 |
lm_model_name=lm_model_name,
|
185 |
cls_model_name=cls_model_name,
|
@@ -190,8 +191,11 @@ def main():
|
|
190 |
fp16=fp16,
|
191 |
act_type=act_type
|
192 |
)
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
195 |
|
196 |
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
|
197 |
def load_generator(lm_model_name: str) -> Generator:
|
@@ -199,7 +203,8 @@ def load_generator(lm_model_name: str) -> Generator:
|
|
199 |
generator = Generator(lm_model_name=lm_model_name, device=device)
|
200 |
return generator
|
201 |
|
202 |
-
|
|
|
203 |
def load_sampler(cls_model_name, lm_tokenizer):
|
204 |
with st.spinner('Loading classifier model...'):
|
205 |
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
|
|
|
179 |
st.plotly_chart(figure, use_container_width=True)
|
180 |
auth_token = os.environ.get('TOKEN') or True
|
181 |
fp16 = st.checkbox("FP16", value=True)
|
182 |
+
|
183 |
+
def generate():
|
184 |
text = inference(
|
185 |
lm_model_name=lm_model_name,
|
186 |
cls_model_name=cls_model_name,
|
|
|
191 |
fp16=fp16,
|
192 |
act_type=act_type
|
193 |
)
|
194 |
+
st.subheader("Generated text:")
|
195 |
+
st.write(text)
|
196 |
+
generate()
|
197 |
+
st.button("Generate new", on_click=generate())
|
198 |
+
|
199 |
|
200 |
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
|
201 |
def load_generator(lm_model_name: str) -> Generator:
|
|
|
203 |
generator = Generator(lm_model_name=lm_model_name, device=device)
|
204 |
return generator
|
205 |
|
206 |
+
|
207 |
+
# @st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
|
208 |
def load_sampler(cls_model_name, lm_tokenizer):
|
209 |
with st.spinner('Loading classifier model...'):
|
210 |
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
|