Балаганский Никита Николаевич commited on
Commit
9320186
1 Parent(s): ab4c068

add states

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. generator.py +2 -0
app.py CHANGED
@@ -179,6 +179,7 @@ def main():
179
  st.plotly_chart(figure, use_container_width=True)
180
  auth_token = os.environ.get('TOKEN') or True
181
  fp16 = st.checkbox("FP16", value=True)
 
182
 
183
  def generate():
184
  text = inference(
@@ -192,8 +193,7 @@ def main():
192
  act_type=act_type
193
  )
194
  st.subheader("Generated text:")
195
- st.write(text)
196
- generate()
197
  st.button("Generate new", on_click=generate())
198
 
199
 
 
179
  st.plotly_chart(figure, use_container_width=True)
180
  auth_token = os.environ.get('TOKEN') or True
181
  fp16 = st.checkbox("FP16", value=True)
182
+ st.session_state["generated_text"] = None
183
 
184
  def generate():
185
  text = inference(
 
193
  act_type=act_type
194
  )
195
  st.subheader("Generated text:")
196
+ st.write(st.session_state["generated_text"])
 
197
  st.button("Generate new", on_click=generate())
198
 
199
 
generator.py CHANGED
@@ -2,6 +2,7 @@ from typing import Optional, Union
2
 
3
  import torch
4
  import transformers
 
5
 
6
 
7
  class Generator:
@@ -69,6 +70,7 @@ class Generator:
69
  progress_bar.progress((i+1)/max_length)
70
  if ended_sequences.all():
71
  break
 
72
 
73
  return (
74
  [
 
2
 
3
  import torch
4
  import transformers
5
+ import streamlit as st
6
 
7
 
8
  class Generator:
 
70
  progress_bar.progress((i+1)/max_length)
71
  if ended_sequences.all():
72
  break
73
+ st.session_state["generated_text"] = self.tokenizer.decode(input_ids[0])
74
 
75
  return (
76
  [