taka-yamakoshi commited on
Commit
8a204f8
1 Parent(s): e9daec4

add more functions

Browse files
Files changed (1) hide show
  1. app.py +60 -30
app.py CHANGED
@@ -2,16 +2,46 @@ import pandas as pd
2
  import streamlit as st
3
  import numpy as np
4
  import torch
5
- from transformers import AlbertTokenizer
6
  import io
7
  import time
8
 
9
  @st.cache(show_spinner=True,allow_output_mutation=True)
10
  def load_model(model_name):
11
- if model_name.startswith('albert'):
 
 
 
 
 
 
 
 
 
 
12
  tokenizer = AlbertTokenizer.from_pretrained(model_name)
13
  return tokenizer
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  if __name__=='__main__':
17
 
@@ -43,34 +73,34 @@ if __name__=='__main__':
43
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
44
 
45
  # Title
46
- st.markdown("<p style='text-align:center; color:black; font-family:Arial; font-size:32px;'>Tokenizer Demo</p>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
47
 
48
- tokenizer = load_model('albert-xxlarge-v2')
49
- sent_cols = st.columns(2)
50
- num_tokens = {}
51
- sents = {}
52
- for sent_id, sent_col in enumerate(sent_cols):
53
- with sent_col:
54
- sentence = st.text_input(f'Sentence {sent_id+1}')
55
- sents[f'sent_{sent_id+1}'] = sentence
56
- if len(sentence)>0:
57
- input_sent = tokenizer(sentence)['input_ids']
58
- encoded_sent = [str(token) for token in input_sent[1:-1]]
59
- decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
60
- num_tokens[f'sent_{sent_id+1}'] = len(decoded_sent)
61
 
62
- #char_nums = [len(word)+2 for word in decoded_sent]
63
- #word_cols = st.columns(char_nums)
64
- #for word_col,word in zip(word_cols,decoded_sent):
65
- #with word_col:
66
- #st.write(word)
67
- st.write(' '.join(encoded_sent))
68
- st.write(' '.join(decoded_sent))
69
- st.markdown(f"<p style='text-align: center; color: black; font-family:Arial; font-size:20px;'>{len(decoded_sent)} tokens </p>", unsafe_allow_html=True)
70
 
71
- if len(sents['sent_1'])>0 and len(sents['sent_2'])>0:
72
- st.markdown("<p style='text-align:center; color:black; font-family:Arial; font-size:16px;'>Result&colon; </p>", unsafe_allow_html=True)
73
- if num_tokens[f'sent_1']==num_tokens[f'sent_2']:
74
- st.markdown("<p style='text-align:center; color:MediumAquamarine; font-family:Arial; font-size:20px;'>Matched! </p>", unsafe_allow_html=True)
75
- else:
76
- st.markdown("<p style='text-align:center; color:Salmon; font-family:Arial; font-size:20px;'>Not Matched... </p>", unsafe_allow_html=True)
 
2
  import streamlit as st
3
  import numpy as np
4
  import torch
 
5
  import io
6
  import time
7
 
8
  @st.cache(show_spinner=True,allow_output_mutation=True)
9
  def load_model(model_name):
10
+ if model_name.startswith('bert'):
11
+ from transformers import BertTokenizer
12
+ tokenizer = BertTokenizer.from_pretrained(model_name)
13
+ elif model_name.startswith('gpt2'):
14
+ from transformers import GPT2Tokenizer
15
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
16
+ elif model_name.startswith('roberta'):
17
+ from transformers import RobertaTokenizer
18
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
19
+ elif model_name.startswith('albert'):
20
+ from transformers import AlbertTokenizer
21
  tokenizer = AlbertTokenizer.from_pretrained(model_name)
22
  return tokenizer
23
 
24
+ def generate_markdown(text,color='black',font='Arial',size=20):
25
+ return f"<p style='text-align:center; color:{color}; font-family:{font}; font-size:{size}px;'>{text}</p>"
26
+
27
+ def TokenizeText(sentence):
28
+ if len(sentence)>0:
29
+ input_sent = tokenizer(sentence)['input_ids']
30
+ encoded_sent = [str(token) for token in input_sent[1:-1]]
31
+ decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
32
+ num_tokens = len(decoded_sent)
33
+
34
+ #char_nums = [len(word)+2 for word in decoded_sent]
35
+ #word_cols = st.columns(char_nums)
36
+ #for word_col,word in zip(word_cols,decoded_sent):
37
+ #with word_col:
38
+ #st.write(word)
39
+ st.write(' '.join(encoded_sent))
40
+ st.write(' '.join(decoded_sent))
41
+ st.markdown(generate_markdown(f'{num_tokens} tokens'), unsafe_allow_html=True)
42
+
43
+ return num_tokens
44
+
45
 
46
  if __name__=='__main__':
47
 
 
73
  st.markdown(hide_table_row_index, unsafe_allow_html=True)
74
 
75
  # Title
76
+ st.markdown(generate_markdown('Tokenizer Demo',size=32), unsafe_allow_html=True)
77
+
78
+ # Select and load the tokenizer
79
+ tokenizer_name = st.selectbox('Choose the tokenizer from below',
80
+ ('bert-base-uncased','bert-large-cased',
81
+ 'gpt2','gpt2-large',
82
+ 'roberta-base','roberta-large',
83
+ 'albert-base-v2','albert-xxlarge-v2'),index=7)
84
+ tokenizer = load_model(tokenizer_name)
85
 
86
+ comparison_mode = st.checkbox('Compare two texts')
87
+ if comparison_mode:
88
+ sent_cols = st.columns(2)
89
+ num_tokens = {}
90
+ sents = {}
91
+ for sent_id, sent_col in enumerate(sent_cols):
92
+ with sent_col:
93
+ sentence = st.text_input(f'Text {sent_id+1}')
94
+ sents[f'sent_{sent_id+1}'] = sentence
95
+ num_tokens[f'{sent_id+1}'] = TokenizeText(sentence)
 
 
 
96
 
97
+ if len(sents['sent_1'])>0 and len(sents['sent_2'])>0:
98
+ st.markdown(generate_markdown('Result&colon; ',size=16), unsafe_allow_html=True)
99
+ if num_tokens[f'sent_1']==num_tokens[f'sent_2']:
100
+ st.markdown(generate_markdown('Matched! ',color='MediumAquamarine'), unsafe_allow_html=True)
101
+ else:
102
+ st.markdown(generate_markdown('Not Matched... ',color='Salmon'), unsafe_allow_html=True)
 
 
103
 
104
+ else:
105
+ sentence = st.text_input(f'Text')
106
+ num_tokens = TokenizeText(sentence)