File size: 1,325 Bytes
561ca81
ab85003
561ca81
2432281
ab85003
b74218d
2432281
b74218d
561ca81
 
 
 
 
be6ae85
588129f
 
d903d0d
588129f
561ca81
99dd4f5
 
 
561ca81
99dd4f5
 
 
 
 
 
 
 
 
 
 
 
 
 
561ca81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from rwkvstic.load import RWKV
import torch
model = RWKV(
    "https://huggingface.co/BlinkDL/rwkv-4-pile-1b5/resolve/main/RWKV-4-Pile-1B5-Instruct-test1-20230124.pth",
    "pytorch(cpu/gpu)",
    runtimedtype=torch.float32,
    useGPU=torch.cuda.is_available(),
    dtype=torch.float32
)
import gradio as gr


def predict(input, history=None):
    model.setState(history[1])
    model.loadContext(newctx=f"Prompt: {input}\n\nExpert Long Detailed Response: ")
    r = model.forward(number=100,stopStrings=["\n\nPrompt"])
    rr = [(input,r["output"])]
    return [*history[0],*rr], [[*history[0],*rr],r["state"]]

def freegen(input):
    model.resetState()
    return model.loadContext(newctx=input)["output"]

with gr.Blocks() as demo:
    with gr.Tab("Chatbot"):
        chatbot = gr.Chatbot()
        state = model.emptyState
        state = gr.State([[],state])
        with gr.Row():
            txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
    
        txt.submit(predict, [txt, state], [chatbot, state])
    with gr.Tab("Free Gen"):
        with gr.Row():
            input = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)
            outtext = gr.Textbox()
        input.submit(freegen,input,outtext)
demo.launch()