Балаганский Никита Николаевич commited on
Commit
147f159
1 Parent(s): 47f9ff2
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -6,6 +6,8 @@ import streamlit as st
6
  import torch
7
 
8
  import transformers
 
 
9
  import tokenizers
10
 
11
  from sampling import CAIFSampler, TopKWithTemperatureSampler
@@ -18,14 +20,18 @@ def main():
18
  st.header("CAIF")
19
  cls_model_name = st.selectbox(
20
  'Выберите модель классификации',
21
- ('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
22
- 'tinkoff-ai/response-quality-classifier-large', "SkolkovoInstitute/roberta_toxicity_classifier")
 
 
 
 
23
  )
24
  lm_model_name = st.selectbox(
25
  'Выберите языковую модель',
26
  ('sberbank-ai/rugpt3small_based_on_gpt2',)
27
  )
28
- cls_model_config = transformers.AutoConfig.from_pretrained(cls_model_name)
29
  if cls_model_config.problem_type == "multi_label_classification":
30
  label2id = cls_model_config.label2id
31
  label_key = st.selectbox("Веберите нужный атрибут текста", label2id.keys())
 
6
  import torch
7
 
8
  import transformers
9
+
10
+ from transformers import AutoConfig
11
  import tokenizers
12
 
13
  from sampling import CAIFSampler, TopKWithTemperatureSampler
 
20
  st.header("CAIF")
21
  cls_model_name = st.selectbox(
22
  'Выберите модель классификации',
23
+ (
24
+ 'tinkoff-ai/response-quality-classifier-tiny',
25
+ 'tinkoff-ai/response-quality-classifier-base',
26
+ 'tinkoff-ai/response-quality-classifier-large',
27
+ "SkolkovoInstitute/roberta_toxicity_classifier"
28
+ )
29
  )
30
  lm_model_name = st.selectbox(
31
  'Выберите языковую модель',
32
  ('sberbank-ai/rugpt3small_based_on_gpt2',)
33
  )
34
+ cls_model_config = AutoConfig.from_pretrained(cls_model_name)
35
  if cls_model_config.problem_type == "multi_label_classification":
36
  label2id = cls_model_config.label2id
37
  label_key = st.selectbox("Веберите нужный атрибут текста", label2id.keys())