|
import re |
|
import yaml |
|
import gc |
|
import copy |
|
import time |
|
from tenacity import RetryError |
|
from tenacity import retry, stop_after_attempt, wait_fixed |
|
import gradio as gr |
|
import torch |
|
from peft import PeftModel |
|
from transformers import ( |
|
LLaMATokenizer, |
|
LLaMAForCausalLM, |
|
GenerationConfig, |
|
AutoModelForCausalLM, |
|
AutoModelForSeq2SeqLM, |
|
AutoTokenizer, |
|
LogitsProcessorList, |
|
MinNewTokensLengthLogitsProcessor, |
|
TemperatureLogitsWarper, |
|
TopPLogitsWarper, |
|
MinLengthLogitsProcessor |
|
) |
|
|
|
assert torch.cuda.is_available(), "Change the runtime type to GPU" |
|
|
|
|
|
num_of_characters_to_keep = 1000 |
|
|
|
|
|
html_tag_pattern = re.compile(r"<.*?>") |
|
multi_line_pattern = re.compile(r"\n+") |
|
multi_space_pattern = re.compile(r"( )") |
|
multi_br_tag_pattern = re.compile(re.compile(r'<br>\s*(<br>\s*)*')) |
|
|
|
|
|
repl_linebreak = "\n" |
|
repl_empty_str = "" |
|
|
|
TITLE = "🦌 Stambecco 🇮🇹" |
|
|
|
ABSTRACT = """ |
|
Stambecco is a Italian Instruction-following model based on the [LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) model. It comes in two versions: 7b and 13b parameters. It is trained on an Italian version of the [GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) dataset, a dataset of `GPT-4` generated instruction-following data. |
|
This demo is intended to show and evaluate the conversational capabilities of the model. |
|
For more information, please visit [the project's website](https://github.com/mchl-labs/stambecco). |
|
NOTE: Too long input (context, instruction) will not be allowed. Please keep context < 500 and instruction < 150 |
|
""" |
|
|
|
BOTTOM_LINE = """ |
|
By default, this demo runs with streaming mode, but you can also run with dynamic batch generation model. |
|
Stambecco is built on the same concept as Standford Alpaca project, but using LoRA it lets us train and inference on a smaller GPUs such as RTX4090 for 7B version. Also, we could build very small size of checkpoints on top of base models thanks to [🤗 transformers](https://huggingface.co/docs/transformers/index), [🤗 peft](https://github.com/huggingface/peft), and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes/tree/main) libraries. |
|
This demo currently runs 8Bit 7b version of the model. |
|
""" |
|
|
|
DEFAULT_EXAMPLES = { |
|
"Typical Questions": [ |
|
{ |
|
"title": "Parlami di Giulio Cesare.", |
|
"examples": [ |
|
["1", "Scrivi un articolo su Giulio Cesare"], |
|
["2", "Davvero?"], |
|
["3", "Quanto era ricco Giulio Cesare?"], |
|
["4", "Chi è stato il suo successore?"], |
|
] |
|
}, |
|
{ |
|
"title": "Parigi", |
|
"examples": [ |
|
["1", "Scrivi un tema sulla città di Parigi"], |
|
["2", "Fai un elenco di 5 posti da visitare assolutamente"], |
|
["3", "Quali eventi importanti della Storia sono avvenuti a Parigi?"], |
|
["4", "Quale è il periodo migliore per visitare Parigi?"], |
|
] |
|
}, |
|
{ |
|
"title": "Scrivi un programma in Python che stampi i primi 10 numeri di Fibonacci", |
|
"examples": [ |
|
["1", "Scrivi un programma in Python che stampi i primi 10 numeri di Fibonacci"], |
|
["2", "Potresti spiegarmi come funziona il codice?"], |
|
["3", "Cos'è la ricorsione?"], |
|
] |
|
} |
|
], |
|
} |
|
|
|
SPECIAL_STRS = { |
|
"continue": "continua", |
|
"summarize": "Di cosa abbiamo discusso finora? Descrivi nella user's view." |
|
} |
|
|
|
PARENT_BLOCK_CSS = """ |
|
#col_container { |
|
width: 95%; |
|
margin-left: auto; |
|
margin-right: auto; |
|
} |
|
#chatbot { |
|
height: 500px; |
|
overflow: auto; |
|
} |
|
""" |
|
|
|
def load_model( |
|
base="decapoda-research/llama-7b-hf", |
|
finetuned="mchl-labs/stambecco-7b-plus", |
|
): |
|
tokenizer = LLaMATokenizer.from_pretrained(base) |
|
tokenizer.pad_token_id = 0 |
|
tokenizer.padding_side = "left" |
|
|
|
model = LLaMAForCausalLM.from_pretrained( |
|
base, |
|
load_in_8bit=True, |
|
device_map="auto", |
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(model, finetuned) |
|
return model, tokenizer |
|
|
|
def get_generation_config(path): |
|
with open(path, 'rb') as f: |
|
generation_config = yaml.safe_load(f.read()) |
|
|
|
return GenerationConfig(**generation_config["generation_config"]) |
|
|
|
def generate_prompt(prompt, histories, ctx=None, partial=False): |
|
convs = f"""Di seguito è riportata una cronologia delle istruzioni che descrivono le tasks, abbinate a un input che fornisce ulteriore contesto. Scrivi una risposta che completi adeguatamente la richiesta ricordando la cronologia della conversazione. |
|
|
|
""" |
|
|
|
if ctx is not None: |
|
convs = f"""### Input: {ctx} |
|
""" |
|
|
|
sub_convs = "" |
|
start_idx = 0 |
|
|
|
for idx, history in enumerate(histories): |
|
history_prompt = history[0] |
|
history_response = history[1] |
|
if history_response == "✅ Riepilogo della conversazione effettuato e impostato come contesto" or history_prompt == SPECIAL_STRS["summarize"]: |
|
start_idx = idx |
|
|
|
|
|
for history in histories[start_idx if start_idx == 0 else start_idx+1:]: |
|
history_prompt = history[0] |
|
history_response = history[1] |
|
|
|
history_response = history_response.replace("<br>", "\n") |
|
history_response = re.sub( |
|
html_tag_pattern, repl_empty_str, history_response |
|
) |
|
|
|
sub_convs = sub_convs + f"""### Istruzione: {history_prompt} |
|
### Risposta: {history_response} |
|
""" |
|
|
|
sub_convs = sub_convs + f"""### Istruzione: {prompt} |
|
### Risposta:""" |
|
|
|
convs = convs + sub_convs |
|
return sub_convs if partial else convs, len(sub_convs) |
|
|
|
def common_post_process(original_str): |
|
original_str = re.sub( |
|
multi_line_pattern, repl_linebreak, original_str |
|
) |
|
return original_str |
|
|
|
def post_process_stream(bot_response): |
|
|
|
|
|
if "### Risposta:" in bot_response or "### Input:" in bot_response: |
|
bot_response = bot_response.replace("### Risposta:", '').replace("### Input:", '').strip() |
|
return bot_response, True |
|
|
|
return common_post_process(bot_response), False |
|
|
|
def post_process_batch(bot_response): |
|
bot_response = bot_response.split("### Risposta:")[-1].strip() |
|
return common_post_process(bot_response) |
|
|
|
def post_processes_batch(bot_responses): |
|
return [post_process_batch(r) for r in bot_responses] |
|
|
|
def get_output_batch( |
|
model, tokenizer, prompts, generation_config |
|
): |
|
if len(prompts) == 1: |
|
encoding = tokenizer(prompts, return_tensors="pt") |
|
input_ids = encoding["input_ids"].cuda() |
|
generated_id = model.generate( |
|
input_ids=input_ids, |
|
generation_config=generation_config, |
|
max_new_tokens=256 |
|
) |
|
|
|
decoded = tokenizer.batch_decode(generated_id) |
|
del input_ids, generated_id |
|
torch.cuda.empty_cache() |
|
return decoded |
|
else: |
|
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') |
|
generated_ids = model.generate( |
|
**encodings, |
|
generation_config=generation_config, |
|
max_new_tokens=256 |
|
) |
|
|
|
decoded = tokenizer.batch_decode(generated_ids) |
|
del encodings, generated_ids |
|
torch.cuda.empty_cache() |
|
return decoded |
|
|
|
|
|
|
|
|
|
class StreamModel: |
|
"""StreamModel wraps around a language model to provide stream decoding.""" |
|
|
|
def __init__(self, model, tokenizer): |
|
super().__init__() |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.processor = LogitsProcessorList() |
|
self.processor.append(TemperatureLogitsWarper(0.9)) |
|
self.processor.append(TopPLogitsWarper(0.75)) |
|
|
|
|
|
def __call__( |
|
self, |
|
prompt, |
|
min_tokens=0, |
|
max_tokens=16, |
|
temperature=1.0, |
|
top_p=1.0, |
|
n=1, |
|
logprobs=0, |
|
): |
|
"""Create a completion stream for the provided prompt.""" |
|
input_ids = self.tokenize(prompt) |
|
logprobs = max(logprobs, 0) |
|
|
|
|
|
chunk_size = 2 |
|
chunk_count = 0 |
|
|
|
|
|
final_tokens = torch.empty(0) |
|
|
|
for tokens in self.generate( |
|
input_ids[None, :].repeat(n, 1), |
|
logprobs=logprobs, |
|
min_new_tokens=min_tokens, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
if chunk_count < chunk_size: |
|
chunk_count = chunk_count + 1 |
|
|
|
final_tokens = torch.cat((final_tokens, tokens.to("cpu"))) |
|
|
|
if chunk_count == chunk_size-1: |
|
chunk_count = 0 |
|
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) |
|
|
|
if chunk_count > 0: |
|
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) |
|
|
|
del final_tokens, input_ids |
|
if self.device == "cuda": |
|
torch.cuda.empty_cache() |
|
|
|
def _infer(self, model_fn, **kwargs): |
|
with torch.inference_mode(): |
|
return model_fn(**kwargs) |
|
|
|
def tokenize(self, text): |
|
"""Tokenize a string into a tensor of token IDs.""" |
|
batch = self.tokenizer.encode(text, return_tensors="pt") |
|
return batch[0].to(self.device) |
|
|
|
def generate(self, input_ids, logprobs=0, **kwargs): |
|
"""Generate a stream of predicted tokens using the language model.""" |
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
input_length = input_ids.shape[-1] |
|
|
|
|
|
config = self.model.generation_config |
|
config = copy.deepcopy(config) |
|
kwargs = config.update(**kwargs) |
|
kwargs["output_attentions"] = False |
|
kwargs["output_hidden_states"] = False |
|
kwargs["use_cache"] = True |
|
|
|
|
|
pad_token_id = config.pad_token_id |
|
bos_token_id = config.bos_token_id |
|
eos_token_id = config.eos_token_id |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
if pad_token_id is None and eos_token_id is not None: |
|
pad_token_id = eos_token_id[0] |
|
|
|
|
|
if input_length == 0: |
|
input_ids = input_ids.new_ones((batch_size, 1)).long() |
|
if eos_token_id is not None: |
|
input_ids = input_ids * eos_token_id[0] |
|
input_length = 1 |
|
|
|
|
|
unfinished = input_ids.new_ones(batch_size) |
|
|
|
|
|
while True: |
|
inputs = self.model.prepare_inputs_for_generation( |
|
input_ids, **kwargs |
|
) |
|
|
|
outputs = self._infer( |
|
self.model, |
|
**inputs, |
|
|
|
output_attentions=False, |
|
output_hidden_states=False, |
|
) |
|
|
|
|
|
logits = outputs.logits[:, -1, :] |
|
with torch.inference_mode(): |
|
logits = self.processor(input_ids, logits) |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
|
|
if (config.top_p is not None and config.top_p <= 0) or ( |
|
config.temperature is not None and config.temperature <= 0 |
|
): |
|
tokens = torch.argmax(probs, dim=-1)[:, None] |
|
else: |
|
tokens = torch.multinomial(probs, num_samples=1) |
|
|
|
tokens = tokens.squeeze(1) |
|
|
|
|
|
if pad_token_id is not None: |
|
tokens = tokens * unfinished + pad_token_id * (1 - unfinished) |
|
|
|
|
|
input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1) |
|
|
|
|
|
if eos_token_id is not None: |
|
not_eos = sum(tokens != i for i in eos_token_id) |
|
unfinished = unfinished.mul(not_eos.long()) |
|
|
|
|
|
status = unfinished.clone() |
|
if input_ids.shape[-1] - input_length >= config.max_new_tokens: |
|
status = 0 - status |
|
|
|
|
|
yield tokens |
|
|
|
|
|
if status.max() <= 0: |
|
break |
|
|
|
generation_config = get_generation_config( |
|
"./generation_config_default.yaml" |
|
) |
|
|
|
model, tokenizer = load_model( |
|
|
|
|
|
) |
|
|
|
stream_model = StreamModel(model, tokenizer) |
|
|
|
def chat_stream( |
|
context, |
|
instruction, |
|
state_chatbot, |
|
): |
|
if len(context) > 1000 or len(instruction) > 300: |
|
raise gr.Error("Context or prompt is too long!") |
|
|
|
bot_summarized_response = '' |
|
|
|
instruction_display = instruction |
|
instruction_prompt, conv_length = generate_prompt(instruction, state_chatbot, context) |
|
|
|
if conv_length > num_of_characters_to_keep: |
|
instruction_prompt = generate_prompt(SPECIAL_STRS["summarize"], state_chatbot, context, partial=True)[0] |
|
|
|
state_chatbot = state_chatbot + [ |
|
( |
|
None, |
|
"![](https://s2.gifyu.com/images/icons8-loading-circle.gif) Conversazione troppo lunga, sto riassumendo..." |
|
) |
|
] |
|
yield (state_chatbot, state_chatbot, context) |
|
|
|
bot_summarized_response = get_output_batch( |
|
model, tokenizer, [instruction_prompt], generation_config |
|
)[0] |
|
bot_summarized_response = bot_summarized_response.split("### Risposta:")[-1].strip() |
|
|
|
state_chatbot[-1] = ( |
|
None, |
|
"✅ Riepilogo della conversazione effettuato e impostato come contesto" |
|
) |
|
print(f"bot_summarized_response: {bot_summarized_response}") |
|
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) |
|
|
|
instruction_prompt = generate_prompt(instruction, state_chatbot, f"{context} {bot_summarized_response}")[0] |
|
|
|
bot_response = stream_model( |
|
instruction_prompt, |
|
max_tokens=256, |
|
temperature=1, |
|
top_p=0.9 |
|
) |
|
|
|
instruction_display = None if instruction_display == SPECIAL_STRS["continue"] else instruction_display |
|
state_chatbot = state_chatbot + [(instruction_display, None)] |
|
yield (state_chatbot, state_chatbot, f"{context}. {bot_summarized_response}".strip()) |
|
|
|
prev_index = 0 |
|
agg_tokens = "" |
|
cutoff_idx = 0 |
|
for tokens in bot_response: |
|
tokens = tokens.strip() |
|
cur_token = tokens[prev_index:] |
|
|
|
if "#" in cur_token and agg_tokens == "": |
|
cutoff_idx = tokens.find("#") |
|
agg_tokens = tokens[cutoff_idx:] |
|
|
|
if agg_tokens != "": |
|
if len(agg_tokens) < len("### Istruzione:") : |
|
agg_tokens = agg_tokens + cur_token |
|
elif len(agg_tokens) >= len("### Istruzione:"): |
|
if tokens.find("### Istruzione:") > -1: |
|
processed_response, _ = post_process_stream(tokens[:tokens.find("### Istruzione:")].strip()) |
|
|
|
state_chatbot[-1] = ( |
|
instruction_display, |
|
processed_response |
|
) |
|
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) |
|
break |
|
else: |
|
agg_tokens = "" |
|
cutoff_idx = 0 |
|
|
|
if agg_tokens == "": |
|
processed_response, to_exit = post_process_stream(tokens) |
|
state_chatbot[-1] = (instruction_display, processed_response) |
|
yield (state_chatbot, state_chatbot, f"{context} {bot_summarized_response}".strip()) |
|
|
|
if to_exit: |
|
break |
|
|
|
prev_index = len(tokens) |
|
|
|
yield ( |
|
state_chatbot, |
|
state_chatbot, |
|
f"{context} {bot_summarized_response}".strip() |
|
) |
|
|
|
|
|
def chat_batch( |
|
contexts, |
|
instructions, |
|
state_chatbots, |
|
): |
|
state_results = [] |
|
ctx_results = [] |
|
|
|
instruct_prompts = [ |
|
generate_prompt(instruct, histories, ctx) |
|
for ctx, instruct, histories in zip(contexts, instructions, state_chatbots) |
|
] |
|
|
|
bot_responses = get_output_batch( |
|
model, tokenizer, instruct_prompts, generation_config |
|
) |
|
bot_responses = post_processes_batch(bot_responses) |
|
|
|
for ctx, instruction, bot_response, state_chatbot in zip(contexts, instructions, bot_responses, state_chatbots): |
|
new_state_chatbot = state_chatbot + [('' if instruction == SPECIAL_STRS["continue"] else instruction, bot_response)] |
|
ctx_results.append(gr.Textbox.update(value=bot_response) if instruction == SPECIAL_STRS["summarize"] else ctx) |
|
state_results.append(new_state_chatbot) |
|
|
|
return (state_results, state_results, ctx_results) |
|
|
|
def reset_textbox(): |
|
return gr.Textbox.update(value='') |
|
|
|
def reset_everything( |
|
context_txtbox, |
|
instruction_txtbox, |
|
state_chatbot): |
|
|
|
state_chatbot = [] |
|
|
|
return ( |
|
state_chatbot, |
|
state_chatbot, |
|
gr.Textbox.update(value=''), |
|
gr.Textbox.update(value=''), |
|
) |
|
|
|
with gr.Blocks(css=PARENT_BLOCK_CSS) as demo: |
|
state_chatbot = gr.State([]) |
|
|
|
with gr.Column(elem_id='col_container'): |
|
gr.Markdown(f"## {TITLE}\n\n\n{ABSTRACT}") |
|
|
|
with gr.Accordion("Context Setting", open=False): |
|
context_txtbox = gr.Textbox(placeholder="Surrounding information to AI", label="Enter Context") |
|
hidden_txtbox = gr.Textbox(placeholder="", label="Order", visible=False) |
|
|
|
chatbot = gr.Chatbot(elem_id='chatbot', label="Stambecco") |
|
instruction_txtbox = gr.Textbox(placeholder="What do you want to say to AI?", label="Instruction") |
|
with gr.Row(): |
|
cancel_btn = gr.Button(value="Cancel") |
|
reset_btn = gr.Button(value="Reset") |
|
|
|
with gr.Accordion("Helper Buttons", open=False): |
|
gr.Markdown(f"`Continue` lets AI to complete the previous incomplete answers. `Summarize` lets AI to summarize the conversations so far.") |
|
continue_txtbox = gr.Textbox(value=SPECIAL_STRS["continue"], visible=False) |
|
summrize_txtbox = gr.Textbox(value=SPECIAL_STRS["summarize"], visible=False) |
|
|
|
continue_btn = gr.Button(value="Continue") |
|
summarize_btn = gr.Button(value="Summarize") |
|
|
|
gr.Markdown("#### Examples") |
|
for _, (category, examples) in enumerate(DEFAULT_EXAMPLES.items()): |
|
with gr.Accordion(category, open=False): |
|
if category == "Identity": |
|
for item in examples: |
|
with gr.Accordion(item["title"], open=False): |
|
gr.Examples( |
|
examples=item["examples"], |
|
inputs=[ |
|
hidden_txtbox, context_txtbox, instruction_txtbox |
|
], |
|
label=None |
|
) |
|
else: |
|
for item in examples: |
|
with gr.Accordion(item["title"], open=False): |
|
gr.Examples( |
|
examples=item["examples"], |
|
inputs=[ |
|
hidden_txtbox, instruction_txtbox |
|
], |
|
label=None |
|
) |
|
|
|
gr.Markdown(f"{BOTTOM_LINE}") |
|
|
|
|
|
send_event = instruction_txtbox.submit( |
|
chat_stream, |
|
[context_txtbox, instruction_txtbox, state_chatbot], |
|
[state_chatbot, chatbot, context_txtbox], |
|
) |
|
reset_event = instruction_txtbox.submit( |
|
reset_textbox, |
|
[], |
|
[instruction_txtbox], |
|
) |
|
|
|
continue_event = continue_btn.click( |
|
chat_stream, |
|
[context_txtbox, continue_txtbox, state_chatbot], |
|
[state_chatbot, chatbot, context_txtbox], |
|
) |
|
reset_continue_event = continue_btn.click( |
|
reset_textbox, |
|
[], |
|
[instruction_txtbox], |
|
) |
|
|
|
summarize_event = summarize_btn.click( |
|
chat_stream, |
|
[context_txtbox, summrize_txtbox, state_chatbot], |
|
[state_chatbot, chatbot, context_txtbox], |
|
) |
|
summarize_reset_event = summarize_btn.click( |
|
reset_textbox, |
|
[], |
|
[instruction_txtbox], |
|
) |
|
|
|
cancel_btn.click( |
|
None, None, None, |
|
cancels=[ |
|
send_event, continue_event, summarize_event |
|
] |
|
) |
|
|
|
reset_btn.click( |
|
reset_everything, |
|
[context_txtbox, instruction_txtbox, state_chatbot], |
|
[state_chatbot, chatbot, context_txtbox, instruction_txtbox], |
|
cancels=[ |
|
send_event, continue_event, summarize_event |
|
] |
|
) |
|
|
|
demo.queue( |
|
concurrency_count=1, |
|
max_size=100, |
|
).launch( |
|
max_threads=5, |
|
server_name="0.0.0.0", |
|
share=True |
|
) |