Балаганский Никита Николаевич
commited on
Commit
•
c3ce809
1
Parent(s):
53cf441
add languages
Browse files
app.py
CHANGED
@@ -15,35 +15,76 @@ from generator import Generator
|
|
15 |
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def main():
|
20 |
st.header("CAIF")
|
|
|
21 |
cls_model_name = st.selectbox(
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
'tinkoff-ai/response-quality-classifier-base',
|
26 |
-
'tinkoff-ai/response-quality-classifier-large',
|
27 |
-
"SkolkovoInstitute/roberta_toxicity_classifier"
|
28 |
-
)
|
29 |
)
|
30 |
lm_model_name = st.selectbox(
|
31 |
-
|
32 |
-
|
33 |
)
|
34 |
cls_model_config = AutoConfig.from_pretrained(cls_model_name)
|
35 |
if cls_model_config.problem_type == "multi_label_classification":
|
36 |
label2id = cls_model_config.label2id
|
37 |
-
label_key = st.selectbox(
|
38 |
target_label_id = label2id[label_key]
|
39 |
else:
|
40 |
label2id = cls_model_config.label2id
|
41 |
print(list(label2id.keys()))
|
42 |
-
label_key = st.selectbox(
|
43 |
-
target_label_id =
|
44 |
-
prompt = st.text_input(
|
45 |
-
alpha = st.slider("Alpha
|
46 |
-
entropy_threshold = st.slider("Entropy
|
47 |
auth_token = os.environ.get('TOKEN') or True
|
48 |
with st.spinner('Running inference...'):
|
49 |
text = inference(
|
|
|
15 |
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
+
ATTRIBUTE_MODELS = {
|
19 |
+
"Russian": (
|
20 |
+
"cointegrated/rubert-tiny-toxicity",
|
21 |
+
'tinkoff-ai/response-quality-classifier-tiny',
|
22 |
+
'tinkoff-ai/response-quality-classifier-base',
|
23 |
+
'tinkoff-ai/response-quality-classifier-large',
|
24 |
+
"SkolkovoInstitute/roberta_toxicity_classifier",
|
25 |
+
"SkolkovoInstitute/russian_toxicity_classifier"
|
26 |
+
),
|
27 |
+
"English": (
|
28 |
+
"unitary/toxic-bert",
|
29 |
+
)
|
30 |
+
}
|
31 |
+
|
32 |
+
LANGUAGE_MODELS = {
|
33 |
+
"Russian": ('sberbank-ai/rugpt3small_based_on_gpt2',),
|
34 |
+
"Eanglish": ("distilgpt2")
|
35 |
+
}
|
36 |
+
|
37 |
+
ATTRIBUTE_MODEL_LABEL = {
|
38 |
+
"Russian": 'Выберите модель классификации',
|
39 |
+
"English": "Choose attribute model"
|
40 |
+
}
|
41 |
+
|
42 |
+
LM_LABEL = {
|
43 |
+
"English": "Choose language model",
|
44 |
+
"Russian": "Выберите языковую модель"
|
45 |
+
}
|
46 |
+
|
47 |
+
ATTRIBUTE_LABEL = {
|
48 |
+
"Russian": "Веберите нужный атрибут текста",
|
49 |
+
"English": "Choose desired attribute",
|
50 |
+
}
|
51 |
+
|
52 |
+
TEXT_PROMPT_LABEL = {
|
53 |
+
"English": "Text prompt",
|
54 |
+
"Russian": "Начало текста"
|
55 |
+
}
|
56 |
+
|
57 |
+
PROMPT_EXAMPLE = {
|
58 |
+
"English": "Hello, today I",
|
59 |
+
"Russian": "Привет, сегодня я"
|
60 |
+
}
|
61 |
+
|
62 |
|
63 |
def main():
|
64 |
st.header("CAIF")
|
65 |
+
language = st.selectbox("Language", ("English", "Russian"))
|
66 |
cls_model_name = st.selectbox(
|
67 |
+
ATTRIBUTE_MODEL_LABEL[language],
|
68 |
+
ATTRIBUTE_MODELS[language]
|
69 |
+
|
|
|
|
|
|
|
|
|
70 |
)
|
71 |
lm_model_name = st.selectbox(
|
72 |
+
LM_LABEL[language],
|
73 |
+
LANGUAGE_MODELS[language]
|
74 |
)
|
75 |
cls_model_config = AutoConfig.from_pretrained(cls_model_name)
|
76 |
if cls_model_config.problem_type == "multi_label_classification":
|
77 |
label2id = cls_model_config.label2id
|
78 |
+
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
|
79 |
target_label_id = label2id[label_key]
|
80 |
else:
|
81 |
label2id = cls_model_config.label2id
|
82 |
print(list(label2id.keys()))
|
83 |
+
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
|
84 |
+
target_label_id = 0
|
85 |
+
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
86 |
+
alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
|
87 |
+
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=0.)
|
88 |
auth_token = os.environ.get('TOKEN') or True
|
89 |
with st.spinner('Running inference...'):
|
90 |
text = inference(
|