|
--- |
|
library_name: transformers |
|
tags: [spell_correction] |
|
--- |
|
|
|
# Model Card for Model ID |
|
|
|
Persian spelling correction model based on LLama3.1 instruct. This model was trained on user queries to Basalam.com market. |
|
|
|
|
|
## Model Details |
|
|
|
### Model Description |
|
```python |
|
Base_prompt = """You are tasked with correcting spelling mistakes in the queries that users submitted to a Persian marketplace. |
|
|
|
Output the corrected query in the following JSON format: |
|
|
|
- If the input requires correction, use: |
|
|
|
{"correction": "<corrected version of the query>"} |
|
|
|
- If the input is correct, use: |
|
|
|
|
|
"correction": ""} |
|
|
|
|
|
Here are some examples: |
|
|
|
"query": "ندل چسبی زنانه" Your answer: {"correction": "صندل چسبی زنانه"} |
|
|
|
"query": "بادکنک جشن تواد" Your answer: {"correction": "بادکنک جشن تولد"} |
|
|
|
"query": "صندلی بادی" Your answer: {"correction": ""}\n""" |
|
``` |
|
|
|
|
|
## Uses |
|
|
|
It should be used for spelling correction in a setting with Persian language around. |
|
|
|
### Direct Use |
|
|
|
```python |
|
//output structring |
|
def extract_json(text): |
|
try: |
|
correction = None |
|
pos = 0 |
|
decoder = json.JSONDecoder() |
|
while pos < len(text): |
|
match = text.find('{"correction":', pos) |
|
if match == -1: |
|
break |
|
try: |
|
result, index = decoder.raw_decode(text[match:]) |
|
correction = result.get('correction') |
|
if correction: |
|
return correction |
|
pos = match + index |
|
except json.JSONDecodeError: |
|
pos = match + 1 |
|
return correction |
|
except Exception as e: |
|
return text |
|
|
|
|
|
//Load Model |
|
BASE_MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
model_name_or_path = "mfsadi/Llama-3.1-8B-spelling-fa" |
|
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, return_dict=True) |
|
spelling_model = PeftModel.from_pretrained(base_model, model_name_or_path) |
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) |
|
|
|
|
|
//Inference. You need to pass "query". |
|
prompt = f"""### Human: {spell_checking_prompt} query: {query}\n ### Assistant:""" |
|
batch = tokenizer(str([prompt]), return_tensors='pt') |
|
prompt_length = len(batch.get('input_ids')[0]) |
|
max_new_tokens = 50 |
|
with torch.no_grad(): |
|
output_tokens = spelling_model.generate(**batch.to(device), max_new_tokens=max_new_tokens, |
|
repetition_penalty=1.1, |
|
do_sample=True, |
|
num_beams=2, |
|
temperature=0.1, |
|
top_k=10, |
|
top_p=.5, |
|
length_penalty=-1 |
|
) |
|
output = tokenizer.decode(output_tokens[0][prompt_length:], skip_special_tokens=True) |
|
return extract_json(output) |
|
``` |
|
|
|
## Model Card Contact |
|
|
|
Majid F. Sadi |
|
|
|
[email protected] |
|
|
|
https://www.linkedin.com/in/mfsadi/ |