Spaces:
Paused
Paused
File size: 1,453 Bytes
fa8a9aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import gradio as gr
from transformers import pipeline
import torch
import re
import spaces
torch.set_default_device("cuda")
model_id = "glides/llama-eap"
pipe = pipeline("text-generation", model=model_id, device_map="auto")
with open("sys", "r") as f:
system_prompt = f.read()
def follows_rules(s):
pattern = r'<thinking>.+?</thinking><output>.+?</output><reflecting>.+?</reflecting><refined>.+?</refined>'
return bool(re.match(pattern, s.replace("\n", "")))
@spaces.GPU(duration=120)
def predict(input_text, history):
chat = [{"role": "system", "content": system_prompt}]
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": input_text})
generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content']
removed_pres = "<thinking>" + generated_text.split("<thinking>")[-1]
removed_posts = removed_pres.split("</refined>")[0] + "</refined>"
while not follows_rules(removed_posts):
print(f"model output {generated_text} was found invalid")
generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content']
model_output = removed_posts.split("<refined>")[-1].replace("</refined>", "")
return model_output
gr.ChatInterface(predict, theme="soft").launch() |