File size: 4,202 Bytes
7556e63
815f887
 
 
7556e63
815f887
 
 
7556e63
 
815f887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
---
language:
- ru
license: apache-2.0
library_name: transformers
datasets:
- hivaze/ru-AAQG-QA-QG
pipeline_tag: text2text-generation
---

## Description

This is **ai-forever/FRED-T5-large** model trained on **Question-Answering**, **Question-Generation** and **Answer-Aware Question Generation** tasks on russian dataset (**hivaze/ru-AAQG-QA-QG**)

### Prompts

```python
AAQG_PROMPT = "Сгенерируй вопрос по тексту, используя известный ответ. Текст: '{context}'. Ответ: '{answer}'."
QG_PROMPT = "Сгенерируй вопрос по тексту. Текст: '{context}'."
QA_PROMPT = "Сгенерируй ответ на вопрос по тексту. Текст: '{context}'. Вопрос: '{question}'."
```

### Examples and code

```python
from transformers import AutoTokenizer, T5ForConditionalGeneration
from functools import partial

saved_checkpoint = 'hivaze/AAQG-QA-QG-FRED-T5-large'
tokenizer = AutoTokenizer.from_pretrained(saved_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(saved_checkpoint).cuda()

def generate_text(prompt, tokenizer, model, n=1, temperature=0.8, num_beams=3):
  encoded_input = tokenizer.encode_plus(prompt, return_tensors='pt')
  encoded_input = {k: v.to(model.device) for k, v in encoded_input.items()}

  resulted_tokens = model.generate(**encoded_input,
                                   max_new_tokens=64,
                                   do_sample=True,
                                   num_beams=num_beams,
                                   num_return_sequences=n,
                                   temperature=temperature,
                                   top_p=0.9,
                                   top_k=50)
  resulted_texts = tokenizer.batch_decode(resulted_tokens, skip_special_tokens=True)

  return resulted_texts

generate_text = partial(generate_text, tokenizer=tokenizer, model=model)

test_context = "Путешественник Федор Конюхов и пилот Игорь Потапкин установили мировой рекорд высоты полета на паралёте, поднявшись на высоту 4728 метров — сайт Конюхова"
```

#### AAQG
```python
generate_text(AAQG_PROMPT.format(
  context=test_context,
  answer='на паралёте'
), n=1)
```
> "На чём установили мировой рекорд высоты полета Федор Конюхов и пилот Игорь Потапкин?"


```python
generate_text(AAQG_PROMPT.format(
  context=test_context,
  answer='рекорд высоты полета'
), n=1)
```
> "Что установили Конюхов и Потапкин?"


#### QA
```python
generate_text(QA_PROMPT.format(
  context=test_context,
  question='Что установили путешественник Федор Конюхов и пилот Игорь Потапкин?'
), n=1)
```
> "мировой рекорд высоты полета на паралёте, поднявшись на высоту 4728 метров — сайт Конюхова"

#### QG
```python
generate_text(QG_PROMPT.format(context=test_context), n=1)
```
> "Кто установил мировой рекорд высоты полета на паралёте, поднявшись на высоту 4728 метров?"

## Metrics

| Step | Training Loss | Validation Loss | Sbleu | Chr F | Rouge1 | Rouge2 | Rougel |
|------|---------------|-----------------|-------|-------|--------|--------|--------|
| 500  | 1.183100      | 1.188049        | 40.114700 | 62.147000 | 0.104600 | 0.034500 | 0.104300 |
| 1000 | 1.193000      | 1.125300        | 40.722300 | 62.661400 | 0.104700 | 0.033900 | 0.104300 |
| 1500 | 1.114300      | 1.097496        | 41.416600 | 63.060300 | 0.106100 | 0.033800 | 0.105800 |
| 2000 | 1.081300      | 1.080900        | 41.600200 | 63.260500 | 0.106200 | 0.033700 | 0.105900 |
| 2500 | 1.076900      | 1.070221        | 41.722300 | 63.315300 | 0.106300 | 0.034100 | 0.106000 |
| 3000 | 1.125600      | 1.062671        | 41.744500 | 63.409400 | 0.106400 | 0.034200 | 0.106200 |

## Authors
- Sergei Bratchikov (https://t.me/nlpwanderer)