Llama3_kw_gen_new / handler.py
NiCEtmtm's picture
Update handler.py
b3b20a3 verified
raw
history blame contribute delete
No virus
2.31 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import subprocess
import sys
# 导入其他必要库
from transformers import WhisperForConditionalGeneration, pipeline
# from huggingface_hub import split_torch_state_dict_into_shards
# Manually install bitsandbytes
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
# 更新 pip 并安装指定版本的库
subprocess.run(["pip", "install", "--upgrade", "pip"], check=True)
subprocess.run(["pip", "install", "--upgrade", "huggingface_hub"], check=True)
# try:
# import bitsandbytes
# except ImportError:
# install("bitsandbytes==0.43.1")
# try:
# import accelerate
# except ImportError:
# install("accelerate==0.32.1")
import pip
class EndpointHandler:
def __init__(self, model_dir):
self.model = None
self.tokenizer = None
self.model_dir = model_dir
self.load_model(model_dir)
def load_model(self, model_dir):
# Load token as env var
#model_id = "NiCETmtm/Llama3_kw_gen_new"
model_id = self.model_dir
token = os.getenv("HF_API_TOKEN")
# Load model & tokenizer
self.model = AutoModelForCausalLM.from_pretrained(model_id, token=token, from_tf=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
def predict(self, inputs):
tokens = self.tokenizer(inputs, return_tensors="pt")
with torch.no_grad():
outputs = self.model.generate(**tokens)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# 确保在安装依赖项之后再实例化和加载模型
if __name__ == "__main__":
try:
# Install accelerate
pip.main(['install', 'accelerate'])
# Install bitsandbytes from PyPI repository
pip.main(['install', '-i', 'https://pypi.org/simple/', 'bitsandbytes'])
except Exception as e:
print(f"Error installing dependencies: {e}")
model_dir = "NiCETmtm/Llama3_kw_gen_new"
handler = EndpointHandler(model_dir)
handler.load_model()
# 定义处理函数
def inference(event, context):
inputs = event["data"]
outputs = handler.predict(inputs)
return {"predictions": outputs}