azizkaroui's picture
Fcommit
760312f
raw
history blame contribute delete
No virus
3.08 kB
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
# model architecture
class ImageEnhancementModel(nn.Module):
def __init__(self):
super(ImageEnhancementModel, self).__init__()
# Define the layers here
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, padding=1)
def forward(self, x):
# forward pass
x = self.relu1(self.conv1(x))
x = self.relu2(self.conv2(x))
x = self.conv3(x)
return x
class CustomDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.image_files = os.listdir(data_dir)
self.transform = transforms.Compose([transforms.ToTensor()])
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = os.path.join(self.data_dir, self.image_files[idx])
image = Image.open(img_name)
if image.mode != 'RGB':
image = image.convert('RGB')
image = self.transform(image)
return image
# Hyperparameters
batch_size = 8
learning_rate = 0.001
num_epochs = 50
model = ImageEnhancementModel()
# loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# DataLoader
train_dataset = CustomDataset(data_dir='before')
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# Training loop
for epoch in range(num_epochs):
for data in train_loader:
# Forward pass
outputs = model(data)
# Load the corresponding "after enhancement" images
target_data = CustomDataset(data_dir='after') # Load the "after" images
target_data = next(iter(target_data)) # Get the corresponding target image
loss = criterion(outputs, target_data) # Use the "after" images as targets
# Backpropagation and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
# Save the trained model
torch.save(model.state_dict(), 'image_enhancement_model.pth')
# Inference (enhance images)
model.eval() # Set the model to evaluation mode
# Load and preprocess an input image
input_image = Image.open('testb.jpg')
input_image = train_dataset.transform(input_image).unsqueeze(0)
# Use the trained model to enhance the input image
enhanced_image = model(input_image)
# Save
output_image = enhanced_image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
output_image = (output_image + 1) / 2.0 * 255.0 # Denormalize
output_image = output_image.astype('uint8')
Image.fromarray(output_image).save('enhanced_image.jpg')