cfe-gen / src /generate_images.py
anindya-hf-2002's picture
upload 3 files
634fc83 verified
import os
import torch
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from src.models import ResUNetGenerator
# Custom Dataset
class ImageDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('L')
if self.transform:
image = self.transform(image)
return image, img_path
# Function to save image
def save_image(tensor, path):
if tensor.is_cuda:
tensor = tensor.cpu()
array = tensor.permute(1, 2, 0).detach().numpy()
array = (array * 0.5 + 0.5) * 255
array = array.astype(np.uint8)
if array.shape[2] == 1:
array = array.squeeze(2)
image = Image.fromarray(array, mode='L')
else:
image = Image.fromarray(array)
image.save(path)
# Function to load model
def load_model(checkpoint_path, model_class, device):
model = model_class().to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
model.eval()
return model
def generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint, output_dir='data/translated_images', batch_size=16):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load models
g_NP = load_model(g_NP_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
g_PN = load_model(g_PN_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
# Create output directories
os.makedirs(os.path.join(output_dir, '0'), exist_ok=True)
os.makedirs(os.path.join(output_dir, '1'), exist_ok=True)
# Collect image paths
image_paths_0 = [os.path.join(image_folder, '0', fname) for fname in os.listdir(os.path.join(image_folder, '0')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
image_paths_1 = [os.path.join(image_folder, '1', fname) for fname in os.listdir(os.path.join(image_folder, '1')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
# Prepare dataset and dataloader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485], std=[0.229])])
dataset_0 = ImageDataset(image_paths_0, transform)
dataset_1 = ImageDataset(image_paths_1, transform)
dataloader_0 = DataLoader(dataset_0, batch_size=batch_size, shuffle=False)
dataloader_1 = DataLoader(dataset_1, batch_size=batch_size, shuffle=False)
# Process images from negative (0) to positive (1)
with torch.no_grad():
for batch, paths in tqdm(dataloader_0, desc="Converting N to P: "):
batch = batch.to(device)
translated_images = g_NP(batch)
translated_images = g_PN(translated_images)
for img, path in zip(translated_images, paths):
save_path = os.path.join(output_dir, '1', os.path.basename(path))
save_image(img, save_path)
# Process images from positive (1) to negative (0)
for batch, paths in tqdm(dataloader_1, desc="Converting P to N: "):
batch = batch.to(device)
translated_images = g_PN(batch)
translated_images = g_NP(translated_images)
for img, path in zip(translated_images, paths):
save_path = os.path.join(output_dir, '0', os.path.basename(path))
save_image(img, save_path)
if __name__ == '__main__':
image_folder = r'data\rsna-pneumonia-dataset\train'
g_NP_checkpoint = 'models\g_NP_best.ckpt'
g_PN_checkpoint = 'models\g_PN_best.ckpt'
generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint)