Балаганский Никита Николаевич commited on
Commit
0405bbc
1 Parent(s): 8a23a0e

add warning for russian language

Browse files
Files changed (1) hide show
  1. app.py +28 -11
app.py CHANGED
@@ -68,6 +68,29 @@ PROMPT_EXAMPLE = {
68
  "Russian": "Привет, сегодня я"
69
  }
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def main():
73
  st.header("CAIF")
@@ -110,15 +133,7 @@ def main():
110
  print(list(label2id.keys()))
111
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
112
  target_label_id = 1
113
- st.write("""
114
- **Warning!**
115
-
116
- If you are clicking checkbox bellow positive alphas for CAIF sampling become available.
117
- It means that language model will be forced to generate toxic or/and abusive text.
118
- This space is only a demonstration of our method and we are not responsible for the content of generated text.
119
-
120
- **Please use it carefully!**
121
- """)
122
  show_pos_alpha = st.checkbox("Show positive alphas", value=False)
123
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
124
  st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}")
@@ -178,7 +193,8 @@ def inference(
178
  fp16: bool = True,
179
  alpha: float = 5,
180
  target_label_id: int = 0,
181
- entropy_threshold: float = 0
 
182
  ) -> str:
183
  torch.set_grad_enabled(False)
184
  generator = load_generator(lm_model_name=lm_model_name)
@@ -198,7 +214,8 @@ def inference(
198
  "temperature": 1.0,
199
  "top_k_classifier": 100,
200
  "classifier_weight": alpha,
201
- "target_cls_id": target_label_id
 
202
  }
203
  generator.set_ordinary_sampler(ordinary_sampler)
204
  if device == "cpu":
 
68
  "Russian": "Привет, сегодня я"
69
  }
70
 
71
+ WARNING_TEXT = {
72
+ "English": """
73
+ **Warning!**
74
+
75
+ If you are clicking checkbox bellow positive alphas for CAIF sampling become available.
76
+ It means that language model will be forced to produce toxic or/and abusive text.
77
+ This space is only a demonstration of our method for controllable text generation
78
+ and we are not responsible for the content produced by this method.
79
+
80
+ **Please use it carefully and with positive intentions!**
81
+ """,
82
+ "Russian": """
83
+ ***Внимание!**
84
+
85
+ После нажатия на чекбокс ниже положительные $\alpha$ станут доступны.
86
+ Это означает, что языковая модель будет генерировать токсичные тексты.
87
+ Это демо служит лишь демонстрацией нашего метода контролируемой генерации.
88
+ Мы не несем ответственности за полученные тексты.
89
+
90
+ **Используйте этот метод осторожно и с положительными намерениями!**
91
+ """
92
+ }
93
+
94
 
95
  def main():
96
  st.header("CAIF")
 
133
  print(list(label2id.keys()))
134
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
135
  target_label_id = 1
136
+ st.write(WARNING_TEXT[language])
 
 
 
 
 
 
 
 
137
  show_pos_alpha = st.checkbox("Show positive alphas", value=False)
138
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
139
  st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}")
 
193
  fp16: bool = True,
194
  alpha: float = 5,
195
  target_label_id: int = 0,
196
+ entropy_threshold: float = 0,
197
+ act_type: str = "sigmoid"
198
  ) -> str:
199
  torch.set_grad_enabled(False)
200
  generator = load_generator(lm_model_name=lm_model_name)
 
214
  "temperature": 1.0,
215
  "top_k_classifier": 100,
216
  "classifier_weight": alpha,
217
+ "target_cls_id": target_label_id,
218
+ "act_type": act_type
219
  }
220
  generator.set_ordinary_sampler(ordinary_sampler)
221
  if device == "cpu":