cointegrated's picture
Create README.md
6101dd6
metadata
language:
  - ba
license: apache-2.0
tags:
  - grammatical error correction

Canine-c Bashkir Spelling Correction v1

This model is a version of google/canine-c fine-tuned to fix corrupted texts. It was trained on a mixture of two parallel datasets in the Bashkir language:

  • sentences post-edited by humans after OCR
  • artificially randomly corrupted sentences along with their original versions

For each character, the model predicts whether to replace it and whether to insert another character next to it.

In this way, the model can be used to fix spelling or OCR errors.

On a held-out set, it reduces the number of required edits by 40%.

How to use

You can use the model by feeding sentences to the following code:

import torch
from transformers import CanineTokenizer, CanineForTokenClassification

tokenizer = CanineTokenizer.from_pretrained('slone/canine-c-bashkir-gec-v1')
model = CanineForTokenClassification.from_pretrained('slone/canine-c-bashkir-gec-v1')
if torch.cuda.is_available():
    model.cuda()

LABELS_THIS = [c[5:] for c in model.config.id2label.values() if c.startswith('THIS_')]
LABELS_NEXT = [c[5:] for c in model.config.id2label.values() if c.startswith('NEXT_')]

def fix_text(text, boost=0):
    """Apply the model to edit the text. `boost` is a parameter to control edit aggressiveness."""
    bx = tokenizer(text, return_tensors='pt', padding=True)
    with torch.inference_mode():
        out = model(**bx.to(model.device))
        n1, n2 =  len(LABELS_THIS), len(LABELS_NEXT)
        logits1 = out.logits[0, :, :n1].view(-1, n1)
        logits2 = out.logits[0, :, n1:].view(-1, n2)
        if boost:
            logits1[1:, 0] -= boost
            logits2[:, 0] -= boost
        ids1, ids2 = logits1.argmax(-1).tolist(), logits2.argmax(-1).tolist()
    result = []
    for c, id1, id2 in zip(' ' + text, ids1, ids2):
        l1, l2 = LABELS_THIS[id1], LABELS_NEXT[id2]
        if l1 == 'KEEP':
            result.append(c)
        elif l1 != 'DELETE':
            result.append(l1)
        if l2 != 'PASS':
            result.append(l2)
    return ''.join(result)

text = 'У йыл дан д ың йөҙө һoрөмлэнде.'
print(fix_text(text))  # Уйылдандың йөҙө һөрөмләнде.

The parameter boost can be used to control the aggressiveness of editing: positive values increase the probability of changing the text, negative values decrease it.