Randolphzeng
commited on
Commit
•
486312f
1
Parent(s):
d5b4c41
Update README.md
Browse files
README.md
CHANGED
@@ -51,12 +51,13 @@ vae_model = Della.from_pretrained("IDEA-CCNL/Randeng-DELLA-CVAE-226M-NER-Chinese
|
|
51 |
special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'additional_special_tokens': ['<ENT>', '<ENS>']}
|
52 |
tokenizer.add_special_tokens(special_tokens_dict)
|
53 |
|
54 |
-
|
|
|
55 |
ent_token_type_id = tokenizer.additional_special_tokens_ids[0]
|
56 |
ent_token_sep_id = tokenizer.additional_special_tokens_ids[1]
|
57 |
bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
|
58 |
decoder_target, decoder_entities = [], []
|
59 |
-
entity_list = [('
|
60 |
|
61 |
for ent in entity_list:
|
62 |
entity_name = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(ent[0]))
|
@@ -72,7 +73,7 @@ prior_z_list, prior_output_list = model.get_cond_prior_vecs(encoder_outputs.hidd
|
|
72 |
outputs = model.decoder.generate(input_ids=inputs.to(device), layer_latent_vecs=prior_z_list, labels=None,
|
73 |
label_ignore=model.pad_token_id, num_return_sequences=32, max_new_tokens=256,
|
74 |
eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
|
75 |
-
no_repeat_ngram_size=-1, do_sample=True, top_p=0.
|
76 |
|
77 |
print(tokenizer.decode(inputs[0]))
|
78 |
gen_sents = []
|
@@ -84,7 +85,6 @@ for idx in range(len(outputs)):
|
|
84 |
gen_sents.append(gen_sent)
|
85 |
for s in gen_sents:
|
86 |
print(s)
|
87 |
-
|
88 |
```
|
89 |
|
90 |
## 引用 Citation
|
|
|
51 |
special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>', 'additional_special_tokens': ['<ENT>', '<ENS>']}
|
52 |
tokenizer.add_special_tokens(special_tokens_dict)
|
53 |
|
54 |
+
device = 0
|
55 |
+
model = vae_model.model.to(device)
|
56 |
ent_token_type_id = tokenizer.additional_special_tokens_ids[0]
|
57 |
ent_token_sep_id = tokenizer.additional_special_tokens_ids[1]
|
58 |
bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
|
59 |
decoder_target, decoder_entities = [], []
|
60 |
+
entity_list = [('深圳', '地点/地理位置'), ('昨天', '时间')]
|
61 |
|
62 |
for ent in entity_list:
|
63 |
entity_name = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(ent[0]))
|
|
|
73 |
outputs = model.decoder.generate(input_ids=inputs.to(device), layer_latent_vecs=prior_z_list, labels=None,
|
74 |
label_ignore=model.pad_token_id, num_return_sequences=32, max_new_tokens=256,
|
75 |
eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
|
76 |
+
no_repeat_ngram_size=-1, do_sample=True, top_p=0.8)
|
77 |
|
78 |
print(tokenizer.decode(inputs[0]))
|
79 |
gen_sents = []
|
|
|
85 |
gen_sents.append(gen_sent)
|
86 |
for s in gen_sents:
|
87 |
print(s)
|
|
|
88 |
```
|
89 |
|
90 |
## 引用 Citation
|