mfsadi's picture
Update README.md
c5121d0 verified
|
raw
history blame
3.09 kB
metadata
library_name: transformers
tags:
  - spell_correction

Model Card for Model ID

Persian spelling correction model based on LLama3.1 instruct. This model was trained on user queries to Basalam.com market.

Model Details

Model Description

Base_prompt = """You are tasked with correcting spelling mistakes in the queries that users submitted to a Persian marketplace.

Output the corrected query in the following JSON format:

- If the input requires correction, use:
  
  {"correction": "<corrected version of the query>"}
  
- If the input is correct, use:

  
  "correction": ""}

  
Here are some examples:

"query": "ندل چسبی زنانه" Your answer: {"correction": "صندل چسبی زنانه"}

"query": "بادکنک جشن تواد"  Your answer: {"correction": "بادکنک جشن تولد"}

"query": "صندلی بادی"  Your answer: {"correction": ""}\n"""

Uses

It should be used for spelling correction in a setting with Persian language around.

Direct Use

//output structring
def extract_json(text):
    try: 
        correction = None
        pos = 0
        decoder = json.JSONDecoder()
        while pos < len(text):
            match = text.find('{"correction":', pos)
            if match == -1:
                break
            try:
                result, index = decoder.raw_decode(text[match:])
                correction = result.get('correction')
                if correction:
                    return correction
                pos = match + index
            except json.JSONDecodeError:
                pos = match + 1
        return correction
    except Exception as e:
        return text

        
//Load Model
BASE_MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_name_or_path = "mfsadi/Llama-3.1-8B-spelling-fa"
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, return_dict=True)
spelling_model = PeftModel.from_pretrained(base_model, model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)


//Inference. You need to pass "query".
prompt = f"""### Human: {spell_checking_prompt} query: {query}\n ### Assistant:"""
batch = tokenizer(str([prompt]), return_tensors='pt')
prompt_length = len(batch.get('input_ids')[0])
max_new_tokens = 50
with torch.no_grad():
    output_tokens = spelling_model.generate(**batch.to(device), max_new_tokens=max_new_tokens,
                                            repetition_penalty=1.1,
                                            do_sample=True,
                                            num_beams=2,
                                            temperature=0.1,
                                            top_k=10,
                                            top_p=.5,
                                            length_penalty=-1
                                            )
    output = tokenizer.decode(output_tokens[0][prompt_length:], skip_special_tokens=True)
    return extract_json(output)

Model Card Contact

Majid F. Sadi

[email protected]

https://www.linkedin.com/in/mfsadi/