Aekanun commited on
Commit
79ec84c
1 Parent(s): b5217a9

fixed app.py

Browse files
Files changed (2) hide show
  1. app.py +25 -15
  2. config.json +0 -9
app.py CHANGED
@@ -5,7 +5,9 @@ import gc
5
  from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
6
  from PIL import Image
7
  import gradio as gr
 
8
 
 
9
  warnings.filterwarnings('ignore')
10
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
11
 
@@ -13,22 +15,30 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
13
  model = None
14
  processor = None
15
 
 
16
  if torch.cuda.is_available():
17
  torch.cuda.empty_cache()
18
  gc.collect()
19
  print("เคลียร์ CUDA cache เรียบร้อยแล้ว")
20
 
 
 
 
 
 
 
 
21
  def load_model_and_processor():
22
  """โหลดโมเดลและ processor"""
23
  global model, processor
24
  print("กำลังโหลดโมเดลและ processor...")
25
 
26
  try:
27
- # กำหนด paths
28
  base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
29
  hub_model_path = "Aekanun/thai-handwriting-llm"
30
 
31
- # ตั้งค่า BitsAndBytes
32
  bnb_config = BitsAndBytesConfig(
33
  load_in_4bit=True,
34
  bnb_4bit_use_double_quant=True,
@@ -36,21 +46,19 @@ def load_model_and_processor():
36
  bnb_4bit_compute_dtype=torch.bfloat16
37
  )
38
 
39
- # โหลด processor จาก base model
40
- print("Loading processor...")
41
  processor = AutoProcessor.from_pretrained(base_model_path)
42
 
43
- # โหลดโมเดลจาก Hub
44
- print("Loading model...")
45
  model = AutoModelForVision2Seq.from_pretrained(
46
  hub_model_path,
47
  device_map="auto",
48
  torch_dtype=torch.bfloat16,
49
  quantization_config=bnb_config,
50
- trust_remote_code=True,
51
- force_download=True # เพิ่มมาเพื่อให้โหลดใหม่
52
  )
53
- print("Model loaded successfully!")
54
 
55
  return True
56
  except Exception as e:
@@ -68,14 +76,12 @@ def process_handwriting(image):
68
  # Ensure image is in PIL format
69
  if not isinstance(image, Image.Image):
70
  image = Image.fromarray(image)
71
-
72
- # Convert to RGB if needed
73
- if image.mode != "RGB":
74
- image = image.convert("RGB")
75
-
76
  prompt = """Transcribe the Thai handwritten text from the provided image.
77
  Only return the transcription in Thai language."""
78
 
 
79
  messages = [
80
  {
81
  "role": "user",
@@ -86,10 +92,12 @@ Only return the transcription in Thai language."""
86
  }
87
  ]
88
 
 
89
  text = processor.apply_chat_template(messages, tokenize=False)
90
  inputs = processor(text=text, images=image, return_tensors="pt")
91
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
92
 
 
93
  with torch.no_grad():
94
  outputs = model.generate(
95
  **inputs,
@@ -98,6 +106,7 @@ Only return the transcription in Thai language."""
98
  pad_token_id=processor.tokenizer.pad_token_id
99
  )
100
 
 
101
  transcription = processor.decode(outputs[0], skip_special_tokens=True)
102
  return transcription.strip()
103
 
@@ -113,7 +122,8 @@ if load_model_and_processor():
113
  inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
114
  outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
115
  title="Thai Handwriting Recognition",
116
- description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ"
 
117
  )
118
 
119
  if __name__ == "__main__":
 
5
  from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
6
  from PIL import Image
7
  import gradio as gr
8
+ from huggingface_hub import login
9
 
10
+ # Basic settings
11
  warnings.filterwarnings('ignore')
12
  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
13
 
 
15
  model = None
16
  processor = None
17
 
18
+ # Clear CUDA cache
19
  if torch.cuda.is_available():
20
  torch.cuda.empty_cache()
21
  gc.collect()
22
  print("เคลียร์ CUDA cache เรียบร้อยแล้ว")
23
 
24
+ # Login to Hugging Face Hub
25
+ if 'HUGGING_FACE_HUB_TOKEN' in os.environ:
26
+ print("กำลังเข้าสู่ระบบ Hugging Face Hub...")
27
+ login(token=os.environ['HUGGING_FACE_HUB_TOKEN'])
28
+ else:
29
+ print("คำเตือน: ไม่พบ HUGGING_FACE_HUB_TOKEN")
30
+
31
  def load_model_and_processor():
32
  """โหลดโมเดลและ processor"""
33
  global model, processor
34
  print("กำลังโหลดโมเดลและ processor...")
35
 
36
  try:
37
+ # Model paths
38
  base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
39
  hub_model_path = "Aekanun/thai-handwriting-llm"
40
 
41
+ # BitsAndBytes config
42
  bnb_config = BitsAndBytesConfig(
43
  load_in_4bit=True,
44
  bnb_4bit_use_double_quant=True,
 
46
  bnb_4bit_compute_dtype=torch.bfloat16
47
  )
48
 
49
+ # Load processor from base model
 
50
  processor = AutoProcessor.from_pretrained(base_model_path)
51
 
52
+ # Load model from Hub
53
+ print("กำลังโหลดโมเดลจาก Hub...")
54
  model = AutoModelForVision2Seq.from_pretrained(
55
  hub_model_path,
56
  device_map="auto",
57
  torch_dtype=torch.bfloat16,
58
  quantization_config=bnb_config,
59
+ trust_remote_code=True
 
60
  )
61
+ print("โหลดโมเดลสำเร็จ!")
62
 
63
  return True
64
  except Exception as e:
 
76
  # Ensure image is in PIL format
77
  if not isinstance(image, Image.Image):
78
  image = Image.fromarray(image)
79
+
80
+ # Create prompt
 
 
 
81
  prompt = """Transcribe the Thai handwritten text from the provided image.
82
  Only return the transcription in Thai language."""
83
 
84
+ # Create model inputs
85
  messages = [
86
  {
87
  "role": "user",
 
92
  }
93
  ]
94
 
95
+ # Process with model
96
  text = processor.apply_chat_template(messages, tokenize=False)
97
  inputs = processor(text=text, images=image, return_tensors="pt")
98
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
99
 
100
+ # Generate
101
  with torch.no_grad():
102
  outputs = model.generate(
103
  **inputs,
 
106
  pad_token_id=processor.tokenizer.pad_token_id
107
  )
108
 
109
+ # Decode output
110
  transcription = processor.decode(outputs[0], skip_special_tokens=True)
111
  return transcription.strip()
112
 
 
122
  inputs=gr.Image(type="pil", label="อัพโหลดรูปลายมือเขียนภาษาไทย"),
123
  outputs=gr.Textbox(label="ข้อความที่แปลงได้"),
124
  title="Thai Handwriting Recognition",
125
+ description="อัพโหลดรูปภาพลายมือเขียนภาษาไทยเพื่อแปลงเป็นข้อความ",
126
+ examples=[["example1.jpg"], ["example2.jpg"]]
127
  )
128
 
129
  if __name__ == "__main__":
config.json DELETED
@@ -1,9 +0,0 @@
1
- {
2
- "architectures": ["LlamaForCausalLM"],
3
- "model_type": "llama",
4
- "tokenizer_class": "PreTrainedTokenizerFast",
5
- "model_max_length": 131072,
6
- "megatron_core": "megatron.core",
7
- "task_type": "CAUSAL_LM",
8
- "target_modules": ["q_proj", "v_proj"]
9
- }