Балаганский Никита Николаевич commited on
Commit
01f8fc1
1 Parent(s): b5beaeb

add num_tokens arg

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -146,9 +146,9 @@ def main():
146
  act_type = "sigmoid"
147
  else:
148
  label2id = cls_model_config.label2id
149
- filtered_label2id = {k: v for k, v in label2id.items() if "negative" in k}
150
- label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
151
- target_label_id = label2id[label_key]
152
  act_type = "softmax"
153
  st.write(WARNING_TEXT[language])
154
  show_pos_alpha = st.checkbox("Show positive alphas", value=False)
@@ -183,6 +183,8 @@ def main():
183
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie")
184
  else:
185
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
 
 
186
  st.subheader("Generated text:")
187
 
188
  def generate():
@@ -194,7 +196,8 @@ def main():
194
  target_label_id=target_label_id,
195
  entropy_threshold=entropy_threshold,
196
  fp16=fp16,
197
- act_type=act_type
 
198
  )
199
 
200
  st.button("Generate new", on_click=generate())
@@ -225,7 +228,8 @@ def inference(
225
  alpha: float = 5,
226
  target_label_id: int = 0,
227
  entropy_threshold: float = 0,
228
- act_type: str = "sigmoid"
 
229
  ) -> str:
230
  torch.set_grad_enabled(False)
231
  generator = load_generator(lm_model_name=lm_model_name)
@@ -259,7 +263,7 @@ def inference(
259
  sequences, tokens = generator.sample_sequences(
260
  num_samples=1,
261
  input_prompt=prompt,
262
- max_length=20,
263
  caif_period=1,
264
  entropy=entropy_threshold,
265
  progress_bar=progress_bar,
 
146
  act_type = "sigmoid"
147
  else:
148
  label2id = cls_model_config.label2id
149
+ filtered_label2id = {k: v for k, v in label2id.items() if "negative" in k.lower()}
150
+ label_key = st.selectbox(ATTRIBUTE_LABEL[language], filtered_label2id.keys())
151
+ target_label_id = filtered_label2id[label_key]
152
  act_type = "softmax"
153
  st.write(WARNING_TEXT[language])
154
  show_pos_alpha = st.checkbox("Show positive alphas", value=False)
 
183
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie")
184
  else:
185
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
186
+ num_tokens = st.slider("# tokens to be generated", min_value=5, max_value=40, step=1, value=20)
187
+ num_tokens = int(num_tokens)
188
  st.subheader("Generated text:")
189
 
190
  def generate():
 
196
  target_label_id=target_label_id,
197
  entropy_threshold=entropy_threshold,
198
  fp16=fp16,
199
+ act_type=act_type,
200
+ num_tokens=num_tokens
201
  )
202
 
203
  st.button("Generate new", on_click=generate())
 
228
  alpha: float = 5,
229
  target_label_id: int = 0,
230
  entropy_threshold: float = 0,
231
+ act_type: str = "sigmoid",
232
+ num_tokens=10,
233
  ) -> str:
234
  torch.set_grad_enabled(False)
235
  generator = load_generator(lm_model_name=lm_model_name)
 
263
  sequences, tokens = generator.sample_sequences(
264
  num_samples=1,
265
  input_prompt=prompt,
266
+ max_length=num_tokens,
267
  caif_period=1,
268
  entropy=entropy_threshold,
269
  progress_bar=progress_bar,