|
import os |
|
import requests |
|
from tqdm import tqdm |
|
from datasets import load_dataset |
|
import numpy as np |
|
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input |
|
from tensorflow.keras.preprocessing import image |
|
from sklearn.neighbors import NearestNeighbors |
|
import joblib |
|
from PIL import UnidentifiedImageError, Image |
|
import gradio as gr |
|
|
|
|
|
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k") |
|
|
|
|
|
subset_size = 2700 |
|
dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size)) |
|
|
|
|
|
image_dir = 'civitai_images' |
|
os.makedirs(image_dir, exist_ok=True) |
|
|
|
|
|
model = ResNet50(weights='imagenet', include_top=False, pooling='avg') |
|
|
|
|
|
def extract_features(img_path, model): |
|
img = image.load_img(img_path, target_size=(224, 224)) |
|
img_array = image.img_to_array(img) |
|
img_array = np.expand_dims(img_array, axis=0) |
|
img_array = preprocess_input(img_array) |
|
features = model.predict(img_array) |
|
return features.flatten() |
|
|
|
|
|
features = [] |
|
image_paths = [] |
|
model_names = [] |
|
|
|
for sample in tqdm(dataset_subset): |
|
img_url = sample['url'] |
|
model_name = sample['Model'] |
|
img_path = os.path.join(image_dir, os.path.basename(img_url)) |
|
|
|
|
|
try: |
|
response = requests.get(img_url) |
|
response.raise_for_status() |
|
|
|
if 'image' not in response.headers['Content-Type']: |
|
raise ValueError("URL does not contain an image") |
|
|
|
with open(img_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
try: |
|
img_features = extract_features(img_path, model) |
|
features.append(img_features) |
|
image_paths.append(img_path) |
|
model_names.append(model_name) |
|
except UnidentifiedImageError: |
|
print(f"UnidentifiedImageError: Skipping file {img_path}") |
|
os.remove(img_path) |
|
|
|
except requests.exceptions.RequestException as e: |
|
print(f"RequestException: Failed to download {img_url} - {e}") |
|
|
|
|
|
features = np.array(features) |
|
|
|
|
|
nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(features) |
|
|
|
|
|
joblib.dump(nbrs, 'nearest_neighbors_model.pkl') |
|
np.save('image_features.npy', features) |
|
np.save('image_paths.npy', image_paths) |
|
np.save('model_names.npy', model_names) |
|
|
|
|
|
nbrs = joblib.load('nearest_neighbors_model.pkl') |
|
features = np.load('image_features.npy') |
|
image_paths = np.load('image_paths.npy', allow_pickle=True) |
|
model_names = np.load('model_names.npy', allow_pickle=True) |
|
|
|
|
|
def get_recommendations(img, n_neighbors=5): |
|
img_path = "temp_input_image.png" |
|
img.save(img_path) |
|
|
|
img_features = extract_features(img_path, model) |
|
distances, indices = nbrs.kneighbors([img_features]) |
|
|
|
recommended_images = [image_paths[idx] for idx in indices.flatten()] |
|
recommended_model_names = [model_names[idx] for idx in indices.flatten()] |
|
recommended_distances = distances.flatten() |
|
|
|
return [(Image.open(img_path), f'{name}, Distance: {dist:.2f}') for img_path, name, dist in zip(recommended_images, recommended_model_names, recommended_distances)] |
|
|
|
|
|
interface = gr.Interface( |
|
fn=get_recommendations, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Gallery(label="Recommended Images"), |
|
title="Image Recommendation System", |
|
description="Upload an image and get similar images with their model names and distances." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|