Балаганский Никита Николаевич
commited on
Commit
•
bf20d4a
1
Parent(s):
cce302e
fix
Browse files
app.py
CHANGED
@@ -30,8 +30,11 @@ ATTRIBUTE_MODELS = {
|
|
30 |
}
|
31 |
|
32 |
LANGUAGE_MODELS = {
|
33 |
-
"Russian": (
|
34 |
-
|
|
|
|
|
|
|
35 |
}
|
36 |
|
37 |
ATTRIBUTE_MODEL_LABEL = {
|
@@ -81,10 +84,10 @@ def main():
|
|
81 |
label2id = cls_model_config.label2id
|
82 |
print(list(label2id.keys()))
|
83 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
|
84 |
-
target_label_id =
|
85 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
86 |
alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
|
87 |
-
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=
|
88 |
auth_token = os.environ.get('TOKEN') or True
|
89 |
with st.spinner('Running inference...'):
|
90 |
text = inference(
|
@@ -101,13 +104,13 @@ def main():
|
|
101 |
|
102 |
|
103 |
|
104 |
-
@st.cache
|
105 |
def load_generator(lm_model_name: str) -> Generator:
|
106 |
with st.spinner('Loading language model...'):
|
107 |
generator = Generator(lm_model_name=lm_model_name, device=device)
|
108 |
return generator
|
109 |
|
110 |
-
|
111 |
def load_sampler(cls_model_name, lm_tokenizer):
|
112 |
with st.spinner('Loading classifier model...'):
|
113 |
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
|
|
|
30 |
}
|
31 |
|
32 |
LANGUAGE_MODELS = {
|
33 |
+
"Russian": (
|
34 |
+
'sberbank-ai/rugpt3small_based_on_gpt2',
|
35 |
+
"sberbank-ai/rugpt3large_based_on_gpt2"
|
36 |
+
),
|
37 |
+
"English": ("distilgpt2", "gpt2", "EleutherAI/gpt-neo-1.3B")
|
38 |
}
|
39 |
|
40 |
ATTRIBUTE_MODEL_LABEL = {
|
|
|
84 |
label2id = cls_model_config.label2id
|
85 |
print(list(label2id.keys()))
|
86 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
|
87 |
+
target_label_id = 1
|
88 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
89 |
alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
|
90 |
+
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
|
91 |
auth_token = os.environ.get('TOKEN') or True
|
92 |
with st.spinner('Running inference...'):
|
93 |
text = inference(
|
|
|
104 |
|
105 |
|
106 |
|
107 |
+
@st.cache
|
108 |
def load_generator(lm_model_name: str) -> Generator:
|
109 |
with st.spinner('Loading language model...'):
|
110 |
generator = Generator(lm_model_name=lm_model_name, device=device)
|
111 |
return generator
|
112 |
|
113 |
+
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
|
114 |
def load_sampler(cls_model_name, lm_tokenizer):
|
115 |
with st.spinner('Loading classifier model...'):
|
116 |
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
|