Aekanun commited on
Commit
11aa08b
1 Parent(s): 4c21d22

revised inference code

Browse files
Files changed (1) hide show
  1. README.md +76 -57
README.md CHANGED
@@ -27,65 +27,84 @@ A LoRA-adapted vision-language model based on Llama-3.2-11B-Vision-Instruct that
27
  ### Single Image
28
  ```python
29
  import torch
30
- from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
 
31
  from PIL import Image
32
 
33
- def load_model_and_processor():
34
- model_path = "Aekanun/thai-handwriting-llm"
35
- base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
36
-
37
- # BitsAndBytes config
38
- bnb_config = BitsAndBytesConfig(
39
- load_in_4bit=True,
40
- bnb_4bit_use_double_quant=True,
41
- bnb_4bit_quant_type="nf4",
42
- bnb_4bit_compute_dtype=torch.bfloat16
43
- )
44
-
45
- # Load processor from base model
46
- processor = AutoProcessor.from_pretrained(base_model_path)
47
-
48
- # Load fine-tuned model
49
- model = AutoModelForVision2Seq.from_pretrained(
50
- model_path,
51
- device_map="auto",
52
- torch_dtype=torch.bfloat16,
53
- quantization_config=bnb_config
54
- )
55
- return model, processor
 
 
 
 
 
 
 
56
 
57
- def process_image(image_path, model, processor):
58
- image = Image.open(image_path)
59
-
60
- prompt = """Transcribe the Thai handwritten text from the provided image.
 
 
61
  Only return the transcription in Thai language."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- messages = [
64
- {
65
- "role": "user",
66
- "content": [
67
- {"type": "text", "text": prompt},
68
- {"type": "image", "image": image}
69
- ],
70
- }
71
- ]
72
-
73
- text = processor.apply_chat_template(messages, tokenize=False)
74
- inputs = processor(text=text, images=image, return_tensors="pt")
75
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
76
-
77
- with torch.no_grad():
78
- outputs = model.generate(
79
- **inputs,
80
- max_new_tokens=256,
81
- do_sample=False,
82
- pad_token_id=processor.tokenizer.pad_token_id
83
- )
84
-
85
- transcription = processor.decode(outputs[0], skip_special_tokens=True)
86
- return transcription
87
-
88
- # Usage
89
- model, processor = load_model_and_processor()
90
- result = process_image("path/to/image.jpg", model, processor)
91
- print(result)
 
27
  ### Single Image
28
  ```python
29
  import torch
30
+ from transformers import AutoModelForVision2Seq, AutoProcessor
31
+ from peft import PeftModel
32
  from PIL import Image
33
 
34
+ def load_model():
35
+ # Model paths
36
+ base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
37
+ adapter_path = "Aekanun/thai-handwriting-llm"
38
+
39
+ # Load processor
40
+ processor = AutoProcessor.from_pretrained(
41
+ base_model_path,
42
+ use_auth_token=True
43
+ )
44
+
45
+ # Load base model
46
+ base_model = AutoModelForVision2Seq.from_pretrained(
47
+ base_model_path,
48
+ device_map="auto",
49
+ torch_dtype=torch.float16,
50
+ trust_remote_code=True,
51
+ use_auth_token=True
52
+ )
53
+
54
+ # Load adapter
55
+ model = PeftModel.from_pretrained(
56
+ base_model,
57
+ adapter_path,
58
+ device_map="auto",
59
+ torch_dtype=torch.float16,
60
+ use_auth_token=True
61
+ )
62
+
63
+ return model, processor
64
 
65
+ def transcribe_thai_handwriting(image_path, model, processor):
66
+ # Load and prepare image
67
+ image = Image.open(image_path)
68
+
69
+ # Create prompt
70
+ prompt = """Transcribe the Thai handwritten text from the provided image.
71
  Only return the transcription in Thai language."""
72
+
73
+ # Prepare inputs
74
+ messages = [
75
+ {
76
+ "role": "user",
77
+ "content": [
78
+ {"type": "text", "text": prompt},
79
+ {"type": "image", "image": image}
80
+ ],
81
+ }
82
+ ]
83
+
84
+ # Process with model
85
+ text = processor.apply_chat_template(messages, tokenize=False)
86
+ inputs = processor(text=text, images=image, return_tensors="pt")
87
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
88
+
89
+ # Generate
90
+ with torch.no_grad():
91
+ outputs = model.generate(
92
+ **inputs,
93
+ max_new_tokens=512,
94
+ do_sample=False,
95
+ pad_token_id=processor.tokenizer.pad_token_id
96
+ )
97
+
98
+ # Decode output
99
+ transcription = processor.decode(outputs[0], skip_special_tokens=True)
100
+ return transcription.strip()
101
 
102
+ # Example usage
103
+ if __name__ == "__main__":
104
+ # Load model
105
+ model, processor = load_model()
106
+
107
+ # Transcribe image
108
+ image_path = "path/to/your/image.jpg"
109
+ result = transcribe_thai_handwriting(image_path, model, processor)
110
+ print(f"Transcription: {result}")