Spaces:
Running
Running
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer | |
import torch | |
import streamlit as st | |
import re | |
from typing import List, Tuple | |
import spacy | |
import numpy as np | |
from dataclasses import dataclass | |
from nltk.tokenize import sent_tokenize, word_tokenize | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
st.set_page_config(layout="wide") | |
class LexicalUnits: | |
unit_type: str | |
text: List[str] | |
self_info: List[float] = None | |
def __add__(self, other): | |
assert self.unit_type == other.unit_type, 'Cannot add two different unit types' | |
return LexicalUnits(self.unit_type, self.text + other.text, self.self_info + other.self_info) | |
def __radd__(self, other): | |
if other == 0: | |
return self | |
return NotImplementedError() | |
def add_to_head(self, token, self_info): | |
return LexicalUnits(self.unit_type, [token] + self.text, [self_info] + self.self_info) | |
def add_to_tail(self, token, self_info): | |
return LexicalUnits(self.unit_type, self.text + [token], self.self_info + [self_info]) | |
class SelectiveContext: | |
def __init__(self, model_type = 'gpt2', lang = 'en'): | |
self.model_type = model_type | |
self.lang = lang | |
# this means we calculate self-information sentence by sentence | |
self.sent_level_self_info = True | |
self._prepare_phrase_tokenizer() | |
self.sent_tokenize_pattern = r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s" | |
self.phrase_mask_token = '' | |
self.sent_mask_token = "<deleted>" | |
self._prepare_model() | |
def _prepare_phrase_tokenizer(self): | |
# we use space to tokenize sentence into phrases | |
# for English, we should use `spacy.load("en_core_web_sm").add_pipe('merge_noun_chunks')` | |
# for Chinese, use `nlp = spacy.load('zh_core_web_sm')`` directly | |
lang = self.lang | |
if lang == "en": | |
self.nlp = spacy.load("en_core_web_sm", disable=["ner"]) | |
self.nlp.add_pipe('merge_noun_chunks') | |
elif lang == "zh": | |
self.nlp = spacy.load('zh_core_web_sm', disable=["ner"]) | |
def _prepare_model(self): | |
if self.model_type == 'gpt2': | |
if self.lang == 'zh': | |
self.model = GPT2LMHeadModel.from_pretrained('uer/gpt2-chinese-cluecorpussmall') | |
self.tokenizer = BertTokenizer.from_pretrained('uer/gpt2-chinese-cluecorpussmall') | |
else: | |
self.model = GPT2LMHeadModel.from_pretrained('gpt2') | |
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
self.model.to(DEVICE) | |
self.model.eval() | |
print('model loaded') | |
self.max_token_length = self.model.config.n_positions | |
self.get_self_information = self._get_self_info_via_gpt2 | |
def get_self_information(self, text: str) -> Tuple[List[str], List[float]]: | |
# it takes text as input, and return a list of words and a list of self-information scores | |
raise NotImplementedError | |
def _get_self_info_via_gpt2(self, text: str) -> Tuple[List[str], List[float]]: | |
if self.lang == 'en': | |
text = f"<|endoftext|>{text}" | |
elif self.lang == 'zh': | |
text = f"[CLS]{text}" | |
with torch.no_grad(): | |
encoding = self.tokenizer(text, add_special_tokens=False, return_tensors='pt') | |
encoding = encoding.to(DEVICE) | |
outputs = self.model(**encoding) | |
logits = outputs.logits | |
probs = torch.softmax(logits, dim=-1) | |
self_info = -torch.log(probs) | |
input_ids = encoding['input_ids'] | |
input_ids_expaned = input_ids[:, 1:].unsqueeze(-1) | |
tokens = [self.tokenizer.decode(token_) for token_ in input_ids.squeeze().tolist()[1:]] | |
return tokens, self_info[:, :-1].gather(-1, input_ids_expaned).squeeze(-1).squeeze(0).tolist() | |
def _lexical_unit(self, sents): | |
if self.sent_level_self_info: | |
sent_self_info = [] | |
all_noun_phrases = [] | |
all_noun_phrases_info = [] | |
all_tokens = [] | |
all_token_self_info = [] | |
for sent in sents: | |
print(sent) | |
tokens, self_info = self.get_self_information(sent) | |
sent_self_info.append(np.mean(self_info)) | |
all_tokens.extend(tokens) | |
all_token_self_info.extend(self_info) | |
noun_phrases, noun_phrases_info = self._calculate_lexical_unit(tokens, self_info) | |
# We need to add a space before the first noun phrase for every sentence except the first one | |
if len(all_noun_phrases) != 0: | |
noun_phrases[0] = f" {noun_phrases[0]}" | |
all_noun_phrases.extend(noun_phrases) | |
all_noun_phrases_info.extend(noun_phrases_info) | |
return [ | |
LexicalUnits('sent', text=sents, self_info=sent_self_info), | |
LexicalUnits('phrase', text=all_noun_phrases, self_info=all_noun_phrases_info), | |
LexicalUnits('token', text=all_tokens, self_info=all_token_self_info) | |
] | |
def _calculate_lexical_unit(self, tokens, self_info): | |
def _unit_info(tokens, self_info, units): | |
current_unit_idx = 0 | |
current_position = 0 | |
unit_self_info = [[] for _ in range(len(units))] | |
for idx, (token, info) in enumerate(zip(tokens, self_info)): | |
current_position += len(token) | |
if current_position == len(units[current_unit_idx]): | |
unit_self_info[current_unit_idx].append(info) | |
current_position = current_position - len(units[current_unit_idx]) | |
current_unit_idx += 1 | |
elif current_position > len(units[current_unit_idx]): | |
counter_ = 1 | |
current_position = current_position - len(units[current_unit_idx]) | |
current_unit_idx += 1 | |
while current_position >= len(units[current_unit_idx]): | |
counter_ += 1 | |
current_position = current_position - len(units[current_unit_idx]) | |
current_unit_idx += 1 | |
if current_unit_idx >= len(units): | |
break | |
partial_info = info/counter_ | |
for _ in range(counter_): | |
unit_self_info[(current_unit_idx-1) - _].append(partial_info) | |
else: | |
if token == " ": | |
continue | |
unit_self_info[current_unit_idx].append(info) | |
unit_self_info_ = [np.mean(info) for info in unit_self_info] | |
return unit_self_info_ | |
def _noun_phrases(sent): | |
noun_phrases = [] | |
doc = self.nlp(sent) | |
for index, chunk in enumerate(doc): | |
if index == 0: | |
noun_phrases.append(chunk.text) | |
else: | |
noun_phrases.append(doc[index-1].whitespace_ + chunk.text) | |
return noun_phrases | |
if self.sent_level_self_info: | |
# in this case, the self_info is for each sentence | |
# we only need to calculate the self_info for each phrase | |
sent = ''.join(tokens) | |
# noun_phrases = [chunk.text for chunk in self.nlp(sent).noun_chunks] | |
noun_phrases = _noun_phrases(sent) | |
# noun_phrases[-1] = noun_phrases[-1] + ' ' | |
noun_phrases_info = _unit_info(tokens, self_info, noun_phrases) | |
return noun_phrases, noun_phrases_info | |
def beautify_context(self, context: str) -> str: | |
context = re.sub(r"\s+", " ", context) | |
return context | |
def self_info_mask(self, sents: List[str], self_info: List[float], mask_level): | |
# mask_level: mask sentences, phrases, or tokens | |
sents_after_mask = [] | |
masked_sents = [] | |
self.ppl_threshold = np.nanpercentile(self_info, self.mask_ratio * 100) | |
# if title is not None: | |
# with open(os.path.join(self.path, title+'_prob_token.tsv'), 'w', encoding='utf-8') as f: | |
# for token, info in zip(tokens, self_info): | |
# f.write(f"{token}\t{info}\n") | |
# with open(os.path.join(self.path, title+'_prob_sent.tsv'), 'w', encoding='utf-8') as f: | |
# for sent, info in zip(sents, sent_self_info): | |
# f.write(f"{sent}\n{info}\n\n") | |
for sent, info in zip(sents, self_info): | |
if info < self.ppl_threshold: | |
masked_sents.append(sent) | |
sents_after_mask.append(self.mask_a_sent(sent, mask_level)) | |
else: | |
sents_after_mask.append(sent) | |
masked_context = " ".join(sents_after_mask) if mask_level == 'sent' else "".join(sents_after_mask) | |
return masked_context, masked_sents | |
def mask_a_sent(self, sent, level): | |
if level == 'phrase': | |
return self.phrase_mask_token | |
elif level == 'sent': | |
return self.sent_mask_token | |
elif level == 'token': | |
return '' | |
def __call__(self, text: str, reduce_ratio: float = 0.35, reduce_level :str = 'phrase') -> List[str]: | |
context = self.beautify_context(text) | |
self.mask_ratio = reduce_ratio | |
sents = re.split(self.sent_tokenize_pattern, context) | |
sents = [sent.strip() for sent in sents if sent.strip()] | |
# You want the reduce happen at sentence level, phrase level, or token level? | |
assert reduce_level in ['sent', 'phrase', 'token'], f"reduce_level should be one of ['sent', 'phrase', 'token'], got {reduce_level}" | |
sent_lus, phrase_lus, token_lus = self._lexical_unit(sents) | |
lexical_level = { | |
'sent': sent_lus, | |
'phrase': phrase_lus, | |
'token': token_lus | |
} | |
# context is the reduced context, masked_sents denotes what context has been filtered out | |
context, masked_sents = self.self_info_mask(lexical_level[reduce_level].text, lexical_level[reduce_level].self_info, reduce_level) | |
return context, masked_sents | |
# streamlit app.py | |
# here we ask the user to input the text and the reduce ratio | |
# then we call the SelectiveContext to compress the text | |
st.title("Selective Context: Compress your prompt") | |
st.markdown("This is a demo for the **Selective Context** algorithm.") | |
st.markdown("Use this algorithm to **compress** your prompt, so that LLMs can deal with **2x more context**!") | |
st.markdown("- The algorithm filters out the content that is less informative. \n - You can also choose to filter out phrases or tokens instead of sentences. \n - Checkout the paper for details and experiments! [https://arxiv.org/abs/2304.12102](https://arxiv.org/abs/2304.12102).") | |
st.write("") | |
st.subheader("Demo") | |
lang = st.radio("Please choose the language: ", ('en', 'zh')) | |
ratio = st.radio("Please choose the compress ratio [we recommend 0.5]: ", (0.5, 0.2, 0.35, 0.65, 0.8)) | |
reduce_level = st.radio("Please choose the reduce level: ", ('phrase', 'token', 'sent')) | |
text = st.text_area("Please input your text here", height=300) | |
def load_model(lang): | |
model = SelectiveContext(lang=lang) | |
return model | |
if st.button("Compress"): | |
model = load_model(lang) | |
context, masked_sents = model(text, reduce_ratio=ratio, reduce_level=reduce_level) | |
st.subheader("The compressed context is:") | |
st.code(context) | |
# st.divider() | |
st.subheader("The filtered out content is:") | |
st.write(masked_sents) |