|
import json |
|
import subprocess |
|
import os |
|
import codecs |
|
import logging |
|
import os |
|
import math |
|
|
|
import json |
|
import random |
|
from tqdm import tqdm |
|
from transformers import pipeline |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
|
|
|
|
|
from flask import Flask, request, jsonify |
|
import json |
|
import random |
|
from tqdm import tqdm |
|
import os |
|
import pickle as pkl |
|
from argparse import Namespace |
|
|
|
from models import Elect |
|
|
|
import torch |
|
from transformers import AutoModel,AutoTokenizer |
|
|
|
from sklearn.preprocessing import MultiLabelBinarizer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
hunyin_classifier = None |
|
|
|
fatiao_args = Namespace() |
|
fatiao_tokenizer = None |
|
fatiao_model = None |
|
|
|
|
|
@app.route('/check_hunyin', methods=['GET', 'POST']) |
|
def check_hunyin(): |
|
input_text = request.json['input'].strip() |
|
force_return = request.json['force_return'] if 'force_return' in request.json else False |
|
|
|
print("input_text:", input_text) |
|
|
|
if len(input_text) == 0: |
|
json_result = { |
|
"output": [] |
|
} |
|
return jsonify(json_result) |
|
|
|
if not force_return: |
|
classifier_result = hunyin_classifier(input_text[:500]) |
|
print(classifier_result) |
|
classifier_result = classifier_result[0]['label'] |
|
|
|
|
|
if '婚' in input_text: |
|
classifier_result = True |
|
|
|
|
|
if classifier_result == False: |
|
json_result = { |
|
"output": [] |
|
} |
|
return jsonify(json_result) |
|
|
|
inputs = fatiao_tokenizer(input_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt") |
|
batch = { |
|
'ids': inputs['input_ids'], |
|
'mask': inputs['attention_mask'], |
|
'token_type_ids':inputs["token_type_ids"] |
|
} |
|
model_output = fatiao_model(batch) |
|
pred = torch.sigmoid(model_output).cpu().detach().numpy()[0] |
|
pred_laws = [] |
|
for law_id, score in sorted(enumerate(pred), key=lambda x: x[1], reverse=True): |
|
pred_laws.append({ |
|
'id': law_id, |
|
'score': float(score), |
|
'text': fatiao_args.mlb.classes_[law_id] |
|
}) |
|
|
|
json_result = { |
|
"output": pred_laws[:3] |
|
} |
|
|
|
print("json_result:", json_result) |
|
return jsonify(json_result) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
hunyin_classifier_path = "./pretrained_models/roberta_wwm_ext_hunyin_2epoch" |
|
hunyin_config = AutoConfig.from_pretrained( |
|
hunyin_classifier_path, |
|
num_labels=2, |
|
) |
|
hunyin_tokenizer = AutoTokenizer.from_pretrained( |
|
hunyin_classifier_path |
|
) |
|
hunyin_model = AutoModelForSequenceClassification.from_pretrained( |
|
hunyin_classifier_path, |
|
config=hunyin_config, |
|
) |
|
hunyin_classifier = pipeline(model=hunyin_model, tokenizer=hunyin_tokenizer, task="text-classification", device=0) |
|
|
|
|
|
|
|
fatiao_args.ckpt_dir = "./pretrained_models/chinese-roberta-wwm-ext" |
|
fatiao_args.device = "cuda:0" |
|
|
|
with open(os.path.join("data/labels2id.pkl"), "rb") as f: |
|
laws2id = pkl.load(f) |
|
fatiao_args.labels = list(laws2id.keys()) |
|
|
|
id2laws = {} |
|
for k, v in laws2id.items(): |
|
id2laws[v] = k |
|
|
|
print("法条个数:", len(id2laws)) |
|
|
|
fatiao_tokenizer = AutoTokenizer.from_pretrained(fatiao_args.ckpt_dir) |
|
|
|
fatiao_args.tokenizer = fatiao_tokenizer |
|
fatiao_model = Elect(fatiao_args, "cuda:0").to("cuda:0") |
|
fatiao_model.eval() |
|
|
|
mlb = MultiLabelBinarizer() |
|
mlb.fit([fatiao_args.labels]) |
|
fatiao_args.mlb = mlb |
|
|
|
with torch.no_grad(): |
|
for idx, l in enumerate(fatiao_args.labels): |
|
|
|
text = ':'.join(l.split(':')[1:]).lower() |
|
la_in = fatiao_tokenizer(text, padding='max_length', truncation=True, max_length=256, |
|
return_tensors="pt") |
|
ids = la_in['input_ids'].to(fatiao_args.device) |
|
mask = la_in['attention_mask'].to(fatiao_args.device) |
|
fatiao_model.la[idx] += (fatiao_model.plm(input_ids=ids, attention_mask=mask)[0][:,0]).squeeze(0) |
|
|
|
|
|
fatiao_model.load_state_dict(torch.load('./pretrained_models/ELECT', map_location=torch.device(fatiao_args.device))) |
|
fatiao_model.to(fatiao_args.device) |
|
|
|
|
|
|
|
logger.info("model loaded") |
|
app.run(host="0.0.0.0", port=9098, debug=False) |
|
|