Балаганский Никита Николаевич
commited on
Commit
•
d320fdd
1
Parent(s):
6331a08
add target_label_id
Browse files- app.py +8 -2
- sampling.py +8 -5
app.py
CHANGED
@@ -25,6 +25,10 @@ def main():
|
|
25 |
'Выберите языковую модель',
|
26 |
('sberbank-ai/rugpt3small_based_on_gpt2',)
|
27 |
)
|
|
|
|
|
|
|
|
|
28 |
prompt = st.text_input("Начало текста:", "Привет")
|
29 |
alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1)
|
30 |
auth_token = os.environ.get('TOKEN') or True
|
@@ -50,7 +54,9 @@ def load_sampler(cls_model_name, lm_tokenizer):
|
|
50 |
|
51 |
|
52 |
@st.cache
|
53 |
-
def inference(
|
|
|
|
|
54 |
generator = load_generator(lm_model_name=lm_model_name)
|
55 |
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
|
56 |
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
|
@@ -61,6 +67,7 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
|
|
61 |
"temperature": 1.0,
|
62 |
"top_k_classifier": 100,
|
63 |
"classifier_weight": alpha,
|
|
|
64 |
}
|
65 |
generator.set_ordinary_sampler(ordinary_sampler)
|
66 |
if device == "cpu":
|
@@ -74,7 +81,6 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
|
|
74 |
input_prompt=prompt,
|
75 |
max_length=20,
|
76 |
caif_period=1,
|
77 |
-
caif_tokens_num=100,
|
78 |
entropy=None,
|
79 |
**kwargs
|
80 |
)
|
|
|
25 |
'Выберите языковую модель',
|
26 |
('sberbank-ai/rugpt3small_based_on_gpt2',)
|
27 |
)
|
28 |
+
cls_model_config = transformers.AutoConfig.from_pretrained(cls_model_name)
|
29 |
+
label2id = cls_model_config.label2id
|
30 |
+
label_key = st.selectbox("Веберите нужный атрибут текста", label2id.keys())
|
31 |
+
target_label_id = label2id[label_key]
|
32 |
prompt = st.text_input("Начало текста:", "Привет")
|
33 |
alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1)
|
34 |
auth_token = os.environ.get('TOKEN') or True
|
|
|
54 |
|
55 |
|
56 |
@st.cache
|
57 |
+
def inference(
|
58 |
+
lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True, alpha: float = 5, target_label_id: int = 0
|
59 |
+
) -> str:
|
60 |
generator = load_generator(lm_model_name=lm_model_name)
|
61 |
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
|
62 |
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
|
|
|
67 |
"temperature": 1.0,
|
68 |
"top_k_classifier": 100,
|
69 |
"classifier_weight": alpha,
|
70 |
+
"target_cls_id": target_label_id
|
71 |
}
|
72 |
generator.set_ordinary_sampler(ordinary_sampler)
|
73 |
if device == "cpu":
|
|
|
81 |
input_prompt=prompt,
|
82 |
max_length=20,
|
83 |
caif_period=1,
|
|
|
84 |
entropy=None,
|
85 |
**kwargs
|
86 |
)
|
sampling.py
CHANGED
@@ -49,10 +49,11 @@ class CAIFSampler:
|
|
49 |
top_k_classifier,
|
50 |
classifier_weight,
|
51 |
caif_tokens_num=None,
|
|
|
52 |
**kwargs
|
53 |
):
|
|
|
54 |
next_token_logits = output_logis[:, -1]
|
55 |
-
|
56 |
next_token_log_probs = F.log_softmax(
|
57 |
next_token_logits, dim=-1
|
58 |
)
|
@@ -63,7 +64,8 @@ class CAIFSampler:
|
|
63 |
temperature,
|
64 |
top_k_classifier,
|
65 |
classifier_weight,
|
66 |
-
caif_tokens_num=caif_tokens_num
|
|
|
67 |
)
|
68 |
topk_probs = next_token_unnormalized_probs.topk(top_k, -1)
|
69 |
next_tokens = sample_from_values(
|
@@ -80,6 +82,7 @@ class CAIFSampler:
|
|
80 |
temperature,
|
81 |
top_k_classifier,
|
82 |
classifier_weight,
|
|
|
83 |
caif_tokens_num=None
|
84 |
):
|
85 |
|
@@ -109,7 +112,7 @@ class CAIFSampler:
|
|
109 |
)
|
110 |
else:
|
111 |
classifier_log_probs = self.get_classifier_log_probs(
|
112 |
-
classifier_input, caif_tokens_num=caif_tokens_num
|
113 |
).view(-1, top_k_classifier)
|
114 |
|
115 |
next_token_probs = torch.exp(
|
@@ -118,7 +121,7 @@ class CAIFSampler:
|
|
118 |
)
|
119 |
return next_token_probs, top_next_token_log_probs[1]
|
120 |
|
121 |
-
def get_classifier_log_probs(self, input, caif_tokens_num=None):
|
122 |
input_ids = self.classifier_tokenizer(
|
123 |
input, padding=True, return_tensors="pt"
|
124 |
).to(self.device)
|
@@ -128,7 +131,7 @@ class CAIFSampler:
|
|
128 |
input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
|
129 |
if "token_type_ids" in input_ids.keys():
|
130 |
input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
|
131 |
-
logits = self.classifier_model(**input_ids).logits[:,
|
132 |
return torch.log(torch.sigmoid(logits))
|
133 |
|
134 |
def get_classifier_probs(self, input, caif_tokens_num=None):
|
|
|
49 |
top_k_classifier,
|
50 |
classifier_weight,
|
51 |
caif_tokens_num=None,
|
52 |
+
act_type: str = "softmax",
|
53 |
**kwargs
|
54 |
):
|
55 |
+
target_cls_id = kwargs["target_cls_id"]
|
56 |
next_token_logits = output_logis[:, -1]
|
|
|
57 |
next_token_log_probs = F.log_softmax(
|
58 |
next_token_logits, dim=-1
|
59 |
)
|
|
|
64 |
temperature,
|
65 |
top_k_classifier,
|
66 |
classifier_weight,
|
67 |
+
caif_tokens_num=caif_tokens_num,
|
68 |
+
target_cls_id=target_cls_id
|
69 |
)
|
70 |
topk_probs = next_token_unnormalized_probs.topk(top_k, -1)
|
71 |
next_tokens = sample_from_values(
|
|
|
82 |
temperature,
|
83 |
top_k_classifier,
|
84 |
classifier_weight,
|
85 |
+
target_cls_id: int = 0,
|
86 |
caif_tokens_num=None
|
87 |
):
|
88 |
|
|
|
112 |
)
|
113 |
else:
|
114 |
classifier_log_probs = self.get_classifier_log_probs(
|
115 |
+
classifier_input, caif_tokens_num=caif_tokens_num, target_cls_id=target_cls_id,
|
116 |
).view(-1, top_k_classifier)
|
117 |
|
118 |
next_token_probs = torch.exp(
|
|
|
121 |
)
|
122 |
return next_token_probs, top_next_token_log_probs[1]
|
123 |
|
124 |
+
def get_classifier_log_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0):
|
125 |
input_ids = self.classifier_tokenizer(
|
126 |
input, padding=True, return_tensors="pt"
|
127 |
).to(self.device)
|
|
|
131 |
input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
|
132 |
if "token_type_ids" in input_ids.keys():
|
133 |
input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
|
134 |
+
logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
|
135 |
return torch.log(torch.sigmoid(logits))
|
136 |
|
137 |
def get_classifier_probs(self, input, caif_tokens_num=None):
|