K00B404 commited on
Commit
d626bab
1 Parent(s): e4391fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -37,17 +37,29 @@ def train_model(epochs):
37
  ])
38
 
39
  # Create dataset and dataloader
40
- class Pix2PixDataset(torch.utils.data.Dataset):
41
- def __init__(self, ds):
42
- self.ds = ds
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def __len__(self):
45
- return len(self.ds["train"])
 
46
 
47
- def __getitem__(self, idx):
48
- original = Image.open(self.ds["train"][idx]['original_image']).convert('RGB')
49
- target = Image.open(self.ds["train"][idx]['target_image']).convert('RGB')
50
- return transform(original), transform(target)
51
 
52
  dataset = Pix2PixDataset(ds)
53
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
 
37
  ])
38
 
39
  # Create dataset and dataloader
40
+ # Create dataset and dataloader
41
+ class Pix2PixDataset(torch.utils.data.Dataset):
42
+ def __init__(self, ds):
43
+ self.originals = [x for x in ds["train"] if x['label'] == 'original']
44
+ self.targets = [x for x in ds["train"] if x['label'] == 'target']
45
+
46
+ # Ensure original and target images match by their index
47
+ assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
48
+
49
+ def __len__(self):
50
+ return len(self.originals)
51
+
52
+ def __getitem__(self, idx):
53
+ # Load original and target images for the given index
54
+ original_img = self.originals[idx]['image']
55
+ target_img = self.targets[idx]['image']
56
 
57
+ # Apply the necessary transforms
58
+ original = Image.open(original_img).convert('RGB')
59
+ target = Image.open(target_img).convert('RGB')
60
 
61
+ # Return transformed original and target images
62
+ return transform(original), transform(target)
 
 
63
 
64
  dataset = Pix2PixDataset(ds)
65
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)