nhradek's picture
Update README.md
60434a3 verified
|
raw
history blame
No virus
2.22 kB
metadata
license: mit
library_name: keras
pipeline_tag: image-classification
tags:
  - image classification
  - embeddings

An embedding model to classify images into FLUX generated images and non-flux photographs. The embeddings are 128 dimensional and can be used in another classifier to classify.

The model can load Fourier transformed images of size 512x512 which are then fed into the model and a 128 length output vector is produced. The steps to create the embeddings can be described as:

  1. Resize the images to 512x512.
  2. Transform the images into their Fourier image.
  3. Input the images into the model using predict.
  4. The output will be a 128-length vector for use in classification models.

The preprocessing code along with the predict can calculate the embeddings for classification.

# load an image and apply the fourier transform

import numpy as np
from PIL import Image
from scipy.fftpack import fft2
from tensorflow.keras.models import load_model, Model

# Function to apply Fourier transform
def apply_fourier_transform(image):
    image = np.array(image)
    fft_image = fft2(image)
    return np.abs(fft_image)

# Function to preprocess image
def preprocess_image(image_path):
    try:
      image = Image.open(image_path).convert('L')
      image = image.resize((512, 512))
      image = apply_fourier_transform(image)
      image = np.expand_dims(image, axis=-1)  # Expand dimensions to match model input shape
      image = np.expand_dims(image, axis=0)   # Expand to add batch dimension
      return image
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        return None

# Function to load embedding model and calculate embeddings
def calculate_embeddings(image_path, model_path='embedding_model.keras'):
    # Load the trained model
    model = load_model(model_path)

    # Remove the final classification layer to get embeddings
    embedding_model = Model(inputs=model.input, outputs=model.output)

    # Preprocess the image
    preprocessed_image = preprocess_image(image_path)

    # Calculate embeddings
    embeddings = embedding_model.predict(preprocessed_image)

    return embeddings



calculate_embeddings('filename.jpg')