Spaces:
Running
Running
File size: 2,269 Bytes
8b19012 68918ad 02af69b 8b19012 8bc6b74 8b19012 8bc6b74 8b19012 8bc6b74 8b19012 68918ad 8b19012 68918ad 02af69b 68918ad e0128dd 68918ad e0128dd 68918ad 02af69b 68918ad 8b19012 68918ad 8bc6b74 e0128dd 8b19012 6a41687 68918ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import gradio as gr
from gradio import deploy
def generate_prompt(instruction, input=""):
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
if input:
return f"""Instruction: {instruction}
Input: {input}
Response:"""
else:
return f"""User: hi
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
User: {instruction}
Assistant:"""
model_path = "models/rwkv-6-world-1b6/" # Path to your local model directory
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
# use_flash_attention_2=False
).to(torch.float32)
# Create a custom tokenizer (make sure to download vocab.json)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
bos_token="</s>",
eos_token="</ s>",
unk_token="<unk>",
pad_token="<pad>",
trust_remote_code=True,
padding_side='left',
clean_up_tokenization_spaces=False # Or set to True if you prefer
)
# Function to handle text generation with word-by-word output and stop sequence
def generate_text(input_text):
prompt = generate_prompt(input_text)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
generated_text = ""
for i in range(333):
output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0)
new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True)
print(new_word, end="", flush=True) # Print to console for monitoring
generated_text += new_word
input_ids = output
yield generated_text # Yield the updated text after each word
# Create the Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="RWKV Chatbot",
description="Enter your prompt below:",
# flagging_callback=None
flagging_dir="gradio_flagged/"
)
# For local testing:
iface.launch(share=False)
# deploy()
# Hugging Face Spaces will automatically launch the interface.
|