LineArt-Removar / app.py
tori29umai's picture
Update
c575ad5
import gradio as gr
from PIL import Image, ImageFilter, ImageOps
import cv2
import numpy as np
import os
from collections import defaultdict
from skimage.color import deltaE_ciede2000, rgb2lab
import zipfile
def DoG_filter(image, kernel_size=0, sigma=1.0, k_sigma=2.0, gamma=1.5):
g1 = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
g2 = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma * k_sigma)
return g1 - gamma * g2
def XDoG_filter(image, kernel_size=0, sigma=1.4, k_sigma=1.6, epsilon=0, phi=10, gamma=0.98):
epsilon /= 255
dog = DoG_filter(image, kernel_size, sigma, k_sigma, gamma)
dog /= dog.max()
e = 1 + np.tanh(phi * (dog - epsilon))
e[e >= 1] = 1
return (e * 255).astype('uint8')
def binarize_image(image):
_, binarized = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return binarized
def process_XDoG(image_path):
kernel_size=0
sigma=1.4
k_sigma=1.6
epsilon=0
phi=10
gamma=0.98
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
xdog_image = XDoG_filter(image, kernel_size, sigma, k_sigma, epsilon, phi, gamma)
binarized_image = binarize_image(xdog_image)
final_image = Image.fromarray(binarized_image)
return final_image
def replace_color(image, color_1, blur_radius=2):
data = np.array(image)
original_shape = data.shape
channels = original_shape[2] if len(original_shape) > 2 else 1 # チャンネル数を確認
data = data.reshape(-1, channels)
color_1 = np.array(color_1)
matches = np.all(data[:, :3] == color_1, axis=1)
nochange_count = 0
mask = np.zeros(data.shape[0], dtype=bool)
while np.any(matches):
new_matches = np.zeros_like(matches)
match_num = np.sum(matches)
for i in range(len(data)):
if matches[i]:
x, y = divmod(i, original_shape[1])
neighbors = [
(x, y-1), (x, y+1), (x-1, y), (x+1, y)
]
valid_neighbors = []
for nx, ny in neighbors:
if 0 <= nx < original_shape[0] and 0 <= ny < original_shape[1]:
ni = nx * original_shape[1] + ny
if not np.all(data[ni, :3] == color_1, axis=0):
valid_neighbors.append(data[ni, :3])
if valid_neighbors:
new_color = np.mean(valid_neighbors, axis=0).astype(np.uint8)
data[i, :3] = new_color
data[i, 3] = 255
mask[i] = True
else:
new_matches[i] = True
matches = new_matches
if match_num == np.sum(matches):
nochange_count += 1
if nochange_count > 5:
break
data = data.reshape(original_shape)
mask = mask.reshape(original_shape[:2])
result_image = Image.fromarray(data, 'RGBA')
blurred_image = result_image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
blurred_data = np.array(blurred_image)
np.copyto(data, blurred_data, where=mask[..., None])
return Image.fromarray(data, 'RGBA')
def generate_distant_colors(consolidated_colors, distance_threshold):
consolidated_lab = [rgb2lab(np.array([color], dtype=np.float32) / 255.0).reshape(3) for color, _ in consolidated_colors]
max_attempts = 10000
for _ in range(max_attempts):
random_rgb = np.random.randint(0, 256, size=3)
random_lab = rgb2lab(np.array([random_rgb], dtype=np.float32) / 255.0).reshape(3)
if all(deltaE_ciede2000(base_color_lab, random_lab) > distance_threshold for base_color_lab in consolidated_lab):
return tuple(random_rgb)
return (128, 128, 128)
def consolidate_colors(major_colors, threshold):
colors_lab = [rgb2lab(np.array([[color]], dtype=np.float32)/255.0).reshape(3) for color, _ in major_colors]
i = 0
while i < len(colors_lab):
j = i + 1
while j < len(colors_lab):
if deltaE_ciede2000(colors_lab[i], colors_lab[j]) < threshold:
if major_colors[i][1] >= major_colors[j][1]:
major_colors[i] = (major_colors[i][0], major_colors[i][1] + major_colors[j][1])
major_colors.pop(j)
colors_lab.pop(j)
else:
major_colors[j] = (major_colors[j][0], major_colors[j][1] + major_colors[i][1])
major_colors.pop(i)
colors_lab.pop(i)
continue
j += 1
i += 1
return major_colors
def get_major_colors(image, threshold_percentage=0.01):
if image.mode != 'RGB':
image = image.convert('RGB')
color_count = defaultdict(int)
for pixel in image.getdata():
color_count[pixel] += 1
total_pixels = image.width * image.height
major_colors = [(color, count) for color, count in color_count.items() if (count / total_pixels) >= threshold_percentage]
return major_colors
def line_color(image, mask, new_color):
data = np.array(image)
data[mask, :3] = new_color
return Image.fromarray(data)
def process_image(image, lineart):
if image.mode != 'RGBA':
image = image.convert('RGBA')
lineart = lineart.point(lambda x: 0 if x < 200 else 255)
lineart = ImageOps.invert(lineart)
kernel = np.ones((3, 3), np.uint8)
lineart = cv2.dilate(np.array(lineart), kernel, iterations=1)
lineart = Image.fromarray(lineart)
mask = np.array(lineart) == 255
major_colors = get_major_colors(image, threshold_percentage=0.05)
major_colors = consolidate_colors(major_colors, 10)
new_color_1 = generate_distant_colors(major_colors, 100)
filled_image = line_color(image, mask, new_color_1)
replace_color_image = replace_color(filled_image, new_color_1, 2).convert('RGB')
return replace_color_image
def zip_files(zip_files, zip_path):
with zipfile.ZipFile(zip_path, 'w') as zipf:
for file_path in zip_files:
zipf.write(file_path, arcname=os.path.basename(file_path))
class webui:
def __init__(self):
self.demo = gr.Blocks()
def main(self, image_path):
image = Image.open(image_path)
#拡張子を取り除いたファイル名を取得
image_name = os.path.splitext(image_path)[0]
alpha = image.getchannel('A') if image.mode == 'RGBA' else None
image = Image.open(image_path).convert('RGBA')
rgb_image = image.convert('RGB')
lineart = process_XDoG(image_path).convert('L')
replace_color_image = process_image(rgb_image, lineart).convert('RGBA')
if alpha:
replace_color_image.putalpha(alpha)
replace_color_image_path = f"{image_name}_noline.png"
replace_color_image.save(replace_color_image_path)
lineart_image = lineart.convert('RGBA')
lineart_alpha = 255 - np.array(lineart)
lineart_image.putalpha(Image.fromarray(lineart_alpha))
lineart_image_path = f"{image_name}_lineart.png"
lineart_image.save(lineart_image_path)
zip_files_list = [replace_color_image_path, lineart_image_path]
zip_path = f"{image_name}.zip"
zip_files(zip_files_list, zip_path)
outputs = [replace_color_image, lineart_image]
return outputs, zip_path
def launch(self, share):
with self.demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(type='filepath', image_mode="RGBA", label="Original Image(png画像にのみ対応しています)")
submit = gr.Button(value="Start")
with gr.Row():
with gr.Column():
with gr.Tab("output"):
output_0 = gr.Gallery(format="png")
output_file = gr.File()
submit.click(
self.main,
inputs=[input_image],
outputs=[output_0, output_file]
)
self.demo.queue()
self.demo.launch(share=share)
if __name__ == "__main__":
ui = webui()
ui.launch(share=True)