File size: 3,091 Bytes
5fc721b
 
6affc33
5fc721b
 
 
 
9d51185
5fc721b
 
 
 
 
c5121d0
9d51185
6affc33
9d51185
6affc33
9d51185
6affc33
c5121d0
6affc33
9d51185
6affc33
 
9d51185
6affc33
 
9d51185
6affc33
9d51185
6affc33
9d51185
6affc33
9d51185
c5121d0
6affc33
 
5fc721b
 
9d51185
6affc33
5fc721b
 
6affc33
9d51185
 
6affc33
9d51185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6affc33
5fc721b
 
 
9d51185
6affc33
9d51185
6affc33
9d51185
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
---
library_name: transformers
tags: [spell_correction]
---

# Model Card for Model ID

Persian spelling correction model based on LLama3.1 instruct. This model was trained on user queries to Basalam.com market.


## Model Details

### Model Description
```python 
Base_prompt = """You are tasked with correcting spelling mistakes in the queries that users submitted to a Persian marketplace.

Output the corrected query in the following JSON format:

- If the input requires correction, use:
  
  {"correction": "<corrected version of the query>"}
  
- If the input is correct, use:

  
  "correction": ""}

  
Here are some examples:

"query": "ندل چسبی زنانه" Your answer: {"correction": "صندل چسبی زنانه"}

"query": "بادکنک جشن تواد"  Your answer: {"correction": "بادکنک جشن تولد"}

"query": "صندلی بادی"  Your answer: {"correction": ""}\n"""
```


## Uses

It should be used for spelling correction in a setting with Persian language around.

### Direct Use

```python 
//output structring
def extract_json(text):
    try: 
        correction = None
        pos = 0
        decoder = json.JSONDecoder()
        while pos < len(text):
            match = text.find('{"correction":', pos)
            if match == -1:
                break
            try:
                result, index = decoder.raw_decode(text[match:])
                correction = result.get('correction')
                if correction:
                    return correction
                pos = match + index
            except json.JSONDecodeError:
                pos = match + 1
        return correction
    except Exception as e:
        return text

        
//Load Model
BASE_MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_name_or_path = "mfsadi/Llama-3.1-8B-spelling-fa"
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, return_dict=True)
spelling_model = PeftModel.from_pretrained(base_model, model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)


//Inference. You need to pass "query".
prompt = f"""### Human: {spell_checking_prompt} query: {query}\n ### Assistant:"""
batch = tokenizer(str([prompt]), return_tensors='pt')
prompt_length = len(batch.get('input_ids')[0])
max_new_tokens = 50
with torch.no_grad():
    output_tokens = spelling_model.generate(**batch.to(device), max_new_tokens=max_new_tokens,
                                            repetition_penalty=1.1,
                                            do_sample=True,
                                            num_beams=2,
                                            temperature=0.1,
                                            top_k=10,
                                            top_p=.5,
                                            length_penalty=-1
                                            )
    output = tokenizer.decode(output_tokens[0][prompt_length:], skip_special_tokens=True)
    return extract_json(output)
```

## Model Card Contact

Majid F. Sadi

[email protected]

https://www.linkedin.com/in/mfsadi/