koziev ilya
commited on
Commit
•
68864bd
1
Parent(s):
e3f5def
в примере сэмплинг переделан на жадную генерацию
Browse files
README.md
CHANGED
@@ -42,10 +42,10 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
42 |
|
43 |
|
44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
-
|
46 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
47 |
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})
|
48 |
-
model = AutoModelForCausalLM.from_pretrained(
|
49 |
model.to(device)
|
50 |
|
51 |
# На вход модели подаем последние 2-3 реплики диалога. Каждая реплика на отдельной строке, начинается с символа "-"
|
@@ -57,17 +57,7 @@ input_text = """<s>- Как тебя зовут?
|
|
57 |
|
58 |
encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)
|
59 |
|
60 |
-
output_sequences = model.generate(
|
61 |
-
input_ids=encoded_prompt,
|
62 |
-
max_length=100,
|
63 |
-
temperature=1.0,
|
64 |
-
top_k=30,
|
65 |
-
top_p=0.85,
|
66 |
-
repetition_penalty=1.2,
|
67 |
-
do_sample=True,
|
68 |
-
num_return_sequences=1,
|
69 |
-
pad_token_id=tokenizer.pad_token_id,
|
70 |
-
)
|
71 |
|
72 |
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)[len(input_text)+1:]
|
73 |
text = text[: text.find('</s>')]
|
|
|
42 |
|
43 |
|
44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
45 |
+
model_name = "inkoziev/rugpt_interpreter"
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
47 |
tokenizer.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>'})
|
48 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
49 |
model.to(device)
|
50 |
|
51 |
# На вход модели подаем последние 2-3 реплики диалога. Каждая реплика на отдельной строке, начинается с символа "-"
|
|
|
57 |
|
58 |
encoded_prompt = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt").to(device)
|
59 |
|
60 |
+
output_sequences = model.generate(input_ids=encoded_prompt, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
text = tokenizer.decode(output_sequences[0].tolist(), clean_up_tokenization_spaces=True)[len(input_text)+1:]
|
63 |
text = text[: text.find('</s>')]
|