Spaces:
Running
on
L40S
Running
on
L40S
import os | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
DESCRIPTION = """\ | |
# ⛰️⛰️ JAIS Initiative: Atlas-Chat-9B ⛰️⛰️ | |
Disclaimer: This research demonstration of Atlas-Chat-9B is not intended for end-user applications. The model may generate biased, offensive, or inaccurate content as it is trained on diverse internet data. The developers do not endorse any views expressed by the model and assume no responsibility for the consequences of its use. Users should critically evaluate the generated responses and use the tool at their own risk. | |
Note: The model is expected to take input and generate output in Moroccan Darija with Arabic script. | |
""" | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_id = "MBZUAI-Paris/Atlas-Chat-9B" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
) | |
model.eval() | |
def generate( | |
message: str, | |
chat_history: list[dict], | |
max_new_tokens: int = 1024, | |
do_sample: bool = False, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.0, | |
) -> Iterator[str]: | |
conversation = chat_history.copy() | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
), | |
gr.Checkbox(label="Do Sample"), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.0, | |
maximum=4.0, | |
step=0.1, | |
value=0.6, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=50, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0, | |
), | |
], | |
stop_btn=None, | |
examples=[ | |
['شكون لي صنعك؟'], | |
["شنو كيتسمى المنتخب المغربي ؟"], | |
["أشنو كايمييز المملكة المغربية."], | |
["ترجم للدارجة:\nAtlas-Chat is the first open source large language model that talks in Darija."], | |
], | |
cache_examples=False, | |
type="messages", | |
) | |
with gr.Blocks(css_paths="style.css", fill_height=True) as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
chat_interface.render() | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |