Балаганский Никита Николаевич commited on
Commit
ffdcf9b
1 Parent(s): e15d353

add plotly chart

Browse files
Files changed (2) hide show
  1. app.py +21 -4
  2. requirements.txt +1 -0
app.py CHANGED
@@ -13,6 +13,10 @@ import tokenizers
13
  from sampling import CAIFSampler, TopKWithTemperatureSampler
14
  from generator import Generator
15
 
 
 
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  ATTRIBUTE_MODELS = {
@@ -34,7 +38,7 @@ LANGUAGE_MODELS = {
34
  'sberbank-ai/rugpt3small_based_on_gpt2',
35
  "sberbank-ai/rugpt3large_based_on_gpt2"
36
  ),
37
- "English": ("distilgpt2", "gpt2", "EleutherAI/gpt-neo-1.3B")
38
  }
39
 
40
  ATTRIBUTE_MODEL_LABEL = {
@@ -65,6 +69,21 @@ PROMPT_EXAMPLE = {
65
 
66
  def main():
67
  st.header("CAIF")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  language = st.selectbox("Language", ("English", "Russian"))
69
  cls_model_name = st.selectbox(
70
  ATTRIBUTE_MODEL_LABEL[language],
@@ -87,6 +106,7 @@ def main():
87
  target_label_id = 1
88
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
89
  alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
 
90
  entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
91
  auth_token = os.environ.get('TOKEN') or True
92
  fp16 = st.checkbox("FP16", value=True)
@@ -103,9 +123,6 @@ def main():
103
  st.subheader("Generated text:")
104
  st.markdown(text)
105
 
106
-
107
-
108
-
109
  @st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
110
  def load_generator(lm_model_name: str) -> Generator:
111
  with st.spinner('Loading language model...'):
 
13
  from sampling import CAIFSampler, TopKWithTemperatureSampler
14
  from generator import Generator
15
 
16
+ import pickle
17
+
18
+ from plotly import graph_objects as go
19
+
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  ATTRIBUTE_MODELS = {
 
38
  'sberbank-ai/rugpt3small_based_on_gpt2',
39
  "sberbank-ai/rugpt3large_based_on_gpt2"
40
  ),
41
+ "English": ("gpt2", "distilgpt2", "EleutherAI/gpt-neo-1.3B")
42
  }
43
 
44
  ATTRIBUTE_MODEL_LABEL = {
 
69
 
70
  def main():
71
  st.header("CAIF")
72
+ with open("entropy_cdf.pkl", "rb") as inp:
73
+ x_s, y_s = pickle.load(inp)
74
+ scatter = go.Scatter({
75
+ "x": x_s,
76
+ "y": y_s,
77
+ "name": "GPT2",
78
+ "mode": "lines",
79
+ }
80
+ )
81
+ layout = go.Layout({
82
+ "yaxis": {"title": "CAIF step probability"},
83
+ "xaxis": {"title": "Entropy threshold"},
84
+ "template": "plotly_white",
85
+ })
86
+ figure = go.Figure(data=[scatter], layout=layout)
87
  language = st.selectbox("Language", ("English", "Russian"))
88
  cls_model_name = st.selectbox(
89
  ATTRIBUTE_MODEL_LABEL[language],
 
106
  target_label_id = 1
107
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
108
  alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
109
+ st.plotly_chart(figure)
110
  entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
111
  auth_token = os.environ.get('TOKEN') or True
112
  fp16 = st.checkbox("FP16", value=True)
 
123
  st.subheader("Generated text:")
124
  st.markdown(text)
125
 
 
 
 
126
  @st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
127
  def load_generator(lm_model_name: str) -> Generator:
128
  with st.spinner('Loading language model...'):
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  streamlit
2
  transformers
3
  torch
 
 
1
  streamlit
2
  transformers
3
  torch
4
+ plotly