Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- zh
|
4 |
+
tags:
|
5 |
+
- generation
|
6 |
+
- poetry
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
# 终于落不了油腻俗套, 来弄这劳什子的藏头诗模型
|
11 |
+
> This is a model to generated Chinese poetry with leading characters and certain tune of mood.
|
12 |
+
|
13 |
+
> 本模型为了达到两个目的
|
14 |
+
* 创作藏头诗 🎸
|
15 |
+
* 创作时尽量融入关键词的意境🪁 🌼 ❄️ 🌝
|
16 |
+
|
17 |
+
## Inference 通道矫情了一点, 大家参数照抄就是了
|
18 |
+
|
19 |
+
```python
|
20 |
+
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained('raynardj/keywords-cangtou-chinese-poetry')
|
22 |
+
model = AutoModel.from_pretrained('raynardj/keywords-cangtou-chinese-poetry')
|
23 |
+
|
24 |
+
def inference(lead, keywords = []):
|
25 |
+
"""
|
26 |
+
lead: 藏头的语句, 比如一个人的名字, 2,3 或4个字
|
27 |
+
keywords:关键词, 0~12个关键词比较好
|
28 |
+
"""
|
29 |
+
leading = f"《{lead}》"
|
30 |
+
text = "-".join(keywords)+leading
|
31 |
+
input_ids = tokenizer(text, return_tensors='pt', ).input_ids[:,:-1]
|
32 |
+
lead_tok = tokenizer(lead, return_tensors='pt', ).input_ids[0,1:-1]
|
33 |
+
|
34 |
+
with torch.no_grad():
|
35 |
+
pred = model.generate(
|
36 |
+
input_ids,
|
37 |
+
max_length=256,
|
38 |
+
num_beams=5,
|
39 |
+
do_sample=True,
|
40 |
+
repetition_penalty=2.1,
|
41 |
+
top_p=.6,
|
42 |
+
bos_token_id=tokenizer.sep_token_id,
|
43 |
+
pad_token_id=tokenizer.pad_token_id,
|
44 |
+
eos_token_id=tokenizer.sep_token_id,
|
45 |
+
)[0,1:]
|
46 |
+
|
47 |
+
# 我们需要将[CLS] 字符, 也就是101, 逐个换回藏头的字符
|
48 |
+
mask = (pred==101)
|
49 |
+
while mask.sum()<len(lead_tok):
|
50 |
+
lead_tok = lead_tok[:mask.sum()]
|
51 |
+
while mask.sum()>len(lead_tok):
|
52 |
+
reversed_lead_tok = lead_tok.flip(0)
|
53 |
+
lead_tok = torch.cat([
|
54 |
+
lead_tok, reversed_lead_tok[:mask.sum()-len(lead_tok)]])
|
55 |
+
pred[mask] = lead_tok
|
56 |
+
# 从 token 编号解码成语句
|
57 |
+
generate = tokenizer.decode(pred, skip_special_tokens=True)
|
58 |
+
# 清理语句
|
59 |
+
generate = generate.replace("》","》\n").replace("。","。\n").replace(" ","")
|
60 |
+
return generate
|
61 |
+
```
|
62 |
+
|
63 |
+
目前可以生成的语句,大家下了模型,🍒可以自己摘
|
64 |
+
```python
|
65 |
+
>>> inference("上海",["高楼","虹光","灯红酒绿","华厦"])
|
66 |
+
高楼-虹光-灯红酒绿-华厦《上海》
|
67 |
+
『二』
|
68 |
+
上台星月明如昼。
|
69 |
+
海阁珠帘卷画堂。
|
70 |
+
```
|