Балаганский Никита Николаевич commited on
Commit
7e5a783
1 Parent(s): d95e99d
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -15,7 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
 
17
  def main():
18
- st.subheader("CAIF")
19
  cls_model_name = st.selectbox(
20
  'Выберите модель классификации',
21
  ('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
@@ -26,6 +26,7 @@ def main():
26
  ('sberbank-ai/rugpt3small_based_on_gpt2',)
27
  )
28
  prompt = st.text_input("Начало текста:", "Привет")
 
29
  auth_token = os.environ.get('TOKEN') or True
30
  with st.spinner('Running inference...'):
31
  text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt)
@@ -49,7 +50,7 @@ def load_sampler(cls_model_name, lm_tokenizer):
49
 
50
 
51
  @st.cache
52
- def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True) -> str:
53
  generator = load_generator(lm_model_name=lm_model_name)
54
  lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
55
  caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
@@ -59,7 +60,7 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
59
  "top_k": 20,
60
  "temperature": 1.0,
61
  "top_k_classifier": 100,
62
- "classifier_weight": 5,
63
  }
64
  generator.set_ordinary_sampler(ordinary_sampler)
65
  if device == "cpu":
 
15
 
16
 
17
  def main():
18
+ st.header("CAIF")
19
  cls_model_name = st.selectbox(
20
  'Выберите модель классификации',
21
  ('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
 
26
  ('sberbank-ai/rugpt3small_based_on_gpt2',)
27
  )
28
  prompt = st.text_input("Начало текста:", "Привет")
29
+ alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1)
30
  auth_token = os.environ.get('TOKEN') or True
31
  with st.spinner('Running inference...'):
32
  text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt)
 
50
 
51
 
52
  @st.cache
53
+ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True, alpha: float = 5) -> str:
54
  generator = load_generator(lm_model_name=lm_model_name)
55
  lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
56
  caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
 
60
  "top_k": 20,
61
  "temperature": 1.0,
62
  "top_k_classifier": 100,
63
+ "classifier_weight": alpha,
64
  }
65
  generator.set_ordinary_sampler(ordinary_sampler)
66
  if device == "cpu":