--- library_name: transformers tags: [spell_correction] --- # Model Card for Model ID Persian spelling correction model based on LLama3.1 instruct. This model was trained on user queries to Basalam.com market. ## Model Details ### Model Description ```python Base_prompt = """You are tasked with correcting spelling mistakes in the queries that users submitted to a Persian marketplace. Output the corrected query in the following JSON format: - If the input requires correction, use: {"correction": ""} - If the input is correct, use: "correction": ""} Here are some examples: "query": "ندل چسبی زنانه" Your answer: {"correction": "صندل چسبی زنانه"} "query": "بادکنک جشن تواد" Your answer: {"correction": "بادکنک جشن تولد"} "query": "صندلی بادی" Your answer: {"correction": ""}\n""" ``` ## Uses It should be used for spelling correction in a setting with Persian language around. ### Direct Use ```python //output structring def extract_json(text): try: correction = None pos = 0 decoder = json.JSONDecoder() while pos < len(text): match = text.find('{"correction":', pos) if match == -1: break try: result, index = decoder.raw_decode(text[match:]) correction = result.get('correction') if correction: return correction pos = match + index except json.JSONDecodeError: pos = match + 1 return correction except Exception as e: return text //Load Model BASE_MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct" model_name_or_path = "mfsadi/Llama-3.1-8B-spelling-fa" base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, return_dict=True) spelling_model = PeftModel.from_pretrained(base_model, model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) //Inference. You need to pass "query". prompt = f"""### Human: {spell_checking_prompt} query: {query}\n ### Assistant:""" batch = tokenizer(str([prompt]), return_tensors='pt') prompt_length = len(batch.get('input_ids')[0]) max_new_tokens = 50 with torch.no_grad(): output_tokens = spelling_model.generate(**batch.to(device), max_new_tokens=max_new_tokens, repetition_penalty=1.1, do_sample=True, num_beams=2, temperature=0.1, top_k=10, top_p=.5, length_penalty=-1 ) output = tokenizer.decode(output_tokens[0][prompt_length:], skip_special_tokens=True) return extract_json(output) ``` ## Model Card Contact Majid F. Sadi mfsadi.work@gmail.com https://www.linkedin.com/in/mfsadi/