faceplugin's picture
initial commit
901e379
raw
history blame
5.2 kB
import sqlite3
import sys
import os
import os.path
import numpy as np
database_base_name = os.path.abspath(os.path.dirname(__file__)) + '\\person'
table_name = "feature"
sqlite_insert_blob_query = "INSERT INTO " + table_name + " (id, filename, count, boxes, landmarks, alignimgs, features) VALUES (?, ?, ?, ?, ?, ?, ?)"
sqlite_create_table_query = "CREATE TABLE " + table_name + " ( id INTEGER PRIMARY KEY, filename TEXT, count INTEGER, boxes BLOB NOT NULL, landmarks BLOB NOT NULL, alignimgs BLOB NOT NULL, features BLOB NOT NULL)"
sqlite_update_all_query = "UPDATE " + table_name + " set filename = ?, count = ?, boxes = ?, landmarks = ?, alignimgs = ?, features = ? where id = ?"
sqlite_search_query = "SELECT * FROM " + table_name
sqlite_delete_all = "DELETE FROM " + table_name
data_all = []
threshold = 68
max_feat_count = 8
max_id = -1
feature_update = False
face_database = None
#open databse
def open_database(db_no):
global max_id
global face_database
db_name = database_base_name + str(db_no) + ".db"
face_database = sqlite3.connect(db_name)
cursor = face_database.execute("SELECT name FROM sqlite_master WHERE type='table';")
#check tables exist in database
tables = [
v[0] for v in cursor.fetchall()
if v[0] != "sqlite_sequence"
]
cursor.close()
if not "feature" in tables:
face_database.execute(sqlite_create_table_query)
cursor = face_database.execute(sqlite_search_query)
#load index and feature in "feature table"
for row in cursor.fetchall():
id = row[0]
filename = row[1]
count = row[2]
boxes = np.fromstring(row[3], dtype=np.float32)
landmarks = np.fromstring(row[4], dtype=np.float32)
alignimgs = np.fromstring(row[5], dtype=np.uint8)
features = np.fromstring(row[6], dtype=np.float32)
if not boxes.shape[0] == count * 4:
continue
if not landmarks.shape[0] == count * 136:
continue
if not alignimgs.shape[0] == count * 49152:
continue
if not features.shape[0] == count * 256:
continue
boxes = boxes.reshape(count, 4)
landmarks = landmarks.reshape(count, 136)
alignimgs = alignimgs.reshape(count, 49152)
features = features.reshape(count, 256)
data_all.append({'id':id, 'filename':filename, 'count':count, 'boxes':boxes, 'landmarks':landmarks, 'alignimgs':alignimgs, 'features':features})
if id > max_id:
max_id = id
cursor.close()
#create database
def create_database():
db_no = 0
db_name = ""
while True:
db_name = database_base_name + str(db_no) + ".db"
if not os.path.isfile(db_name):
break
db_no += 1
open_database(db_no)
def clear_database():
global face_database
data_all.clear()
cursor = face_database.cursor()
cursor.execute(sqlite_delete_all)
face_database.commit()
cursor.close()
return
def register_face(filename, count, boxes, landmarks, alignimgs, features):
# boxes = boxes.reshape(count, 4)
# landmarks = landmarks.reshape(count, 136)
# alignimgs = alignimgs.reshape(count, 49152)
# features = features.reshape(count, 256)
global face_database
global max_id
max_id = max_id + 1
id = max_id
cursor = face_database.cursor()
cursor.execute(sqlite_insert_blob_query, (id, filename, count, boxes.tostring(), landmarks.tostring(), alignimgs.tostring(), features.tostring()))
face_database.commit()
cursor.close()
data_all.append({'id':id, 'filename':filename, 'count':count, 'boxes':boxes, 'landmarks':landmarks, 'alignimgs':alignimgs, 'features':features})
print('id = ', id)
return id
def update_face(id = None, filename = None, count = None, boxes = None, landmarks = None, alignimgs = None, features = None):
global face_database
cursor = face_database.cursor()
cursor.execute(sqlite_update_all_query, (filename, count, boxes.tostring(), landmarks.tostring(), alignimgs.tostring(), features.tostring(), id))
face_database.commit()
cursor.close()
def get_similarity(feat1, feat2):
return (np.sum(feat1 * feat2) + 1) * 50
def verify_face(feat):
global max_id
max_score = 0
for data in data_all:
id = data['id']
sub_id = data['count']
features = data['features']
# for sub_id in range(count):
score = get_similarity(feat, features)
if score >= max_score:
max_score = score
if score >= threshold:
print("score = ", score)
return id, data['filename'], sub_id
return -1, None, None
def get_info(id, sub_id):
for data in data_all:
nid = data['id']
if nid == id:
count = data['count']
if count < sub_id:
return data['filename'], data['boxes'][sub_id], data['landmarsk'][sub_id], data['alignimgs'][sub_id], data['features'][sub_id]
else:
return None, None, None, None, None
return None, None, None, None, None
def set_threshold(th):
threshold = th
def get_threshold():
return threshold