antony-pk's picture
Rename app.py to app_pkb.py
cbd3c65 verified
raw
history blame
2.66 kB
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer
)
import os
from threading import Thread
import spaces
import time
hf_token = os.environ["HF_TOKEN"]
model_name = os.environ["MODEL_NAME"]
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
terminators = [
tokenizer.eos_token_id,
]
if torch.cude.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
model = model.to(device)
@spaces.GPU(duration=60)
def chat(message, history, temperature, do_sample, max_tokens):
chat = []
for item in history:
chat.append({
"role": "user",
"content": item[0]
})
if item[1] is not None:
chat.append({
"role": "assistant",
"content": item[1]
})
chat.append({
"role": "user",
"content": message
})
messages = tokenizer.apply_chat_template(chat, tokenize=False, add_gereration_prompt=True)
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tokenizer,
timeout=20,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
eos_token_id=terminators
)
if temperature == 0:
generate_kwargs["do_sample"] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
yield partial_text
demo = gr.ChatInterface(
fn=chat,
examples=[["write me a poem about machine Learning"]],
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
),
gr.Checkbox(label="Sampling", value=True),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False
),
],
stop_btn="Stop Generation",
title="Chat with Phi3.5 ERPNext",
description="Noew Running antony - Phi3.5 ERPNext"
)
demo.launch()