kevin-yang commited on
Commit
4f76eaa
โ€ข
1 Parent(s): 5d01a65

Add application file

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, PreTrainedTokenizerFast
4
+ import numpy as np
5
+
6
+
7
+ model = GPT2LMHeadModel.from_pretrained("jason9693/soongsil-univ-gpt-v1")
8
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("jason9693/soongsil-univ-gpt-v1")
9
+
10
+ category_map = {
11
+ "์ˆญ์‹ค๋Œ€ ์—ํƒ€": "<unused5>",
12
+ "๋ชจ๋‘์˜ ์—ฐ์• ": "<unused3>",
13
+ "๋Œ€ํ•™์ƒ ์žก๋‹ด๋ฐฉ": "<unused4>"
14
+ }
15
+
16
+ st.markdown("""# University Community KoGPT2 : ์ˆญ์‹ค๋Œ€ ์—๋ธŒ๋ฆฌํƒ€์ž„๋ด‡
17
+
18
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1p6DIxsesi3eJNPwFwvMw0MeM5LkSGoPW?usp=sharing) [![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jason9693/UCK-GPT2/issues) ![GitHub](https://img.shields.io/github/license/jason9693/UCK-GPT2)
19
+
20
+ ## ๋Œ€ํ•™ ์ปค๋ฎค๋‹ˆํ‹ฐ ๊ฒŒ์‹œ๊ธ€ ์ƒ์„ฑ๊ธฐ
21
+
22
+ SKT-AI์—์„œ ๊ณต๊ฐœํ•œ [KoGPT2](https://github.com/SKT-AI/KoGPT2) ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹ํ•˜์—ฌ ๋Œ€ํ•™ ์ปค๋ฎค๋‹ˆํ‹ฐ ๊ฒŒ์‹œ๊ธ€์„ ์ƒ์„ฑํ•˜๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ด ์—๋ธŒ๋ฆฌํƒ€์ž„, ์บ ํผ์Šคํ”ฝ ๋ฐ์ดํ„ฐ 22๋งŒ๊ฐœ๋ฅผ ์ด์šฉํ•ด์„œ ํ•™์Šต์„ ์ง„ํ–‰ํ–ˆ์œผ๋ฉฐ, ํ•™์Šต์—๋Š” ๋Œ€๋žต **3์ผ**์ •๋„ ์†Œ์š”๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
23
+
24
+ * [GPT ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ ๋งํฌ](https://www.notion.so/Improve-Language-Understanding-by-Generative-Pre-Training-GPT-afb4b5ef6e984961ac022b700c152b6b)
25
+
26
+ ## ์‹œ์—ฐํ•˜๊ธฐ
27
+ """)
28
+
29
+
30
+ seed = st.text_input("Seed", "์กฐ๋งŒ์‹ ๊ธฐ๋…๊ด€")
31
+ category = st.selectbox("Category", list(category_map.keys()))
32
+ go = st.button("Generate")
33
+
34
+
35
+ st.markdown("## ์ƒ์„ฑ ๊ฒฐ๊ณผ")
36
+ if go:
37
+ input_context = category_map[category] + seed
38
+ input_ids = tokenizer(input_context, return_tensors="pt")
39
+ outputs = model.generate(
40
+ input_ids=input_ids["input_ids"],
41
+ max_length=250,
42
+ num_return_sequences=1,
43
+ no_repeat_ngram_size=3,
44
+ repetition_penalty=2.0,
45
+ do_sample=True,
46
+ use_cache=True,
47
+ eos_token_id=tokenizer.eos_token_id
48
+ )
49
+ st.write(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace("<unused2>", "\n"))