Spaces:
Paused
Paused
import gradio as gr | |
from transformers import pipeline | |
import torch | |
import re | |
import os | |
import spaces | |
torch.set_default_device("cuda") | |
model_id = "glides/llama-eap" | |
pipe = pipeline("text-generation", model=model_id, device_map="auto") | |
system_prompt = os.environ["sys"] | |
def follows_rules(s): | |
pattern = r'<thinking>.+?</thinking><output>.+?</output><reflecting>.+?</reflecting><refined>.+?</refined>' | |
return bool(re.match(pattern, s.replace("\n", ""))) | |
def predict(input_text, history): | |
chat = [{"role": "system", "content": system_prompt}] | |
for item in history: | |
chat.append({"role": "user", "content": item[0]}) | |
if item[1] is not None: | |
chat.append({"role": "assistant", "content": item[1]}) | |
chat.append({"role": "user", "content": input_text}) | |
generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content'] | |
removed_pres = "<thinking>" + generated_text.split("<thinking>")[-1] | |
removed_posts = removed_pres.split("</refined>")[0] + "</refined>" | |
while not follows_rules(removed_posts): | |
print(f"model output {generated_text} was found invalid") | |
generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content'] | |
model_output = removed_posts.split("<refined>")[-1].replace("</refined>", "") | |
return model_output | |
gr.ChatInterface(predict, theme="soft").launch() |