Spaces:
Runtime error
Runtime error
add app
Browse files- README.md +6 -7
- app/SessionState.py +107 -0
- app/app.py +96 -0
- app/prompts.py +18 -0
- requirements.txt +7 -0
README.md
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
---
|
2 |
-
title: Indonesian Story
|
3 |
emoji: 🔥
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
app_file: app.py
|
8 |
-
pinned:
|
9 |
---
|
10 |
# Configuration
|
11 |
`title`: _string_
|
@@ -24,5 +24,4 @@ Path is relative to the root of the repository.
|
|
24 |
`pinned`: _boolean_
|
25 |
Whether the Space stays on top of your list.
|
26 |
|
27 |
-
# Indonesian Story
|
28 |
-
GPT2 generated Stories / Poems
|
|
|
1 |
---
|
2 |
+
title: Indonesian Story Generator
|
3 |
emoji: 🔥
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: orange
|
6 |
sdk: streamlit
|
7 |
+
app_file: app/app.py
|
8 |
+
pinned: true
|
9 |
---
|
10 |
# Configuration
|
11 |
`title`: _string_
|
|
|
24 |
`pinned`: _boolean_
|
25 |
Whether the Space stays on top of your list.
|
26 |
|
27 |
+
# Indonesian Story
|
|
app/SessionState.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Hack to add per-session state to Streamlit.
|
2 |
+
Usage
|
3 |
+
-----
|
4 |
+
>>> import SessionState
|
5 |
+
>>>
|
6 |
+
>>> session_state = SessionState.get(user_name='', favorite_color='black')
|
7 |
+
>>> session_state.user_name
|
8 |
+
''
|
9 |
+
>>> session_state.user_name = 'Mary'
|
10 |
+
>>> session_state.favorite_color
|
11 |
+
'black'
|
12 |
+
Since you set user_name above, next time your script runs this will be the
|
13 |
+
result:
|
14 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
15 |
+
>>> session_state.user_name
|
16 |
+
'Mary'
|
17 |
+
"""
|
18 |
+
try:
|
19 |
+
import streamlit.ReportThread as ReportThread
|
20 |
+
from streamlit.server.Server import Server
|
21 |
+
except Exception:
|
22 |
+
# Streamlit >= 0.65.0
|
23 |
+
import streamlit.report_thread as ReportThread
|
24 |
+
from streamlit.server.server import Server
|
25 |
+
|
26 |
+
|
27 |
+
class SessionState(object):
|
28 |
+
def __init__(self, **kwargs):
|
29 |
+
"""A new SessionState object.
|
30 |
+
Parameters
|
31 |
+
----------
|
32 |
+
**kwargs : any
|
33 |
+
Default values for the session state.
|
34 |
+
Example
|
35 |
+
-------
|
36 |
+
>>> session_state = SessionState(user_name='', favorite_color='black')
|
37 |
+
>>> session_state.user_name = 'Mary'
|
38 |
+
''
|
39 |
+
>>> session_state.favorite_color
|
40 |
+
'black'
|
41 |
+
"""
|
42 |
+
for key, val in kwargs.items():
|
43 |
+
setattr(self, key, val)
|
44 |
+
|
45 |
+
|
46 |
+
def get(**kwargs):
|
47 |
+
"""Gets a SessionState object for the current session.
|
48 |
+
Creates a new object if necessary.
|
49 |
+
Parameters
|
50 |
+
----------
|
51 |
+
**kwargs : any
|
52 |
+
Default values you want to add to the session state, if we're creating a
|
53 |
+
new one.
|
54 |
+
Example
|
55 |
+
-------
|
56 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
57 |
+
>>> session_state.user_name
|
58 |
+
''
|
59 |
+
>>> session_state.user_name = 'Mary'
|
60 |
+
>>> session_state.favorite_color
|
61 |
+
'black'
|
62 |
+
Since you set user_name above, next time your script runs this will be the
|
63 |
+
result:
|
64 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
65 |
+
>>> session_state.user_name
|
66 |
+
'Mary'
|
67 |
+
"""
|
68 |
+
# Hack to get the session object from Streamlit.
|
69 |
+
|
70 |
+
ctx = ReportThread.get_report_ctx()
|
71 |
+
|
72 |
+
this_session = None
|
73 |
+
|
74 |
+
current_server = Server.get_current()
|
75 |
+
if hasattr(current_server, '_session_infos'):
|
76 |
+
# Streamlit < 0.56
|
77 |
+
session_infos = Server.get_current()._session_infos.values()
|
78 |
+
else:
|
79 |
+
session_infos = Server.get_current()._session_info_by_id.values()
|
80 |
+
|
81 |
+
for session_info in session_infos:
|
82 |
+
s = session_info.session
|
83 |
+
if (
|
84 |
+
# Streamlit < 0.54.0
|
85 |
+
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
|
86 |
+
or
|
87 |
+
# Streamlit >= 0.54.0
|
88 |
+
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
|
89 |
+
or
|
90 |
+
# Streamlit >= 0.65.2
|
91 |
+
(not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
|
92 |
+
):
|
93 |
+
this_session = s
|
94 |
+
|
95 |
+
if this_session is None:
|
96 |
+
raise RuntimeError(
|
97 |
+
"Oh noes. Couldn't get your Streamlit Session object. "
|
98 |
+
'Are you doing something fancy with threads?')
|
99 |
+
|
100 |
+
# Got the session object! Now let's attach some state into it.
|
101 |
+
|
102 |
+
if not hasattr(this_session, '_custom_session_state'):
|
103 |
+
this_session._custom_session_state = SessionState(**kwargs)
|
104 |
+
|
105 |
+
return this_session._custom_session_state
|
106 |
+
|
107 |
+
__all__ = ['get']
|
app/app.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import SessionState
|
3 |
+
from mtranslate import translate
|
4 |
+
from prompts import PROMPT_LIST
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
from transformers import pipeline, set_seed
|
8 |
+
|
9 |
+
# st.set_page_config(page_title="Image Search")
|
10 |
+
|
11 |
+
# vector_length = 128
|
12 |
+
model_name = "cahya/gpt2-small-indonesian-story"
|
13 |
+
|
14 |
+
|
15 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
16 |
+
def get_generator():
|
17 |
+
st.write("Loading the GPT2 model...")
|
18 |
+
text_generator = pipeline('text-generation', model=model_name)
|
19 |
+
return text_generator
|
20 |
+
|
21 |
+
|
22 |
+
@st.cache(suppress_st_warning=True)
|
23 |
+
def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
24 |
+
temperature: float = 1.0, max_time: float = None):
|
25 |
+
st.write("Cache miss: process")
|
26 |
+
set_seed(42)
|
27 |
+
result = text_generator(text, max_length=max_length, do_sample=do_sample,
|
28 |
+
top_k=top_k, top_p=top_p, temperature=temperature, max_time=max_time)
|
29 |
+
return result
|
30 |
+
|
31 |
+
|
32 |
+
st.title("Indonesian Story Generator")
|
33 |
+
|
34 |
+
st.markdown(
|
35 |
+
"""
|
36 |
+
This application is a demo for Indonesian Story Generator using GPT2.
|
37 |
+
"""
|
38 |
+
)
|
39 |
+
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
|
40 |
+
ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
|
41 |
+
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
|
42 |
+
# Update prompt
|
43 |
+
if session_state.prompt is None:
|
44 |
+
session_state.prompt = prompt
|
45 |
+
elif session_state.prompt is not None and (prompt != session_state.prompt):
|
46 |
+
session_state.prompt = prompt
|
47 |
+
session_state.prompt_box = None
|
48 |
+
session_state.text = None
|
49 |
+
else:
|
50 |
+
session_state.prompt = prompt
|
51 |
+
|
52 |
+
# Update prompt box
|
53 |
+
if session_state.prompt == "Custom":
|
54 |
+
session_state.prompt_box = "Enter your text here"
|
55 |
+
else:
|
56 |
+
if session_state.prompt is not None and session_state.prompt_box is None:
|
57 |
+
session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])
|
58 |
+
|
59 |
+
session_state.text = st.text_area("Enter text", session_state.prompt_box)
|
60 |
+
|
61 |
+
temp = st.sidebar.slider(
|
62 |
+
"Temperature",
|
63 |
+
value=1.0,
|
64 |
+
min_value=0.0,
|
65 |
+
max_value=100.0
|
66 |
+
)
|
67 |
+
|
68 |
+
top_k = st.sidebar.number_input(
|
69 |
+
"Top k",
|
70 |
+
value=25
|
71 |
+
)
|
72 |
+
|
73 |
+
top_p = st.sidebar.number_input(
|
74 |
+
"Top p",
|
75 |
+
value=0.95
|
76 |
+
)
|
77 |
+
|
78 |
+
text_generator = get_generator()
|
79 |
+
if st.button("Run"):
|
80 |
+
with st.spinner(text="Getting results..."):
|
81 |
+
st.subheader("Result")
|
82 |
+
time_start = time.time()
|
83 |
+
result = process(text=session_state.text, top_k=int(top_k), top_p=float(top_p))
|
84 |
+
time_end = time.time()
|
85 |
+
time_diff = time_end-time_start
|
86 |
+
#print(f"Text generated in {time_diff} seconds")
|
87 |
+
result = result[0]["generated_text"]
|
88 |
+
st.write(result.replace("\n", " \n"))
|
89 |
+
st.text("Translation")
|
90 |
+
translation = translate(result, "en", "id")
|
91 |
+
st.write(translation.replace("\n", " \n"))
|
92 |
+
|
93 |
+
# Reset state
|
94 |
+
session_state.prompt = None
|
95 |
+
session_state.prompt_box = None
|
96 |
+
session_state.text = None
|
app/prompts.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PROMPT_LIST = {
|
2 |
+
"City": [
|
3 |
+
"Vienna is the national capital, largest city, and one of nine states of Austria. Vienna is Austria's most populous city, with about 2 million inhabitants, and its cultural, economic, and political centre. It is the 6th-largest city by population within city limits in the European Union",
|
4 |
+
"Sydney is the capital city of the state of New South Wales, and the most populous city in Australia and Oceania.",
|
5 |
+
"Ubud is a town on the Indonesian island of Bali in Ubud District, located amongst rice paddies and steep ravines in the central foothills of the Gianyar regency. Promoted as an arts and culture centre, it has developed a large tourism industry.",
|
6 |
+
"Jakarta is the capital of Indonesia"
|
7 |
+
],
|
8 |
+
"People": [
|
9 |
+
"Albert Einstein was a German-born theoretical physicist, widely acknowledged to be one of the greatest physicists of all time. Einstein is known for developing the theory of relativity, but he also made important contributions to the development of the theory of quantum mechanics.",
|
10 |
+
"Geoffrey Everest Hinton is a British-Canadian cognitive psychologist and computer scientist, most noted for his work on artificial neural networks.",
|
11 |
+
"Pramoedya Ananta Toer was an Indonesian author of novels, short stories, essays, polemics and histories of his homeland and its people."
|
12 |
+
],
|
13 |
+
"Building": [
|
14 |
+
"Borobudur is a 7th-century Mahayana Buddhist temple in Indonesia. It is the world's largest Buddhist temple. The temple consists of nine stacked platforms, six square and three circular, topped by a central dome. It is decorated with 2,672 relief panels and 504 Buddha statues.",
|
15 |
+
"The Statue of Liberty is a colossal neoclassical sculpture on Liberty Island in New York Harbor within New York City, in the United States.",
|
16 |
+
"Machu Picchu is a 15th-century Inca citadel, located in the Eastern Cordillera of southern Peru, on a 2,430-meter mountain ridge."
|
17 |
+
]
|
18 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
datasets
|
5 |
+
mtranslate
|
6 |
+
# streamlit version 0.67.1 is needed due to issue with caching
|
7 |
+
streamlit==0.67.1
|