Балаганский Никита Николаевич
commited on
Commit
•
7e5a783
1
Parent(s):
d95e99d
fix out
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
15 |
|
16 |
|
17 |
def main():
|
18 |
-
st.
|
19 |
cls_model_name = st.selectbox(
|
20 |
'Выберите модель классификации',
|
21 |
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
|
@@ -26,6 +26,7 @@ def main():
|
|
26 |
('sberbank-ai/rugpt3small_based_on_gpt2',)
|
27 |
)
|
28 |
prompt = st.text_input("Начало текста:", "Привет")
|
|
|
29 |
auth_token = os.environ.get('TOKEN') or True
|
30 |
with st.spinner('Running inference...'):
|
31 |
text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt)
|
@@ -49,7 +50,7 @@ def load_sampler(cls_model_name, lm_tokenizer):
|
|
49 |
|
50 |
|
51 |
@st.cache
|
52 |
-
def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True) -> str:
|
53 |
generator = load_generator(lm_model_name=lm_model_name)
|
54 |
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
|
55 |
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
|
@@ -59,7 +60,7 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
|
|
59 |
"top_k": 20,
|
60 |
"temperature": 1.0,
|
61 |
"top_k_classifier": 100,
|
62 |
-
"classifier_weight":
|
63 |
}
|
64 |
generator.set_ordinary_sampler(ordinary_sampler)
|
65 |
if device == "cpu":
|
|
|
15 |
|
16 |
|
17 |
def main():
|
18 |
+
st.header("CAIF")
|
19 |
cls_model_name = st.selectbox(
|
20 |
'Выберите модель классификации',
|
21 |
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
|
|
|
26 |
('sberbank-ai/rugpt3small_based_on_gpt2',)
|
27 |
)
|
28 |
prompt = st.text_input("Начало текста:", "Привет")
|
29 |
+
alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1)
|
30 |
auth_token = os.environ.get('TOKEN') or True
|
31 |
with st.spinner('Running inference...'):
|
32 |
text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt)
|
|
|
50 |
|
51 |
|
52 |
@st.cache
|
53 |
+
def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True, alpha: float = 5) -> str:
|
54 |
generator = load_generator(lm_model_name=lm_model_name)
|
55 |
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
|
56 |
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
|
|
|
60 |
"top_k": 20,
|
61 |
"temperature": 1.0,
|
62 |
"top_k_classifier": 100,
|
63 |
+
"classifier_weight": alpha,
|
64 |
}
|
65 |
generator.set_ordinary_sampler(ordinary_sampler)
|
66 |
if device == "cpu":
|