Балаганский Никита Николаевич commited on
Commit
bf20d4a
1 Parent(s): cce302e
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -30,8 +30,11 @@ ATTRIBUTE_MODELS = {
30
  }
31
 
32
  LANGUAGE_MODELS = {
33
- "Russian": ('sberbank-ai/rugpt3small_based_on_gpt2',),
34
- "English": ("distilgpt2", )
 
 
 
35
  }
36
 
37
  ATTRIBUTE_MODEL_LABEL = {
@@ -81,10 +84,10 @@ def main():
81
  label2id = cls_model_config.label2id
82
  print(list(label2id.keys()))
83
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
84
- target_label_id = 0
85
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
86
  alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
87
- entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=5.)
88
  auth_token = os.environ.get('TOKEN') or True
89
  with st.spinner('Running inference...'):
90
  text = inference(
@@ -101,13 +104,13 @@ def main():
101
 
102
 
103
 
104
- @st.cache(hash_funcs={str: lambda lm_model_name: hash(lm_model_name)}, allow_output_mutation=True)
105
  def load_generator(lm_model_name: str) -> Generator:
106
  with st.spinner('Loading language model...'):
107
  generator = Generator(lm_model_name=lm_model_name, device=device)
108
  return generator
109
 
110
-
111
  def load_sampler(cls_model_name, lm_tokenizer):
112
  with st.spinner('Loading classifier model...'):
113
  sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
 
30
  }
31
 
32
  LANGUAGE_MODELS = {
33
+ "Russian": (
34
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
35
+ "sberbank-ai/rugpt3large_based_on_gpt2"
36
+ ),
37
+ "English": ("distilgpt2", "gpt2", "EleutherAI/gpt-neo-1.3B")
38
  }
39
 
40
  ATTRIBUTE_MODEL_LABEL = {
 
84
  label2id = cls_model_config.label2id
85
  print(list(label2id.keys()))
86
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
87
+ target_label_id = 1
88
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
89
  alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
90
+ entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
91
  auth_token = os.environ.get('TOKEN') or True
92
  with st.spinner('Running inference...'):
93
  text = inference(
 
104
 
105
 
106
 
107
+ @st.cache
108
  def load_generator(lm_model_name: str) -> Generator:
109
  with st.spinner('Loading language model...'):
110
  generator = Generator(lm_model_name=lm_model_name, device=device)
111
  return generator
112
 
113
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
114
  def load_sampler(cls_model_name, lm_tokenizer):
115
  with st.spinner('Loading classifier model...'):
116
  sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)