File size: 1,450 Bytes
fa8a9aa
 
 
 
 
cf81fc2
fa8a9aa
 
 
 
 
f00c92a
fa8a9aa
 
cf81fc2
fa8a9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeda2cf
fa8a9aa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gradio as gr
from transformers import pipeline
import torch

import re
import os

import spaces

torch.set_default_device("cuda")

model_id = "glides/mistral-eap"
pipe = pipeline("text-generation", model=model_id, device_map="auto")

system_prompt = os.environ["sys"]

def follows_rules(s):
    pattern = r'<thinking>.+?</thinking><output>.+?</output><reflecting>.+?</reflecting><refined>.+?</refined>'
    return bool(re.match(pattern, s.replace("\n", "")))

@spaces.GPU(duration=120)
def predict(input_text, history):
    chat = [{"role": "system", "content": system_prompt}]
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            chat.append({"role": "assistant", "content": item[1]})
    chat.append({"role": "user", "content": input_text})
    
    generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content']
    
    removed_pres = "<thinking>" + generated_text.split("<thinking>")[-1]
    removed_posts = removed_pres.split("</refined>")[0] + "</refined>"
    
    while not follows_rules(removed_posts):
        print(f"model output {generated_text} was found invalid")
        generated_text = pipe(chat, max_new_tokens=2 ** 16)[0]['generated_text'][-1]['content']

    model_output = removed_posts.split("<refined>")[-1].replace("</refined>", "")
    
    return model_output.strip()

gr.ChatInterface(predict, theme="soft").launch()