Балаганский Никита Николаевич
commited on
Commit
•
0405bbc
1
Parent(s):
8a23a0e
add warning for russian language
Browse files
app.py
CHANGED
@@ -68,6 +68,29 @@ PROMPT_EXAMPLE = {
|
|
68 |
"Russian": "Привет, сегодня я"
|
69 |
}
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def main():
|
73 |
st.header("CAIF")
|
@@ -110,15 +133,7 @@ def main():
|
|
110 |
print(list(label2id.keys()))
|
111 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
|
112 |
target_label_id = 1
|
113 |
-
st.write(
|
114 |
-
**Warning!**
|
115 |
-
|
116 |
-
If you are clicking checkbox bellow positive alphas for CAIF sampling become available.
|
117 |
-
It means that language model will be forced to generate toxic or/and abusive text.
|
118 |
-
This space is only a demonstration of our method and we are not responsible for the content of generated text.
|
119 |
-
|
120 |
-
**Please use it carefully!**
|
121 |
-
""")
|
122 |
show_pos_alpha = st.checkbox("Show positive alphas", value=False)
|
123 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
124 |
st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}")
|
@@ -178,7 +193,8 @@ def inference(
|
|
178 |
fp16: bool = True,
|
179 |
alpha: float = 5,
|
180 |
target_label_id: int = 0,
|
181 |
-
entropy_threshold: float = 0
|
|
|
182 |
) -> str:
|
183 |
torch.set_grad_enabled(False)
|
184 |
generator = load_generator(lm_model_name=lm_model_name)
|
@@ -198,7 +214,8 @@ def inference(
|
|
198 |
"temperature": 1.0,
|
199 |
"top_k_classifier": 100,
|
200 |
"classifier_weight": alpha,
|
201 |
-
"target_cls_id": target_label_id
|
|
|
202 |
}
|
203 |
generator.set_ordinary_sampler(ordinary_sampler)
|
204 |
if device == "cpu":
|
|
|
68 |
"Russian": "Привет, сегодня я"
|
69 |
}
|
70 |
|
71 |
+
WARNING_TEXT = {
|
72 |
+
"English": """
|
73 |
+
**Warning!**
|
74 |
+
|
75 |
+
If you are clicking checkbox bellow positive alphas for CAIF sampling become available.
|
76 |
+
It means that language model will be forced to produce toxic or/and abusive text.
|
77 |
+
This space is only a demonstration of our method for controllable text generation
|
78 |
+
and we are not responsible for the content produced by this method.
|
79 |
+
|
80 |
+
**Please use it carefully and with positive intentions!**
|
81 |
+
""",
|
82 |
+
"Russian": """
|
83 |
+
***Внимание!**
|
84 |
+
|
85 |
+
После нажатия на чекбокс ниже положительные $\alpha$ станут доступны.
|
86 |
+
Это означает, что языковая модель будет генерировать токсичные тексты.
|
87 |
+
Это демо служит лишь демонстрацией нашего метода контролируемой генерации.
|
88 |
+
Мы не несем ответственности за полученные тексты.
|
89 |
+
|
90 |
+
**Используйте этот метод осторожно и с положительными намерениями!**
|
91 |
+
"""
|
92 |
+
}
|
93 |
+
|
94 |
|
95 |
def main():
|
96 |
st.header("CAIF")
|
|
|
133 |
print(list(label2id.keys()))
|
134 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
|
135 |
target_label_id = 1
|
136 |
+
st.write(WARNING_TEXT[language])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
show_pos_alpha = st.checkbox("Show positive alphas", value=False)
|
138 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
139 |
st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}")
|
|
|
193 |
fp16: bool = True,
|
194 |
alpha: float = 5,
|
195 |
target_label_id: int = 0,
|
196 |
+
entropy_threshold: float = 0,
|
197 |
+
act_type: str = "sigmoid"
|
198 |
) -> str:
|
199 |
torch.set_grad_enabled(False)
|
200 |
generator = load_generator(lm_model_name=lm_model_name)
|
|
|
214 |
"temperature": 1.0,
|
215 |
"top_k_classifier": 100,
|
216 |
"classifier_weight": alpha,
|
217 |
+
"target_cls_id": target_label_id,
|
218 |
+
"act_type": act_type
|
219 |
}
|
220 |
generator.set_ordinary_sampler(ordinary_sampler)
|
221 |
if device == "cpu":
|