Балаганский Никита Николаевич commited on
Commit
7ff7323
1 Parent(s): cb18e78
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -179,7 +179,8 @@ def main():
179
  st.plotly_chart(figure, use_container_width=True)
180
  auth_token = os.environ.get('TOKEN') or True
181
  fp16 = st.checkbox("FP16", value=True)
182
- with st.spinner('Running inference...'):
 
183
  text = inference(
184
  lm_model_name=lm_model_name,
185
  cls_model_name=cls_model_name,
@@ -190,8 +191,11 @@ def main():
190
  fp16=fp16,
191
  act_type=act_type
192
  )
193
- st.subheader("Generated text:")
194
- st.write(text)
 
 
 
195
 
196
  @st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
197
  def load_generator(lm_model_name: str) -> Generator:
@@ -199,7 +203,8 @@ def load_generator(lm_model_name: str) -> Generator:
199
  generator = Generator(lm_model_name=lm_model_name, device=device)
200
  return generator
201
 
202
- #@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
 
203
  def load_sampler(cls_model_name, lm_tokenizer):
204
  with st.spinner('Loading classifier model...'):
205
  sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
 
179
  st.plotly_chart(figure, use_container_width=True)
180
  auth_token = os.environ.get('TOKEN') or True
181
  fp16 = st.checkbox("FP16", value=True)
182
+
183
+ def generate():
184
  text = inference(
185
  lm_model_name=lm_model_name,
186
  cls_model_name=cls_model_name,
 
191
  fp16=fp16,
192
  act_type=act_type
193
  )
194
+ st.subheader("Generated text:")
195
+ st.write(text)
196
+ generate()
197
+ st.button("Generate new", on_click=generate())
198
+
199
 
200
  @st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
201
  def load_generator(lm_model_name: str) -> Generator:
 
203
  generator = Generator(lm_model_name=lm_model_name, device=device)
204
  return generator
205
 
206
+
207
+ # @st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
208
  def load_sampler(cls_model_name, lm_tokenizer):
209
  with st.spinner('Loading classifier model...'):
210
  sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)