Балаганский Никита Николаевич commited on
Commit
b5beaeb
1 Parent(s): 3fc6ce8

remove russian language

Browse files
Files changed (1) hide show
  1. app.py +8 -15
app.py CHANGED
@@ -123,7 +123,7 @@ def main():
123
  "template": "plotly_white",
124
  })
125
 
126
- language = st.selectbox("Language", ("English", "Russian"))
127
  cls_model_name = st.selectbox(
128
  ATTRIBUTE_MODEL_LABEL[language],
129
  ATTRIBUTE_MODELS[language]
@@ -136,15 +136,7 @@ def main():
136
  cls_model_config = AutoConfig.from_pretrained(cls_model_name)
137
  if cls_model_config.problem_type == "multi_label_classification":
138
  label2id = cls_model_config.label2id
139
- if "rubert-tiny-toxicity" in cls_model_name:
140
- idx = 0
141
- for i, k in enumerate(label2id.keys()):
142
- if k == 'threat':
143
- idx = i
144
-
145
- label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys(), index=idx)
146
- else:
147
- label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
148
  target_label_id = label2id[label_key]
149
  act_type = "sigmoid"
150
  elif cls_model_config.problem_type == "single_label_classification":
@@ -154,20 +146,17 @@ def main():
154
  act_type = "sigmoid"
155
  else:
156
  label2id = cls_model_config.label2id
 
157
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
158
  target_label_id = label2id[label_key]
159
  act_type = "softmax"
160
  st.write(WARNING_TEXT[language])
161
  show_pos_alpha = st.checkbox("Show positive alphas", value=False)
162
- if "sst" in cls_model_name:
163
- prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie")
164
- else:
165
- prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
166
  st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}")
167
  if act_type == "softmax":
168
  alpha = st.slider("α", min_value=-40, max_value=40 if show_pos_alpha else 0, step=1, value=0)
169
  else:
170
- alpha = st.slider("α", min_value=-10, max_value=10 if show_pos_alpha else 0, step=1, value=0)
171
  entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=10., step=.1, value=2.)
172
  plot_idx = np.argmin(np.abs(entropy_threshold - x_s))
173
  scatter_tip = go.Scatter({
@@ -190,6 +179,10 @@ def main():
190
  auth_token = os.environ.get('TOKEN') or True
191
  fp16 = st.checkbox("FP16", value=True)
192
  st.session_state["generated_text"] = None
 
 
 
 
193
  st.subheader("Generated text:")
194
 
195
  def generate():
 
123
  "template": "plotly_white",
124
  })
125
 
126
+ language = "English"
127
  cls_model_name = st.selectbox(
128
  ATTRIBUTE_MODEL_LABEL[language],
129
  ATTRIBUTE_MODELS[language]
 
136
  cls_model_config = AutoConfig.from_pretrained(cls_model_name)
137
  if cls_model_config.problem_type == "multi_label_classification":
138
  label2id = cls_model_config.label2id
139
+ label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
 
 
 
 
 
 
 
 
140
  target_label_id = label2id[label_key]
141
  act_type = "sigmoid"
142
  elif cls_model_config.problem_type == "single_label_classification":
 
146
  act_type = "sigmoid"
147
  else:
148
  label2id = cls_model_config.label2id
149
+ filtered_label2id = {k: v for k, v in label2id.items() if "negative" in k}
150
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
151
  target_label_id = label2id[label_key]
152
  act_type = "softmax"
153
  st.write(WARNING_TEXT[language])
154
  show_pos_alpha = st.checkbox("Show positive alphas", value=False)
 
 
 
 
155
  st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}")
156
  if act_type == "softmax":
157
  alpha = st.slider("α", min_value=-40, max_value=40 if show_pos_alpha else 0, step=1, value=0)
158
  else:
159
+ alpha = st.slider("α", min_value=-5, max_value=5 if show_pos_alpha else 0, step=1, value=0)
160
  entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=10., step=.1, value=2.)
161
  plot_idx = np.argmin(np.abs(entropy_threshold - x_s))
162
  scatter_tip = go.Scatter({
 
179
  auth_token = os.environ.get('TOKEN') or True
180
  fp16 = st.checkbox("FP16", value=True)
181
  st.session_state["generated_text"] = None
182
+ if "sst" in cls_model_name:
183
+ prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie")
184
+ else:
185
+ prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
186
  st.subheader("Generated text:")
187
 
188
  def generate():