import torch from torch import nn from torch.utils.data import DataLoader # Hyperparameters image_size = (224, 224, 3) # Adjust based on your data # Define the Generator Network class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # Define convolutional layers with appropriate filters and activations self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1) # ... Add more convolutional layers as needed self.conv_final = nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1, activation=nn.Tanh) # Tanh for shadow intensity def forward(self, x): # Define the forward pass through the convolutional layers x = self.conv1(x) # ... Forward pass through remaining convolutional layers return self.conv_final(x) # Define the Discriminator Network class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() # Define convolutional layers with appropriate filters and activations self.conv1 = nn.Conv2d(6, 32, kernel_size=3, stride=1, padding=1) # ... Add more convolutional layers as needed self.linear = nn.Linear(128, 1) # Final layer with sigmoid activation def forward(self, car, shadow): # Concatenate car and shadow features x = torch.cat([car, shadow], dim=1) # Define the forward pass through the convolutional layers x = self.conv1(x) # ... Forward pass through remaining convolutional layers return torch.sigmoid(self.linear(x)) # Create data loaders for training and validation data # ... (Implement data loading logic using PyTorch's DataLoader) # Create the models generator = Generator() discriminator = Discriminator() # Define loss function and optimizer criterion = nn.BCELoss() g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002) # Training loop for epoch in range(epochs): # Train the Discriminator # ... (Implement discriminator training logic with loss calculation and updates) # Train the Generator # ... (Implement generator training logic with loss calculation and updates) # Print training progress # ... (Print loss values or other metrics) # Save the trained generator torch.save(generator.state_dict(), 'generator.pt')