kcarnold commited on
Commit
5aa0ce2
1 Parent(s): 58a8c64

Multipageify

Browse files
Files changed (3) hide show
  1. app.py +0 -108
  2. rewrite.py → pages/1_Rewrite.py +0 -0
  3. pages/2_Highlights.py +109 -0
app.py CHANGED
@@ -1,109 +1 @@
1
  import streamlit as st
2
- import pandas as pd
3
- import html
4
-
5
- model_options = [
6
- 'API',
7
- 'google/gemma-1.1-2b-it',
8
- 'google/gemma-1.1-7b-it'
9
- ]
10
-
11
- model_name = st.selectbox("Select a model", model_options + ['other'])
12
-
13
- if model_name == 'other':
14
- model_name = st.text_input("Enter model name", model_options[0])
15
-
16
- @st.cache_resource
17
- def get_tokenizer(model_name):
18
- from transformers import AutoTokenizer
19
- return AutoTokenizer.from_pretrained(model_name).from_pretrained(model_name)
20
-
21
- @st.cache_resource
22
- def get_model(model_name):
23
- import torch
24
- from transformers import AutoModelForCausalLM
25
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16)
26
- print(f"Loaded model, {model.num_parameters():,d} parameters.")
27
- return model
28
-
29
- prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
30
- doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
31
- updated_doc = st.text_area("Updated Doc", help="Your edited document. Leave this blank to use your original document.")
32
-
33
-
34
- def get_spans_local(prompt, doc, updated_doc):
35
- import torch
36
-
37
- tokenizer = get_tokenizer(model_name)
38
- model = get_model(model_name)
39
-
40
-
41
- messages = [
42
- {
43
- "role": "user",
44
- "content": f"{prompt}\n\n{doc}",
45
- },
46
- ]
47
- tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
48
- assert len(tokenized_chat.shape) == 1
49
-
50
- if len(updated_doc.strip()) == 0:
51
- updated_doc = doc
52
- updated_doc_ids = tokenizer(updated_doc, return_tensors='pt')['input_ids'][0]
53
- joined_ids = torch.cat([tokenized_chat, updated_doc_ids[1:]])
54
-
55
- with torch.no_grad():
56
- logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
57
-
58
- spans = []
59
- length_so_far = 0
60
- for idx in range(len(tokenized_chat), len(joined_ids)):
61
- probs = logits[idx - 1].softmax(dim=-1)
62
- token_id = joined_ids[idx]
63
- token = tokenizer.decode(token_id)
64
- token_loss = -probs[token_id].log().item()
65
- most_likely_token_id = probs.argmax()
66
- print(idx, token, token_loss, tokenizer.decode(most_likely_token_id))
67
- spans.append(dict(
68
- start=length_so_far,
69
- end=length_so_far + len(token),
70
- token=token,
71
- token_loss=token_loss,
72
- most_likely_token=tokenizer.decode(most_likely_token_id)
73
- ))
74
- length_so_far += len(token)
75
- return spans
76
-
77
- def get_highlights_api(prompt, doc, updated_doc):
78
- # Make a request to the API. prompt and doc are query parameters:
79
- # https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
80
- # The response is a JSON array
81
- import requests
82
- response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc, updated_doc=updated_doc))
83
- return response.json()['highlights']
84
-
85
- if model_name == 'API':
86
- spans = get_highlights_api(prompt, doc, updated_doc)
87
- else:
88
- spans = get_spans_local(prompt, doc, updated_doc)
89
-
90
- if len(spans) < 2:
91
- st.write("No spans found.")
92
- st.stop()
93
-
94
- highest_loss = max(span['token_loss'] for span in spans[1:])
95
- for span in spans:
96
- span['loss_ratio'] = span['token_loss'] / highest_loss
97
-
98
- html_out = ''
99
- for span in spans:
100
- is_different = span['token'] != span['most_likely_token']
101
- html_out += '<span style="color: {color}" title="{title}">{orig_token}</span>'.format(
102
- color="blue" if is_different else "black",
103
- title=html.escape(span["most_likely_token"]).replace('\n', ' '),
104
- orig_token=html.escape(span["token"]).replace('\n', '<br>')
105
- )
106
- html_out = f"<p style=\"background: white;\">{html_out}</p>"
107
-
108
- st.write(html_out, unsafe_allow_html=True)
109
- st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rewrite.py → pages/1_Rewrite.py RENAMED
File without changes
pages/2_Highlights.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import html
4
+
5
+ model_options = [
6
+ 'API',
7
+ 'google/gemma-1.1-2b-it',
8
+ 'google/gemma-1.1-7b-it'
9
+ ]
10
+
11
+ model_name = st.selectbox("Select a model", model_options + ['other'])
12
+
13
+ if model_name == 'other':
14
+ model_name = st.text_input("Enter model name", model_options[0])
15
+
16
+ @st.cache_resource
17
+ def get_tokenizer(model_name):
18
+ from transformers import AutoTokenizer
19
+ return AutoTokenizer.from_pretrained(model_name).from_pretrained(model_name)
20
+
21
+ @st.cache_resource
22
+ def get_model(model_name):
23
+ import torch
24
+ from transformers import AutoModelForCausalLM
25
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16)
26
+ print(f"Loaded model, {model.num_parameters():,d} parameters.")
27
+ return model
28
+
29
+ prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
30
+ doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
31
+ updated_doc = st.text_area("Updated Doc", help="Your edited document. Leave this blank to use your original document.")
32
+
33
+
34
+ def get_spans_local(prompt, doc, updated_doc):
35
+ import torch
36
+
37
+ tokenizer = get_tokenizer(model_name)
38
+ model = get_model(model_name)
39
+
40
+
41
+ messages = [
42
+ {
43
+ "role": "user",
44
+ "content": f"{prompt}\n\n{doc}",
45
+ },
46
+ ]
47
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
48
+ assert len(tokenized_chat.shape) == 1
49
+
50
+ if len(updated_doc.strip()) == 0:
51
+ updated_doc = doc
52
+ updated_doc_ids = tokenizer(updated_doc, return_tensors='pt')['input_ids'][0]
53
+ joined_ids = torch.cat([tokenized_chat, updated_doc_ids[1:]])
54
+
55
+ with torch.no_grad():
56
+ logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
57
+
58
+ spans = []
59
+ length_so_far = 0
60
+ for idx in range(len(tokenized_chat), len(joined_ids)):
61
+ probs = logits[idx - 1].softmax(dim=-1)
62
+ token_id = joined_ids[idx]
63
+ token = tokenizer.decode(token_id)
64
+ token_loss = -probs[token_id].log().item()
65
+ most_likely_token_id = probs.argmax()
66
+ print(idx, token, token_loss, tokenizer.decode(most_likely_token_id))
67
+ spans.append(dict(
68
+ start=length_so_far,
69
+ end=length_so_far + len(token),
70
+ token=token,
71
+ token_loss=token_loss,
72
+ most_likely_token=tokenizer.decode(most_likely_token_id)
73
+ ))
74
+ length_so_far += len(token)
75
+ return spans
76
+
77
+ def get_highlights_api(prompt, doc, updated_doc):
78
+ # Make a request to the API. prompt and doc are query parameters:
79
+ # https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
80
+ # The response is a JSON array
81
+ import requests
82
+ response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc, updated_doc=updated_doc))
83
+ return response.json()['highlights']
84
+
85
+ if model_name == 'API':
86
+ spans = get_highlights_api(prompt, doc, updated_doc)
87
+ else:
88
+ spans = get_spans_local(prompt, doc, updated_doc)
89
+
90
+ if len(spans) < 2:
91
+ st.write("No spans found.")
92
+ st.stop()
93
+
94
+ highest_loss = max(span['token_loss'] for span in spans[1:])
95
+ for span in spans:
96
+ span['loss_ratio'] = span['token_loss'] / highest_loss
97
+
98
+ html_out = ''
99
+ for span in spans:
100
+ is_different = span['token'] != span['most_likely_token']
101
+ html_out += '<span style="color: {color}" title="{title}">{orig_token}</span>'.format(
102
+ color="blue" if is_different else "black",
103
+ title=html.escape(span["most_likely_token"]).replace('\n', ' '),
104
+ orig_token=html.escape(span["token"]).replace('\n', '<br>')
105
+ )
106
+ html_out = f"<p style=\"background: white;\">{html_out}</p>"
107
+
108
+ st.write(html_out, unsafe_allow_html=True)
109
+ st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])