kcarnold commited on
Commit
18f01d3
1 Parent(s): 39fec23

Import from other repo

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import pandas as pd
8
+
9
+ model_options = [
10
+ 'google/gemma-1.1-2b-it',
11
+ 'google/gemma-1.1-7b-it'
12
+ ]
13
+
14
+ model_name = st.selectbox("Select a model", model_options + ['other'])
15
+
16
+ if model_name == 'other':
17
+ model_name = st.text_input("Enter model name", model_options[0])
18
+
19
+ @st.cache_resource
20
+ def get_tokenizer(model_name):
21
+ return AutoTokenizer.from_pretrained(model_name).from_pretrained(model_name)
22
+
23
+ @st.cache_resource
24
+ def get_model(model_name):
25
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16)
26
+ print(f"Loaded model, {model.num_parameters():,d} parameters.")
27
+ return model
28
+
29
+ tokenizer = get_tokenizer(model_name)
30
+ model = get_model(model_name)
31
+
32
+ prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
33
+ doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
34
+
35
+
36
+ messages = [
37
+ {
38
+ "role": "user",
39
+ "content": f"{prompt}\n\n{doc}",
40
+ },
41
+ ]
42
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
43
+ assert len(tokenized_chat.shape) == 1
44
+
45
+ doc_ids = tokenizer(doc, return_tensors='pt')['input_ids'][0]
46
+ joined_ids = torch.cat([tokenized_chat, doc_ids[1:]])
47
+
48
+ # Call the model
49
+ with torch.no_grad():
50
+ logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
51
+
52
+ spans = []
53
+ length_so_far = 0
54
+ for idx in range(len(tokenized_chat), len(joined_ids)):
55
+ probs = logits[idx - 1].softmax(dim=-1)
56
+ token_id = joined_ids[idx]
57
+ token = tokenizer.decode(token_id)
58
+ token_loss = -probs[token_id].log().item()
59
+ most_likely_token_id = probs.argmax()
60
+ print(idx, token, token_loss, tokenizer.decode(most_likely_token_id))
61
+ spans.append(dict(
62
+ start=length_so_far,
63
+ end=length_so_far + len(token),
64
+ token=token,
65
+ token_loss=token_loss,
66
+ most_likely_token=tokenizer.decode(most_likely_token_id)
67
+ ))
68
+ length_so_far += len(token)
69
+
70
+ highest_loss = max(span['token_loss'] for span in spans[1:])
71
+ for span in spans:
72
+ span['loss_ratio'] = span['token_loss'] / highest_loss
73
+
74
+ html = ''
75
+ for span in spans:
76
+ b = int(256 * span["token_loss"] / highest_loss)
77
+ html += f'<span style="color: rgba(128, 128, {b:d})" title="{span["most_likely_token"]}">{span["token"]}</span>'
78
+ html = f"<p style=\"background: white;\">{html}</p>"
79
+
80
+ st.subheader("Rewritten document")
81
+ st.write(html, unsafe_allow_html=True)
82
+ st.write(pd.DataFrame(spans))