dialofred / README.md
TeraSpace's picture
Update README.md
b77ba7a
|
raw
history blame
1.59 kB
---
license: mit
widget:
- text: |-
<SC1>- как ты?
- <extra_id_0>
example_title: how r u
language:
- ru
pipeline_tag: text2text-generation
---
# Usage
```python
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
device='cuda'
tokenizer = AutoTokenizer.from_pretrained('TeraSpace/dialofred')
model = AutoModelForSeq2SeqLM.from_pretrained('TeraSpace/dialofred').to(device)
while True:
text_inp = input("=>")
lm_text=f'<SC1>- {text_inp}\n- <extra_id_0>'
input_ids=torch.tensor([tokenizer.encode(lm_text)]).to(device)
# outputs=model.generate(input_ids=input_ids,
# max_length=200,
# eos_token_id=tokenizer.eos_token_id,
# early_stopping=True,
# do_sample=True,
# temperature=1.0,
# top_k=0,
# top_p=0.85)
# outputs=model.generate(input_ids,eos_token_id=tokenizer.eos_token_id,early_stopping=True)
outputs=model.generate(input_ids=input_ids,
max_length=200,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
do_sample=True,
temperature=0.7,
top_k=0,
top_p=0.8)
print(tokenizer.decode(outputs[0][1:]))
```