Балаганский Никита Николаевич commited on
Commit
b5251fc
1 Parent(s): 367e85c
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -89,6 +89,7 @@ def main():
89
  alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
90
  entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
91
  auth_token = os.environ.get('TOKEN') or True
 
92
  with st.spinner('Running inference...'):
93
  text = inference(
94
  lm_model_name=lm_model_name,
@@ -97,6 +98,7 @@ def main():
97
  alpha=alpha,
98
  target_label_id=target_label_id,
99
  entropy_threshold=entropy_threshold,
 
100
  )
101
  st.subheader("Generated text:")
102
  st.markdown(text)
@@ -110,7 +112,7 @@ def load_generator(lm_model_name: str) -> Generator:
110
  generator = Generator(lm_model_name=lm_model_name, device=device)
111
  return generator
112
 
113
- @st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
114
  def load_sampler(cls_model_name, lm_tokenizer):
115
  with st.spinner('Loading classifier model...'):
116
  sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
 
89
  alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
90
  entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
91
  auth_token = os.environ.get('TOKEN') or True
92
+ fp16 = st.checkbox("FP16", value=True)
93
  with st.spinner('Running inference...'):
94
  text = inference(
95
  lm_model_name=lm_model_name,
 
98
  alpha=alpha,
99
  target_label_id=target_label_id,
100
  entropy_threshold=entropy_threshold,
101
+ fp16=fp16,
102
  )
103
  st.subheader("Generated text:")
104
  st.markdown(text)
 
112
  generator = Generator(lm_model_name=lm_model_name, device=device)
113
  return generator
114
 
115
+ #@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
116
  def load_sampler(cls_model_name, lm_tokenizer):
117
  with st.spinner('Loading classifier model...'):
118
  sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)