|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from flask import Flask, request |
|
import argparse |
|
import logging |
|
|
|
|
|
class LLMInstance: |
|
|
|
def __init__(self, model_path: str, device: str = "cuda"): |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
self.model.to(device) |
|
self.device = device |
|
|
|
def query(self, message): |
|
try: |
|
messages = [ |
|
{"role": "user", "content": message}, |
|
] |
|
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt") |
|
model_inputs = encodeds.to(self.device) |
|
|
|
generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True) |
|
decoded = self.tokenizer.batch_decode(generated_ids) |
|
|
|
|
|
output = decoded[0].split("[/INST]")[1].split("</s>")[0] |
|
return { |
|
'code': 0, |
|
'ret': True, |
|
'error_msg': None, |
|
'output': output |
|
} |
|
except Exception as e: |
|
return { |
|
'code': 1, |
|
'ret': False, |
|
'error_msg': str(e), |
|
'output': None |
|
} |
|
|
|
|
|
def create_app(core): |
|
app = Flask(__name__) |
|
|
|
@app.route('/ask_llm_for_answer', methods=['POST']) |
|
def ask_llm_for_answer(): |
|
user_text = request.json['user_text'] |
|
return core.query(user_text) |
|
|
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-m', '--model_path', required=True, default='Mistral-7B-Instruct-v0.1', help='the model path of reward model') |
|
parser.add_argument('--ip', default='0.0.0.0') |
|
parser.add_argument('-p', '--port', default=8001) |
|
parser.add_argument('--debug', action='store_true') |
|
args = parser.parse_args() |
|
|
|
if args.debug: |
|
logging.getLogger().setLevel(logging.DEBUG) |
|
else: |
|
logging.getLogger().setLevel(logging.INFO) |
|
logging.getLogger().addHandler(logging.StreamHandler()) |
|
logging.getLogger().handlers[0].setFormatter(logging.Formatter("%(message)s")) |
|
|
|
core = LLMInstance(args.model_path) |
|
app = create_app(core) |
|
app.run(host=args.ip, port=args.port) |
|
|