Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
from requests import get | |
import streamlit as st | |
import cv2 | |
from ultralytics import YOLO | |
import shutil | |
import easyocr | |
import imutils | |
PREDICTION_PATH = os.path.join('.', 'predictions') | |
def load_od_model(): | |
finetuned_model = YOLO('cc_detect_best.pt') | |
return finetuned_model | |
def load_easyocr(): | |
reader = easyocr.Reader(['en']) | |
return reader | |
def decode_text(type: str): | |
reader = load_easyocr() | |
output_crop_path = os.path.join(PREDICTION_PATH, 'predict', 'crops', type) | |
ocr_txt = '' | |
if os.path.exists(output_crop_path): | |
crop_file = os.listdir(output_crop_path)[0] | |
crop_img_path = os.path.join(output_crop_path, crop_file) | |
crop_img = cv2.imread(crop_img_path) | |
increase = cv2.resize(crop_img, None, fx = 2, fy = 2, interpolation = cv2.INTER_CUBIC) | |
if type == 'card_number': | |
increase = cv2.resize(crop_img, None, fx = 5, fy = 5, interpolation = cv2.INTER_CUBIC) | |
gray = cv2.cvtColor(increase, cv2.COLOR_BGR2GRAY) | |
value, thresh = cv2.threshold(gray,0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) | |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2)) | |
opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1) | |
# Find contours and remove small noise | |
cnts = cv2.findContours(opening, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
cnts = cnts[0] if len(cnts) == 2 else cnts[1] | |
for c in cnts: | |
area = cv2.contourArea(c) | |
if area < 50: | |
cv2.drawContours(opening, [c], -1, 0, -1) | |
# Invert | |
result = 255 - opening | |
cleaned_image = result | |
crop_ocr = reader.readtext(cleaned_image) | |
cleaned_image = cv2.resize(cleaned_image, None, fx = 0.5, fy = 0.5, interpolation = cv2.INTER_CUBIC) | |
if type == 'card_number': | |
cleaned_image = cv2.resize(cleaned_image, None, fx = 0.2, fy = 0.2, interpolation = cv2.INTER_CUBIC) | |
cv2.imwrite(crop_img_path, cleaned_image) | |
ocr_txt = ''.join([t for _, t, _ in crop_ocr]) | |
ocr_txt_conf = np.round(np.mean([p for _, _, p in crop_ocr]), 4) | |
if type == 'card_number': | |
ocr_txt = ocr_txt.replace(' ', '') | |
col1, col2 = st.columns(2, gap='small') | |
with col1: | |
st.markdown(f"<h5>{type.replace('_', ' ').upper()}</h5>", unsafe_allow_html=True) | |
st.text(f"{ocr_txt.upper()} ({str(ocr_txt_conf)})") | |
with col2: | |
st.text(' ') | |
if type == 'card_number': | |
st.text(' ') | |
st.image(crop_img_path) | |
def inference(input_image_path: str): | |
finetuned_model = load_od_model() | |
results = finetuned_model.predict(input_image_path, | |
show=False, | |
save=True, | |
save_crop=True, | |
imgsz=640, | |
conf=0.6, | |
save_txt=True, | |
project= PREDICTION_PATH, | |
show_labels=True, | |
show_conf=True, | |
line_width=2, | |
exist_ok=True) | |
decode_text('card_number') | |
decode_text('holder_name') | |
decode_text('exp_date') | |
st.image(os.path.join(PREDICTION_PATH, 'predict', 'input.jpg')) | |
def files_cleanup(path_: str): | |
if os.path.exists(path_): | |
os.remove(path_) | |
if os.path.exists(PREDICTION_PATH): | |
shutil.rmtree(PREDICTION_PATH) | |
# @st.cache_resource | |
def get_upload_path(): | |
upload_file_path = os.path.join('.', 'uploads') | |
if not os.path.exists(upload_file_path): | |
os.makedirs(upload_file_path) | |
upload_filename = "input.jpg" | |
upload_file_path = os.path.join(upload_file_path, upload_filename) | |
return upload_file_path | |
def process_input_image(img_url): | |
upload_file_path = get_upload_path() | |
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'} | |
r = get(img_url, headers=headers) | |
arr = np.frombuffer(r.content, np.uint8) | |
input_image = cv2.imdecode(arr, cv2.IMREAD_UNCHANGED) | |
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) | |
input_image = cv2.resize(input_image, (640, 640)) | |
cv2.imwrite(upload_file_path, cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)) | |
return upload_file_path | |
try: | |
files_cleanup(get_upload_path()) | |
st.markdown("<h3>Credit Card Detection</h3>", unsafe_allow_html=True) | |
desc = '''YOLOv8 is fine-tuned to detect credit card number, holder's name and expiry date. Dataset used to fine-tune YOLOv8 | |
can be found <a href="https://universe.roboflow.com/credit-cards-detection/credit_card_detect-wjmlc/dataset/2" target="_blank"> | |
here</a>. The detected objects are cropped, processed and passed as inputs to EasyOCR for text recognition. | |
''' | |
st.markdown(desc, unsafe_allow_html=True) | |
img_url = st.text_input("Paste the image URL of a credit card:", "") | |
placeholder = st.empty() | |
if img_url: | |
placeholder.empty() | |
img_path = process_input_image(img_url) | |
inference(img_path) | |
files_cleanup(get_upload_path()) | |
except Exception as e: | |
files_cleanup(get_upload_path()) | |
st.error(f'An unexpected error occured: \n{e}') |