Spaces:
Runtime error
Runtime error
import torch | |
from peft import PeftModel, PeftConfig | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# config = PeftConfig.from_pretrained("/content/llama-2-7b-medichat") | |
model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-7b-chat-hf", return_dict=True, load_in_8bit=True, device_map='auto') | |
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") | |
model = PeftModel.from_pretrained(model, "maxspin/medichat") | |
import gradio as gr | |
iface.launch() | |
def query_handling(query, conversation): | |
if "thanks" in query.lower() or "thank you" in query.lower() or "thank you very much" in query.lower(): | |
conversation="" | |
return conversation | |
def process_response(input_string): | |
# Find the indices of the first [INST] and last [/INST] | |
start_index = input_string.find("[INST]") | |
end_index = input_string.rfind("[/INST]") | |
# If both [INST] and [/INST] are found | |
if start_index != -1 and end_index != -1: | |
# Extract the substring between [INST] and [/INST] | |
inst_substring = input_string[start_index:end_index + len("[/INST]")] | |
# Remove the extracted substring from the original string | |
cleaned_string = input_string.replace(inst_substring, "") | |
else: | |
# If [INST] or [/INST] is not found, keep the original string | |
cleaned_string = input_string | |
# Remove the special characters <s> and </s> | |
cleaned_string = cleaned_string.replace("<s>", "").replace("</s>", "").replace("[INST]","").replace("[/INST]","") | |
return cleaned_string | |
conversation="" | |
def predict(prompt): | |
global conversation | |
conversation = conversation+f"[INST]{prompt}[/INST]" | |
input_sequense = "<s>"+conversation | |
batch = tokenizer(f"{input_sequense}", return_tensors='pt') | |
batch = batch.to('cuda') | |
with torch.cuda.amp.autocast(): | |
output_tokens = model.generate(**batch, max_new_tokens=4000) | |
response = tokenizer.decode(output_tokens[0], skip_special_tokens=True) | |
print('\n\n', tokenizer.decode(output_tokens[0], skip_special_tokens=True)) | |
response = process_response(response) | |
conversation+=response | |
conversation = query_handling(prompt,conversation) | |
print(conversation) | |
return response | |
iface = gr.Interface( | |
fn=predict, | |
inputs="text", # Accepts a single text input | |
outputs="text" # Outputs a single text response | |
) | |