hivaze's picture
Update README.md
815f887 verified
metadata
language:
  - ru
license: apache-2.0
library_name: transformers
datasets:
  - hivaze/ru-AAQG-QA-QG
pipeline_tag: text2text-generation

Description

This is ai-forever/FRED-T5-large model trained on Question-Answering, Question-Generation and Answer-Aware Question Generation tasks on russian dataset (hivaze/ru-AAQG-QA-QG)

Prompts

AAQG_PROMPT = "Сгенерируй вопрос по тексту, используя известный ответ. Текст: '{context}'. Ответ: '{answer}'."
QG_PROMPT = "Сгенерируй вопрос по тексту. Текст: '{context}'."
QA_PROMPT = "Сгенерируй ответ на вопрос по тексту. Текст: '{context}'. Вопрос: '{question}'."

Examples and code

from transformers import AutoTokenizer, T5ForConditionalGeneration
from functools import partial

saved_checkpoint = 'hivaze/AAQG-QA-QG-FRED-T5-large'
tokenizer = AutoTokenizer.from_pretrained(saved_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(saved_checkpoint).cuda()

def generate_text(prompt, tokenizer, model, n=1, temperature=0.8, num_beams=3):
  encoded_input = tokenizer.encode_plus(prompt, return_tensors='pt')
  encoded_input = {k: v.to(model.device) for k, v in encoded_input.items()}

  resulted_tokens = model.generate(**encoded_input,
                                   max_new_tokens=64,
                                   do_sample=True,
                                   num_beams=num_beams,
                                   num_return_sequences=n,
                                   temperature=temperature,
                                   top_p=0.9,
                                   top_k=50)
  resulted_texts = tokenizer.batch_decode(resulted_tokens, skip_special_tokens=True)

  return resulted_texts

generate_text = partial(generate_text, tokenizer=tokenizer, model=model)

test_context = "Путешественник Федор Конюхов и пилот Игорь Потапкин установили мировой рекорд высоты полета на паралёте, поднявшись на высоту 4728 метров — сайт Конюхова"

AAQG

generate_text(AAQG_PROMPT.format(
  context=test_context,
  answer='на паралёте'
), n=1)

"На чём установили мировой рекорд высоты полета Федор Конюхов и пилот Игорь Потапкин?"

generate_text(AAQG_PROMPT.format(
  context=test_context,
  answer='рекорд высоты полета'
), n=1)

"Что установили Конюхов и Потапкин?"

QA

generate_text(QA_PROMPT.format(
  context=test_context,
  question='Что установили путешественник Федор Конюхов и пилот Игорь Потапкин?'
), n=1)

"мировой рекорд высоты полета на паралёте, поднявшись на высоту 4728 метров — сайт Конюхова"

QG

generate_text(QG_PROMPT.format(context=test_context), n=1)

"Кто установил мировой рекорд высоты полета на паралёте, поднявшись на высоту 4728 метров?"

Metrics

Step Training Loss Validation Loss Sbleu Chr F Rouge1 Rouge2 Rougel
500 1.183100 1.188049 40.114700 62.147000 0.104600 0.034500 0.104300
1000 1.193000 1.125300 40.722300 62.661400 0.104700 0.033900 0.104300
1500 1.114300 1.097496 41.416600 63.060300 0.106100 0.033800 0.105800
2000 1.081300 1.080900 41.600200 63.260500 0.106200 0.033700 0.105900
2500 1.076900 1.070221 41.722300 63.315300 0.106300 0.034100 0.106000
3000 1.125600 1.062671 41.744500 63.409400 0.106400 0.034200 0.106200

Authors