raynardj's picture
Update README.md
85f6f87
|
raw
history blame
2.3 kB
metadata
language:
  - zh
tags:
  - generation
  - poetry
widget:
  - text: 疆场-思乡-归家-耕织《丘处机》

终于落不了油腻俗套, 来弄这劳什子的藏头诗模型

This is a model to generated Chinese poetry with leading characters and certain tune of mood.

本模型为了达到两个目的

  • 创作藏头诗 🎸
  • 创作时尽量融入关键词的意境🪁 🌼 ❄️ 🌝

Inference 通道矫情了一点, 大家参数照抄就是了


tokenizer  = AutoTokenizer.from_pretrained('raynardj/keywords-cangtou-chinese-poetry')
model  = AutoModel.from_pretrained('raynardj/keywords-cangtou-chinese-poetry')

def inference(lead, keywords = []):
    """
    lead: 藏头的语句, 比如一个人的名字, 2,3 或4个字
    keywords:关键词, 0~12个关键词比较好
    """
    leading = f"《{lead}》"
    text = "-".join(keywords)+leading
    input_ids = tokenizer(text, return_tensors='pt', ).input_ids[:,:-1]
    lead_tok = tokenizer(lead, return_tensors='pt',  ).input_ids[0,1:-1]

    with torch.no_grad():
        pred = model.generate(
            input_ids,
            max_length=256,
            num_beams=5,
            do_sample=True,
            repetition_penalty=2.1,
            top_p=.6,
            bos_token_id=tokenizer.sep_token_id,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.sep_token_id,
        )[0,1:]
    
    # 我们需要将[CLS] 字符, 也就是101, 逐个换回藏头的字符
    mask = (pred==101)
    while mask.sum()<len(lead_tok):
        lead_tok = lead_tok[:mask.sum()]
    while mask.sum()>len(lead_tok):
        reversed_lead_tok = lead_tok.flip(0)
        lead_tok = torch.cat([
            lead_tok, reversed_lead_tok[:mask.sum()-len(lead_tok)]])
    pred[mask] = lead_tok
    # 从 token 编号解码成语句
    generate = tokenizer.decode(pred, skip_special_tokens=True)
    # 清理语句
    generate = generate.replace("》","》\n").replace("。","。\n").replace(" ","")
    return generate

目前可以生成的语句,大家下了模型,🍒可以自己摘

>>> inference("上海",["高楼","虹光","灯红酒绿","华厦"])
高楼-虹光-灯红酒绿-华厦《上海》
『二』
上台星月明如昼。
海阁珠帘卷画堂。