Балаганский Никита Николаевич
commited on
Commit
•
ffdcf9b
1
Parent(s):
e15d353
add plotly chart
Browse files- app.py +21 -4
- requirements.txt +1 -0
app.py
CHANGED
@@ -13,6 +13,10 @@ import tokenizers
|
|
13 |
from sampling import CAIFSampler, TopKWithTemperatureSampler
|
14 |
from generator import Generator
|
15 |
|
|
|
|
|
|
|
|
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
ATTRIBUTE_MODELS = {
|
@@ -34,7 +38,7 @@ LANGUAGE_MODELS = {
|
|
34 |
'sberbank-ai/rugpt3small_based_on_gpt2',
|
35 |
"sberbank-ai/rugpt3large_based_on_gpt2"
|
36 |
),
|
37 |
-
"English": ("
|
38 |
}
|
39 |
|
40 |
ATTRIBUTE_MODEL_LABEL = {
|
@@ -65,6 +69,21 @@ PROMPT_EXAMPLE = {
|
|
65 |
|
66 |
def main():
|
67 |
st.header("CAIF")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
language = st.selectbox("Language", ("English", "Russian"))
|
69 |
cls_model_name = st.selectbox(
|
70 |
ATTRIBUTE_MODEL_LABEL[language],
|
@@ -87,6 +106,7 @@ def main():
|
|
87 |
target_label_id = 1
|
88 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
89 |
alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
|
|
|
90 |
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
|
91 |
auth_token = os.environ.get('TOKEN') or True
|
92 |
fp16 = st.checkbox("FP16", value=True)
|
@@ -103,9 +123,6 @@ def main():
|
|
103 |
st.subheader("Generated text:")
|
104 |
st.markdown(text)
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
|
110 |
def load_generator(lm_model_name: str) -> Generator:
|
111 |
with st.spinner('Loading language model...'):
|
|
|
13 |
from sampling import CAIFSampler, TopKWithTemperatureSampler
|
14 |
from generator import Generator
|
15 |
|
16 |
+
import pickle
|
17 |
+
|
18 |
+
from plotly import graph_objects as go
|
19 |
+
|
20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
|
22 |
ATTRIBUTE_MODELS = {
|
|
|
38 |
'sberbank-ai/rugpt3small_based_on_gpt2',
|
39 |
"sberbank-ai/rugpt3large_based_on_gpt2"
|
40 |
),
|
41 |
+
"English": ("gpt2", "distilgpt2", "EleutherAI/gpt-neo-1.3B")
|
42 |
}
|
43 |
|
44 |
ATTRIBUTE_MODEL_LABEL = {
|
|
|
69 |
|
70 |
def main():
|
71 |
st.header("CAIF")
|
72 |
+
with open("entropy_cdf.pkl", "rb") as inp:
|
73 |
+
x_s, y_s = pickle.load(inp)
|
74 |
+
scatter = go.Scatter({
|
75 |
+
"x": x_s,
|
76 |
+
"y": y_s,
|
77 |
+
"name": "GPT2",
|
78 |
+
"mode": "lines",
|
79 |
+
}
|
80 |
+
)
|
81 |
+
layout = go.Layout({
|
82 |
+
"yaxis": {"title": "CAIF step probability"},
|
83 |
+
"xaxis": {"title": "Entropy threshold"},
|
84 |
+
"template": "plotly_white",
|
85 |
+
})
|
86 |
+
figure = go.Figure(data=[scatter], layout=layout)
|
87 |
language = st.selectbox("Language", ("English", "Russian"))
|
88 |
cls_model_name = st.selectbox(
|
89 |
ATTRIBUTE_MODEL_LABEL[language],
|
|
|
106 |
target_label_id = 1
|
107 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
108 |
alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
|
109 |
+
st.plotly_chart(figure)
|
110 |
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=2.)
|
111 |
auth_token = os.environ.get('TOKEN') or True
|
112 |
fp16 = st.checkbox("FP16", value=True)
|
|
|
123 |
st.subheader("Generated text:")
|
124 |
st.markdown(text)
|
125 |
|
|
|
|
|
|
|
126 |
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
|
127 |
def load_generator(lm_model_name: str) -> Generator:
|
128 |
with st.spinner('Loading language model...'):
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
streamlit
|
2 |
transformers
|
3 |
torch
|
|
|
|
1 |
streamlit
|
2 |
transformers
|
3 |
torch
|
4 |
+
plotly
|