Edit model card

简介

这是一款根据自然语言生成 SQL 的模型(NL2SQL/Text2SQL),是我们自研众多 NL2SQL 模型中最为基础的一版,其它高级版模型后续将陆续进行开源。

该模型基于 BART 架构,我们将 NL2SQL 问题建模为类似机器翻译的 Seq2Seq 形式,该模型的优势特点:参数规模较小、但 SQL 生成准确性也较高。

用法

NL2SQL 任务中输入参数含有用户查询文本+数据库表信息,目前按照以下格式拼接模型的输入文本:

Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes <sep>

具体使用方法参考以下示例:

import torch
from transformers import AutoModelForSeq2SeqLM, MBartForConditionalGeneration, AutoTokenizer

device = 'cuda'
model_path = 'DMetaSoul/nl2sql-chinese-basic'
sampling = False
tokenizer = AutoTokenizer.from_pretrained(model_path, src_lang='zh_CN')
#model = MBartForConditionalGeneration.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model = model.half()
model.to(device)


input_texts = [
    "Question: 所有章节的名称和描述是什么? <sep> Tables: sections: section id , course id , section name , section description , other details <sep>",
    "Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes ; player_award_vote: award_id, year, league_id, player_id, points_won, points_max, votes_first ; salary: year, team_id, league_id, player_id, salary ; player: player_id, birth_year, birth_month, birth_day, birth_country, birth_state, birth_city, death_year, death_month, death_day, death_country, death_state, death_city, name_first, name_last, name_given, weight <sep>"
]
inputs = tokenizer(input_texts, max_length=512, return_tensors="pt",
    padding=True, truncation=True)
inputs = {k:v.to(device) for k,v in inputs.items() if k not in ["token_type_ids"]}

with torch.no_grad():
    if sampling:
        outputs = model.generate(**inputs, do_sample=True, top_k=50, top_p=0.95,
            temperature=1.0, num_return_sequences=1, 
            max_length=512, return_dict_in_generate=True, output_scores=True)
    else:
        outputs = model.generate(**inputs, num_beams=4, num_return_sequences=1, 
            max_length=512, return_dict_in_generate=True, output_scores=True)

output_ids = outputs.sequences
results = tokenizer.batch_decode(output_ids, skip_special_tokens=True,
            clean_up_tokenization_spaces=True)

for question, sql in zip(input_texts, results):
    print(question)
    print('SQL: {}'.format(sql))
    print()

输入结果如下:

Question: 所有章节的名称和描述是什么? <sep> Tables: sections: section id , course id , section name , section description , other details <sep>
SQL: SELECT section name, section description FROM sections

Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes ; player_award_vote: award_id, year, league_id, player_id, points_won, points_max, votes_first ; salary: year, team_id, league_id, player_id, salary ; player: player_id, birth_year, birth_month, birth_day, birth_country, birth_state, birth_city, death_year, death_month, death_day, death_country, death_state, death_city, name_first, name_last, name_given, weight <sep>
SQL: SELECT count(*) FROM hall_of_fame
Downloads last month
6
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.