Spaces:
Sleeping
Sleeping
Import from other repo
Browse files
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
model_options = [
|
10 |
+
'google/gemma-1.1-2b-it',
|
11 |
+
'google/gemma-1.1-7b-it'
|
12 |
+
]
|
13 |
+
|
14 |
+
model_name = st.selectbox("Select a model", model_options + ['other'])
|
15 |
+
|
16 |
+
if model_name == 'other':
|
17 |
+
model_name = st.text_input("Enter model name", model_options[0])
|
18 |
+
|
19 |
+
@st.cache_resource
|
20 |
+
def get_tokenizer(model_name):
|
21 |
+
return AutoTokenizer.from_pretrained(model_name).from_pretrained(model_name)
|
22 |
+
|
23 |
+
@st.cache_resource
|
24 |
+
def get_model(model_name):
|
25 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16)
|
26 |
+
print(f"Loaded model, {model.num_parameters():,d} parameters.")
|
27 |
+
return model
|
28 |
+
|
29 |
+
tokenizer = get_tokenizer(model_name)
|
30 |
+
model = get_model(model_name)
|
31 |
+
|
32 |
+
prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
|
33 |
+
doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
|
34 |
+
|
35 |
+
|
36 |
+
messages = [
|
37 |
+
{
|
38 |
+
"role": "user",
|
39 |
+
"content": f"{prompt}\n\n{doc}",
|
40 |
+
},
|
41 |
+
]
|
42 |
+
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
|
43 |
+
assert len(tokenized_chat.shape) == 1
|
44 |
+
|
45 |
+
doc_ids = tokenizer(doc, return_tensors='pt')['input_ids'][0]
|
46 |
+
joined_ids = torch.cat([tokenized_chat, doc_ids[1:]])
|
47 |
+
|
48 |
+
# Call the model
|
49 |
+
with torch.no_grad():
|
50 |
+
logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
|
51 |
+
|
52 |
+
spans = []
|
53 |
+
length_so_far = 0
|
54 |
+
for idx in range(len(tokenized_chat), len(joined_ids)):
|
55 |
+
probs = logits[idx - 1].softmax(dim=-1)
|
56 |
+
token_id = joined_ids[idx]
|
57 |
+
token = tokenizer.decode(token_id)
|
58 |
+
token_loss = -probs[token_id].log().item()
|
59 |
+
most_likely_token_id = probs.argmax()
|
60 |
+
print(idx, token, token_loss, tokenizer.decode(most_likely_token_id))
|
61 |
+
spans.append(dict(
|
62 |
+
start=length_so_far,
|
63 |
+
end=length_so_far + len(token),
|
64 |
+
token=token,
|
65 |
+
token_loss=token_loss,
|
66 |
+
most_likely_token=tokenizer.decode(most_likely_token_id)
|
67 |
+
))
|
68 |
+
length_so_far += len(token)
|
69 |
+
|
70 |
+
highest_loss = max(span['token_loss'] for span in spans[1:])
|
71 |
+
for span in spans:
|
72 |
+
span['loss_ratio'] = span['token_loss'] / highest_loss
|
73 |
+
|
74 |
+
html = ''
|
75 |
+
for span in spans:
|
76 |
+
b = int(256 * span["token_loss"] / highest_loss)
|
77 |
+
html += f'<span style="color: rgba(128, 128, {b:d})" title="{span["most_likely_token"]}">{span["token"]}</span>'
|
78 |
+
html = f"<p style=\"background: white;\">{html}</p>"
|
79 |
+
|
80 |
+
st.subheader("Rewritten document")
|
81 |
+
st.write(html, unsafe_allow_html=True)
|
82 |
+
st.write(pd.DataFrame(spans))
|