import subprocess # Installing flash_attn subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) from threading import Thread import spaces import gradio as gr import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer ) model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto') tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = model.config.eos_token_id for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False @spaces.GPU() def predict(history, prompt, max_length, top_p, temperature): stop = StopOnTokens() messages = [] if prompt: messages.append({"role": "system", "content": prompt}) for idx, (user_msg, model_msg) in enumerate(history): if prompt and idx == 0: continue if idx == len(history) - 1 and not model_msg: query = user_msg break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to( next(model.parameters()).device) streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True) eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|observation|>")] generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_length, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": StoppingCriteriaList([stop]), "repetition_penalty": 1, "eos_token_id": eos_token_id, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() for new_token in streamer: if new_token and '<|user|>' not in new_token: history[-1][1] += new_token yield history with gr.Blocks() as demo: gr.Markdown( """