luciusssss's picture
Upload 22 files
a48216a verified
raw
history blame
4.69 kB
import json
import subprocess
import os
import codecs
import logging
import os
import math
import json
import random
from tqdm import tqdm
from transformers import pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from flask import Flask, request, jsonify
import json
import random
from tqdm import tqdm
import os
import pickle as pkl
from argparse import Namespace
from models import Elect
import torch
from transformers import AutoModel,AutoTokenizer
from sklearn.preprocessing import MultiLabelBinarizer
logger = logging.getLogger(__name__)
app = Flask(__name__)
hunyin_classifier = None
fatiao_args = Namespace()
fatiao_tokenizer = None
fatiao_model = None
@app.route('/check_hunyin', methods=['GET', 'POST'])
def check_hunyin():
input_text = request.json['input'].strip()
force_return = request.json['force_return'] if 'force_return' in request.json else False
print("input_text:", input_text)
if len(input_text) == 0:
json_result = {
"output": []
}
return jsonify(json_result)
if not force_return:
classifier_result = hunyin_classifier(input_text[:500])
print(classifier_result)
classifier_result = classifier_result[0]['label']
# 加一条规则,如果输入文本中包含“婚”字,那么直接判定为婚姻相关
if '婚' in input_text:
classifier_result = True
# 如果不是婚姻相关的,直接返回空
if classifier_result == False:
json_result = {
"output": []
}
return jsonify(json_result)
inputs = fatiao_tokenizer(input_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
batch = {
'ids': inputs['input_ids'],
'mask': inputs['attention_mask'],
'token_type_ids':inputs["token_type_ids"]
}
model_output = fatiao_model(batch)
pred = torch.sigmoid(model_output).cpu().detach().numpy()[0]
pred_laws = []
for law_id, score in sorted(enumerate(pred), key=lambda x: x[1], reverse=True):
pred_laws.append({
'id': law_id,
'score': float(score),
'text': fatiao_args.mlb.classes_[law_id]
})
json_result = {
"output": pred_laws[:3]
}
print("json_result:", json_result)
return jsonify(json_result)
if __name__ == '__main__':
# 加载咨询分类模型,用于判断是否与婚姻有关
hunyin_classifier_path = "./pretrained_models/roberta_wwm_ext_hunyin_2epoch"
hunyin_config = AutoConfig.from_pretrained(
hunyin_classifier_path,
num_labels=2,
)
hunyin_tokenizer = AutoTokenizer.from_pretrained(
hunyin_classifier_path
)
hunyin_model = AutoModelForSequenceClassification.from_pretrained(
hunyin_classifier_path,
config=hunyin_config,
)
hunyin_classifier = pipeline(model=hunyin_model, tokenizer=hunyin_tokenizer, task="text-classification", device=0)
# 加载法条检索模型
fatiao_args.ckpt_dir = "./pretrained_models/chinese-roberta-wwm-ext"
fatiao_args.device = "cuda:0"
with open(os.path.join("data/labels2id.pkl"), "rb") as f:
laws2id = pkl.load(f)
fatiao_args.labels = list(laws2id.keys())
# get id2laws
id2laws = {}
for k, v in laws2id.items():
id2laws[v] = k
# fatiao_args.id2laws = id2laws
print("法条个数:", len(id2laws))
fatiao_tokenizer = AutoTokenizer.from_pretrained(fatiao_args.ckpt_dir)
fatiao_args.tokenizer = fatiao_tokenizer
fatiao_model = Elect(fatiao_args, "cuda:0").to("cuda:0")
fatiao_model.eval()
mlb = MultiLabelBinarizer() # mlb.classes_: idx to law article
mlb.fit([fatiao_args.labels])
fatiao_args.mlb = mlb
with torch.no_grad():
for idx, l in enumerate(fatiao_args.labels):
# remove 《民法典》第xxxx条:
text = ':'.join(l.split(':')[1:]).lower()
la_in = fatiao_tokenizer(text, padding='max_length', truncation=True, max_length=256,
return_tensors="pt")
ids = la_in['input_ids'].to(fatiao_args.device)
mask = la_in['attention_mask'].to(fatiao_args.device)
fatiao_model.la[idx] += (fatiao_model.plm(input_ids=ids, attention_mask=mask)[0][:,0]).squeeze(0)
fatiao_model.load_state_dict(torch.load('./pretrained_models/ELECT', map_location=torch.device(fatiao_args.device)))
fatiao_model.to(fatiao_args.device)
logger.info("model loaded")
app.run(host="0.0.0.0", port=9098, debug=False)