kcarnold commited on
Commit
325ca0f
1 Parent(s): 47f3f6e

Just playing around

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import streamlit as st
2
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
@@ -10,18 +13,18 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  from transformers import MarianMTModel, MarianTokenizer
12
 
13
- model_name = st.radio("Select a model", [
14
  'Helsinki-NLP/opus-mt-roa-en',
15
  'Helsinki-NLP/opus-mt-en-roa',
16
- 'other'
17
- ])
18
 
19
- if model_name == 'other':
20
- model_name = st.text_input("Enter model name", 'Helsinki-NLP/opus-mt-ROMANCE-en')
21
 
22
- if not hasattr(st, "cache_resource"):
23
- st.cache_resource = st.experimental_singleton
24
 
 
 
25
 
26
  @st.cache_resource
27
  def get_tokenizer(model_name):
@@ -42,7 +45,9 @@ else:
42
  lang_code = None
43
 
44
 
45
- input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
 
 
46
  input_text = input_text.strip()
47
  if not input_text:
48
  st.stop()
@@ -51,7 +56,6 @@ if not input_text:
51
  if lang_code:
52
  input_text = f"{lang_code} {input_text}"
53
 
54
- output_so_far = st.text_input("Enter text translated so far", "Hello, my")
55
 
56
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
57
 
@@ -60,8 +64,18 @@ example_generations = model.generate(
60
  num_beams=4,
61
  num_return_sequences=4,
62
  )
63
- st.write("Example generations:")
64
- st.write(tokenizer.batch_decode(example_generations, skip_special_tokens=True))
 
 
 
 
 
 
 
 
 
 
65
 
66
  # tokenize the output so far
67
  with tokenizer.as_target_tokenizer():
@@ -93,10 +107,10 @@ with tokenizer.as_target_tokenizer():
93
  })
94
 
95
  st.subheader("Most likely next tokens")
96
- st.write(probs_table)
97
 
98
  if len(decoder_input_ids) > 1:
99
- st.subheader("Loss by token")
100
  loss_table = pd.DataFrame({
101
  'token': [tokenizer.decode(token_id) for token_id in decoder_input_ids[1:]],
102
  'loss': F.cross_entropy(model_output.logits[0, :-1], torch.tensor(decoder_input_ids[1:]).to(device), reduction='none').cpu()
 
1
  import streamlit as st
2
 
3
+ if not hasattr(st, "cache_resource"):
4
+ st.cache_resource = st.experimental_singleton
5
+
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
13
 
14
  from transformers import MarianMTModel, MarianTokenizer
15
 
16
+ model_options = [
17
  'Helsinki-NLP/opus-mt-roa-en',
18
  'Helsinki-NLP/opus-mt-en-roa',
19
+ ]
 
20
 
21
+ col1, col2 = st.columns(2)
 
22
 
23
+ with col1:
24
+ model_name = st.selectbox("Select a model", model_options + ['other'])
25
 
26
+ if model_name == 'other':
27
+ model_name = st.text_input("Enter model name", model_options[0])
28
 
29
  @st.cache_resource
30
  def get_tokenizer(model_name):
 
45
  lang_code = None
46
 
47
 
48
+ with col2:
49
+ input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
50
+
51
  input_text = input_text.strip()
52
  if not input_text:
53
  st.stop()
 
56
  if lang_code:
57
  input_text = f"{lang_code} {input_text}"
58
 
 
59
 
60
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
61
 
 
64
  num_beams=4,
65
  num_return_sequences=4,
66
  )
67
+
68
+ col1, col2 = st.columns(2)
69
+ with col1:
70
+ st.write("Example generations:")
71
+ st.write('\n'.join(
72
+ '- ' + translation
73
+ for translation in tokenizer.batch_decode(example_generations, skip_special_tokens=True)))
74
+
75
+ with col2:
76
+ example_first_word = tokenizer.decode(example_generations[0, 1])
77
+ output_so_far = st.text_input("Enter text translated so far", example_first_word)
78
+
79
 
80
  # tokenize the output so far
81
  with tokenizer.as_target_tokenizer():
 
107
  })
108
 
109
  st.subheader("Most likely next tokens")
110
+ st.table(probs_table.style.hide(axis='index'))
111
 
112
  if len(decoder_input_ids) > 1:
113
+ st.subheader("Loss by already-generated token")
114
  loss_table = pd.DataFrame({
115
  'token': [tokenizer.decode(token_id) for token_id in decoder_input_ids[1:]],
116
  'loss': F.cross_entropy(model_output.logits[0, :-1], torch.tensor(decoder_input_ids[1:]).to(device), reduction='none').cpu()