Балаганский Никита Николаевич
commited on
Commit
•
f57bdfa
1
Parent(s):
895c44e
fix
Browse files- app.py +11 -7
- sampling.py +17 -7
app.py
CHANGED
@@ -24,14 +24,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
24 |
ATTRIBUTE_MODELS = {
|
25 |
"Russian": (
|
26 |
"cointegrated/rubert-tiny-toxicity",
|
27 |
-
'tinkoff-ai/response-quality-classifier-tiny',
|
28 |
-
'tinkoff-ai/response-quality-classifier-base',
|
29 |
-
'tinkoff-ai/response-quality-classifier-large',
|
30 |
-
"SkolkovoInstitute/roberta_toxicity_classifier",
|
31 |
"SkolkovoInstitute/russian_toxicity_classifier"
|
32 |
),
|
33 |
"English": (
|
34 |
"unitary/toxic-bert",
|
|
|
35 |
)
|
36 |
}
|
37 |
|
@@ -72,7 +69,7 @@ WARNING_TEXT = {
|
|
72 |
"English": """
|
73 |
**Warning!**
|
74 |
|
75 |
-
If you are clicking checkbox bellow positive""" + r"$\alpha$" + """ values for CAIF sampling become available.
|
76 |
It means that language model will be forced to produce toxic or/and abusive text.
|
77 |
This space is only a demonstration of our method for controllable text generation
|
78 |
and we are not responsible for the content produced by this method.
|
@@ -128,11 +125,17 @@ def main():
|
|
128 |
label2id = cls_model_config.label2id
|
129 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
|
130 |
target_label_id = label2id[label_key]
|
131 |
-
|
|
|
132 |
label2id = cls_model_config.label2id
|
133 |
-
print(list(label2id.keys()))
|
134 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
|
135 |
target_label_id = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
st.write(WARNING_TEXT[language])
|
137 |
show_pos_alpha = st.checkbox("Show positive alphas", value=False)
|
138 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
@@ -168,6 +171,7 @@ def main():
|
|
168 |
target_label_id=target_label_id,
|
169 |
entropy_threshold=entropy_threshold,
|
170 |
fp16=fp16,
|
|
|
171 |
)
|
172 |
st.subheader("Generated text:")
|
173 |
st.write(text)
|
|
|
24 |
ATTRIBUTE_MODELS = {
|
25 |
"Russian": (
|
26 |
"cointegrated/rubert-tiny-toxicity",
|
|
|
|
|
|
|
|
|
27 |
"SkolkovoInstitute/russian_toxicity_classifier"
|
28 |
),
|
29 |
"English": (
|
30 |
"unitary/toxic-bert",
|
31 |
+
"distilbert-base-uncased-finetuned-sst-2-english"
|
32 |
)
|
33 |
}
|
34 |
|
|
|
69 |
"English": """
|
70 |
**Warning!**
|
71 |
|
72 |
+
If you are clicking checkbox bellow positive """ + r"$\alpha$" + """ values for CAIF sampling become available.
|
73 |
It means that language model will be forced to produce toxic or/and abusive text.
|
74 |
This space is only a demonstration of our method for controllable text generation
|
75 |
and we are not responsible for the content produced by this method.
|
|
|
125 |
label2id = cls_model_config.label2id
|
126 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
|
127 |
target_label_id = label2id[label_key]
|
128 |
+
act_type = "sigmoid"
|
129 |
+
elif cls_model_config.problem_type == "single_label_classification":
|
130 |
label2id = cls_model_config.label2id
|
|
|
131 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
|
132 |
target_label_id = 1
|
133 |
+
act_type = "sigmoid"
|
134 |
+
else:
|
135 |
+
label2id = cls_model_config.label2id
|
136 |
+
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
|
137 |
+
target_label_id = label2id[label_key]
|
138 |
+
act_type = "softmax"
|
139 |
st.write(WARNING_TEXT[language])
|
140 |
show_pos_alpha = st.checkbox("Show positive alphas", value=False)
|
141 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
|
|
171 |
target_label_id=target_label_id,
|
172 |
entropy_threshold=entropy_threshold,
|
173 |
fp16=fp16,
|
174 |
+
act_type=act_type
|
175 |
)
|
176 |
st.subheader("Generated text:")
|
177 |
st.write(text)
|
sampling.py
CHANGED
@@ -53,6 +53,7 @@ class CAIFSampler:
|
|
53 |
**kwargs
|
54 |
):
|
55 |
target_cls_id = kwargs["target_cls_id"]
|
|
|
56 |
next_token_logits = output_logis[:, -1]
|
57 |
next_token_log_probs = F.log_softmax(
|
58 |
next_token_logits, dim=-1
|
@@ -83,6 +84,7 @@ class CAIFSampler:
|
|
83 |
top_k_classifier,
|
84 |
classifier_weight,
|
85 |
target_cls_id: int = 0,
|
|
|
86 |
caif_tokens_num=None
|
87 |
):
|
88 |
|
@@ -107,12 +109,15 @@ class CAIFSampler:
|
|
107 |
if self.invert_cls_probs:
|
108 |
classifier_log_probs = torch.log(
|
109 |
1 - self.get_classifier_probs(
|
110 |
-
classifier_input, caif_tokens_num=caif_tokens_num
|
111 |
).view(-1, top_k_classifier)
|
112 |
)
|
113 |
else:
|
114 |
classifier_log_probs = self.get_classifier_log_probs(
|
115 |
-
classifier_input,
|
|
|
|
|
|
|
116 |
).view(-1, top_k_classifier)
|
117 |
|
118 |
next_token_probs = torch.exp(
|
@@ -121,7 +126,7 @@ class CAIFSampler:
|
|
121 |
)
|
122 |
return next_token_probs, top_next_token_log_probs[1]
|
123 |
|
124 |
-
def get_classifier_log_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0):
|
125 |
input_ids = self.classifier_tokenizer(
|
126 |
input, padding=True, return_tensors="pt"
|
127 |
).to(self.device)
|
@@ -131,10 +136,15 @@ class CAIFSampler:
|
|
131 |
input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
|
132 |
if "token_type_ids" in input_ids.keys():
|
133 |
input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
|
134 |
-
logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
|
135 |
-
return torch.log(torch.sigmoid(logits))
|
136 |
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
input_ids = self.classifier_tokenizer(
|
139 |
input, padding=True, return_tensors="pt"
|
140 |
).to(self.device)
|
@@ -142,5 +152,5 @@ class CAIFSampler:
|
|
142 |
input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:]
|
143 |
if "attention_mask" in input_ids.keys():
|
144 |
input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:]
|
145 |
-
logits = self.classifier_model(**input_ids).logits[:,
|
146 |
return torch.sigmoid(logits)
|
|
|
53 |
**kwargs
|
54 |
):
|
55 |
target_cls_id = kwargs["target_cls_id"]
|
56 |
+
act_type = kwargs["act_type"]
|
57 |
next_token_logits = output_logis[:, -1]
|
58 |
next_token_log_probs = F.log_softmax(
|
59 |
next_token_logits, dim=-1
|
|
|
84 |
top_k_classifier,
|
85 |
classifier_weight,
|
86 |
target_cls_id: int = 0,
|
87 |
+
act_type: str = "sigmoid",
|
88 |
caif_tokens_num=None
|
89 |
):
|
90 |
|
|
|
109 |
if self.invert_cls_probs:
|
110 |
classifier_log_probs = torch.log(
|
111 |
1 - self.get_classifier_probs(
|
112 |
+
classifier_input, caif_tokens_num=caif_tokens_num, target_cls_id=target_cls_id
|
113 |
).view(-1, top_k_classifier)
|
114 |
)
|
115 |
else:
|
116 |
classifier_log_probs = self.get_classifier_log_probs(
|
117 |
+
classifier_input,
|
118 |
+
caif_tokens_num=caif_tokens_num,
|
119 |
+
target_cls_id=target_cls_id,
|
120 |
+
act_type=act_type,
|
121 |
).view(-1, top_k_classifier)
|
122 |
|
123 |
next_token_probs = torch.exp(
|
|
|
126 |
)
|
127 |
return next_token_probs, top_next_token_log_probs[1]
|
128 |
|
129 |
+
def get_classifier_log_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0, act_type: str = "sigmoid"):
|
130 |
input_ids = self.classifier_tokenizer(
|
131 |
input, padding=True, return_tensors="pt"
|
132 |
).to(self.device)
|
|
|
136 |
input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
|
137 |
if "token_type_ids" in input_ids.keys():
|
138 |
input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
|
|
|
|
|
139 |
|
140 |
+
if act_type == "sigmoid":
|
141 |
+
logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
|
142 |
+
return F.logsigmoid(logits)
|
143 |
+
if act_type == "softmax":
|
144 |
+
logits = F.log_softmax(self.classifier_model(**input_ids).logits)[:, target_cls_id].squeeze(-1)
|
145 |
+
return logits
|
146 |
+
|
147 |
+
def get_classifier_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0):
|
148 |
input_ids = self.classifier_tokenizer(
|
149 |
input, padding=True, return_tensors="pt"
|
150 |
).to(self.device)
|
|
|
152 |
input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:]
|
153 |
if "attention_mask" in input_ids.keys():
|
154 |
input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:]
|
155 |
+
logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
|
156 |
return torch.sigmoid(logits)
|