cahya commited on
Commit
59cee12
1 Parent(s): 8bf08bf
Files changed (5) hide show
  1. README.md +6 -7
  2. app/SessionState.py +107 -0
  3. app/app.py +96 -0
  4. app/prompts.py +18 -0
  5. requirements.txt +7 -0
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: Indonesian Story
3
  emoji: 🔥
4
- colorFrom: pink
5
- colorTo: indigo
6
  sdk: streamlit
7
- app_file: app.py
8
- pinned: false
9
  ---
10
  # Configuration
11
  `title`: _string_
@@ -24,5 +24,4 @@ Path is relative to the root of the repository.
24
  `pinned`: _boolean_
25
  Whether the Space stays on top of your list.
26
 
27
- # Indonesian Story
28
- GPT2 generated Stories / Poems
 
1
  ---
2
+ title: Indonesian Story Generator
3
  emoji: 🔥
4
+ colorFrom: green
5
+ colorTo: orange
6
  sdk: streamlit
7
+ app_file: app/app.py
8
+ pinned: true
9
  ---
10
  # Configuration
11
  `title`: _string_
 
24
  `pinned`: _boolean_
25
  Whether the Space stays on top of your list.
26
 
27
+ # Indonesian Story
 
app/SessionState.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hack to add per-session state to Streamlit.
2
+ Usage
3
+ -----
4
+ >>> import SessionState
5
+ >>>
6
+ >>> session_state = SessionState.get(user_name='', favorite_color='black')
7
+ >>> session_state.user_name
8
+ ''
9
+ >>> session_state.user_name = 'Mary'
10
+ >>> session_state.favorite_color
11
+ 'black'
12
+ Since you set user_name above, next time your script runs this will be the
13
+ result:
14
+ >>> session_state = get(user_name='', favorite_color='black')
15
+ >>> session_state.user_name
16
+ 'Mary'
17
+ """
18
+ try:
19
+ import streamlit.ReportThread as ReportThread
20
+ from streamlit.server.Server import Server
21
+ except Exception:
22
+ # Streamlit >= 0.65.0
23
+ import streamlit.report_thread as ReportThread
24
+ from streamlit.server.server import Server
25
+
26
+
27
+ class SessionState(object):
28
+ def __init__(self, **kwargs):
29
+ """A new SessionState object.
30
+ Parameters
31
+ ----------
32
+ **kwargs : any
33
+ Default values for the session state.
34
+ Example
35
+ -------
36
+ >>> session_state = SessionState(user_name='', favorite_color='black')
37
+ >>> session_state.user_name = 'Mary'
38
+ ''
39
+ >>> session_state.favorite_color
40
+ 'black'
41
+ """
42
+ for key, val in kwargs.items():
43
+ setattr(self, key, val)
44
+
45
+
46
+ def get(**kwargs):
47
+ """Gets a SessionState object for the current session.
48
+ Creates a new object if necessary.
49
+ Parameters
50
+ ----------
51
+ **kwargs : any
52
+ Default values you want to add to the session state, if we're creating a
53
+ new one.
54
+ Example
55
+ -------
56
+ >>> session_state = get(user_name='', favorite_color='black')
57
+ >>> session_state.user_name
58
+ ''
59
+ >>> session_state.user_name = 'Mary'
60
+ >>> session_state.favorite_color
61
+ 'black'
62
+ Since you set user_name above, next time your script runs this will be the
63
+ result:
64
+ >>> session_state = get(user_name='', favorite_color='black')
65
+ >>> session_state.user_name
66
+ 'Mary'
67
+ """
68
+ # Hack to get the session object from Streamlit.
69
+
70
+ ctx = ReportThread.get_report_ctx()
71
+
72
+ this_session = None
73
+
74
+ current_server = Server.get_current()
75
+ if hasattr(current_server, '_session_infos'):
76
+ # Streamlit < 0.56
77
+ session_infos = Server.get_current()._session_infos.values()
78
+ else:
79
+ session_infos = Server.get_current()._session_info_by_id.values()
80
+
81
+ for session_info in session_infos:
82
+ s = session_info.session
83
+ if (
84
+ # Streamlit < 0.54.0
85
+ (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
86
+ or
87
+ # Streamlit >= 0.54.0
88
+ (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
89
+ or
90
+ # Streamlit >= 0.65.2
91
+ (not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
92
+ ):
93
+ this_session = s
94
+
95
+ if this_session is None:
96
+ raise RuntimeError(
97
+ "Oh noes. Couldn't get your Streamlit Session object. "
98
+ 'Are you doing something fancy with threads?')
99
+
100
+ # Got the session object! Now let's attach some state into it.
101
+
102
+ if not hasattr(this_session, '_custom_session_state'):
103
+ this_session._custom_session_state = SessionState(**kwargs)
104
+
105
+ return this_session._custom_session_state
106
+
107
+ __all__ = ['get']
app/app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import SessionState
3
+ from mtranslate import translate
4
+ from prompts import PROMPT_LIST
5
+ import random
6
+ import time
7
+ from transformers import pipeline, set_seed
8
+
9
+ # st.set_page_config(page_title="Image Search")
10
+
11
+ # vector_length = 128
12
+ model_name = "cahya/gpt2-small-indonesian-story"
13
+
14
+
15
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
16
+ def get_generator():
17
+ st.write("Loading the GPT2 model...")
18
+ text_generator = pipeline('text-generation', model=model_name)
19
+ return text_generator
20
+
21
+
22
+ @st.cache(suppress_st_warning=True)
23
+ def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
24
+ temperature: float = 1.0, max_time: float = None):
25
+ st.write("Cache miss: process")
26
+ set_seed(42)
27
+ result = text_generator(text, max_length=max_length, do_sample=do_sample,
28
+ top_k=top_k, top_p=top_p, temperature=temperature, max_time=max_time)
29
+ return result
30
+
31
+
32
+ st.title("Indonesian Story Generator")
33
+
34
+ st.markdown(
35
+ """
36
+ This application is a demo for Indonesian Story Generator using GPT2.
37
+ """
38
+ )
39
+ session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
40
+ ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
41
+ prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
42
+ # Update prompt
43
+ if session_state.prompt is None:
44
+ session_state.prompt = prompt
45
+ elif session_state.prompt is not None and (prompt != session_state.prompt):
46
+ session_state.prompt = prompt
47
+ session_state.prompt_box = None
48
+ session_state.text = None
49
+ else:
50
+ session_state.prompt = prompt
51
+
52
+ # Update prompt box
53
+ if session_state.prompt == "Custom":
54
+ session_state.prompt_box = "Enter your text here"
55
+ else:
56
+ if session_state.prompt is not None and session_state.prompt_box is None:
57
+ session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])
58
+
59
+ session_state.text = st.text_area("Enter text", session_state.prompt_box)
60
+
61
+ temp = st.sidebar.slider(
62
+ "Temperature",
63
+ value=1.0,
64
+ min_value=0.0,
65
+ max_value=100.0
66
+ )
67
+
68
+ top_k = st.sidebar.number_input(
69
+ "Top k",
70
+ value=25
71
+ )
72
+
73
+ top_p = st.sidebar.number_input(
74
+ "Top p",
75
+ value=0.95
76
+ )
77
+
78
+ text_generator = get_generator()
79
+ if st.button("Run"):
80
+ with st.spinner(text="Getting results..."):
81
+ st.subheader("Result")
82
+ time_start = time.time()
83
+ result = process(text=session_state.text, top_k=int(top_k), top_p=float(top_p))
84
+ time_end = time.time()
85
+ time_diff = time_end-time_start
86
+ #print(f"Text generated in {time_diff} seconds")
87
+ result = result[0]["generated_text"]
88
+ st.write(result.replace("\n", " \n"))
89
+ st.text("Translation")
90
+ translation = translate(result, "en", "id")
91
+ st.write(translation.replace("\n", " \n"))
92
+
93
+ # Reset state
94
+ session_state.prompt = None
95
+ session_state.prompt_box = None
96
+ session_state.text = None
app/prompts.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT_LIST = {
2
+ "City": [
3
+ "Vienna is the national capital, largest city, and one of nine states of Austria. Vienna is Austria's most populous city, with about 2 million inhabitants, and its cultural, economic, and political centre. It is the 6th-largest city by population within city limits in the European Union",
4
+ "Sydney is the capital city of the state of New South Wales, and the most populous city in Australia and Oceania.",
5
+ "Ubud is a town on the Indonesian island of Bali in Ubud District, located amongst rice paddies and steep ravines in the central foothills of the Gianyar regency. Promoted as an arts and culture centre, it has developed a large tourism industry.",
6
+ "Jakarta is the capital of Indonesia"
7
+ ],
8
+ "People": [
9
+ "Albert Einstein was a German-born theoretical physicist, widely acknowledged to be one of the greatest physicists of all time. Einstein is known for developing the theory of relativity, but he also made important contributions to the development of the theory of quantum mechanics.",
10
+ "Geoffrey Everest Hinton is a British-Canadian cognitive psychologist and computer scientist, most noted for his work on artificial neural networks.",
11
+ "Pramoedya Ananta Toer was an Indonesian author of novels, short stories, essays, polemics and histories of his homeland and its people."
12
+ ],
13
+ "Building": [
14
+ "Borobudur is a 7th-century Mahayana Buddhist temple in Indonesia. It is the world's largest Buddhist temple. The temple consists of nine stacked platforms, six square and three circular, topped by a central dome. It is decorated with 2,672 relief panels and 504 Buddha statues.",
15
+ "The Statue of Liberty is a colossal neoclassical sculpture on Liberty Island in New York Harbor within New York City, in the United States.",
16
+ "Machu Picchu is a 15th-century Inca citadel, located in the Eastern Cordillera of southern Peru, on a 2,430-meter mountain ridge."
17
+ ]
18
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ transformers
4
+ datasets
5
+ mtranslate
6
+ # streamlit version 0.67.1 is needed due to issue with caching
7
+ streamlit==0.67.1