simple-ML-app / app.py
jatnikonm
kode post-ocr
135d706
# # import gradio as gr
# #
# # def greet(name):
# # return "Hello " + name + "!!"
# #
# # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# # demo.launch()
# #
#
# import gradio as gr
# from sklearn.neighbors import KNeighborsClassifier
# import numpy as np
#
# # Training data
# X = np.array([[1, 2], [2, 3], [3, 1], [6, 5], [7, 7], [8, 6]])
# y = np.array([0, 0, 0, 1, 1, 1])
#
# # Training the model
# model = KNeighborsClassifier(n_neighbors=3)
# model.fit(X, y)
#
# # Define the prediction function
# def classify_point(x, y):
# prediction = model.predict([[x, y]])
# return "Class " + str(prediction[0])
#
# # Create a Gradio interface
# demo = gr.Interface(
# fn=classify_point,
# inputs=["number", "number"],
# outputs="text",
# description="Predict the class of a point based on its coordinates using K-Nearest Neighbors"
# )
#
# # Launch the app
# demo.launch()
from dotenv import load_dotenv
import os
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
import gradio as gr
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, BitsAndBytesConfig
import torch
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoPeftModelForCausalLM.from_pretrained(
'pykale/llama-2-7b-ocr',
quantization_config=bnb_config,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained('pykale/llama-2-7b-ocr', token=hf_token)
def fix_ocr_errors(ocr):
prompt = f"""### instruksi:
perbaiki kata yang salah pada hasil OCR, hasil perbaikan harus dalam bahasa indonesia.
### Input:
{ocr}
### Response:
"""
input_ids = tokenizer(prompt, max_length=1024, return_tensors='pt', truncation=True).input_ids.cuda()
with torch.inference_mode():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.1,
top_k=40
)
pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
corrected_text = pred[len(prompt):].strip()
return corrected_text
iface = gr.Interface(
fn=fix_ocr_errors,
inputs=gr.Textbox(lines=5, placeholder="Masukkan teks OCR di sini..."),
outputs=gr.Textbox(label="text"),
title="Perbaiki Kesalahan OCR",
description="Masukkan teks dengan kesalahan OCR dan model akan mencoba memperbaikinya."
)
if __name__ == "__main__":
iface.launch()