K00B404 commited on
Commit
d71e34d
1 Parent(s): 23e8de2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -27,7 +27,7 @@ model_repo_id = "K00B404/pix2pix_flux"
27
 
28
  # Create dataset and dataloader
29
  class Pix2PixDataset(torch.utils.data.Dataset):
30
- def __init__(self, ds):
31
  # Filter dataset for 'original' (label = 0) and 'target' (label = 1) images
32
  self.originals = [x for x in ds["train"] if x['label'] == 0]
33
  self.targets = [x for x in ds["train"] if x['label'] == 1]
@@ -39,35 +39,34 @@ class Pix2PixDataset(torch.utils.data.Dataset):
39
  print(f"Number of original images: {len(self.originals)}")
40
  print(f"Number of target images: {len(self.targets)}")
41
 
 
 
42
  def __len__(self):
43
  return len(self.originals)
44
 
45
  def __getitem__(self, idx):
46
- # Directly use the 'image' object without loading via Image.open()
47
  original_img = self.originals[idx]['image']
48
  target_img = self.targets[idx]['image']
49
 
50
- # Apply the necessary transforms
51
- original = original_img.convert('RGB')
52
- target = target_img.convert('RGB')
53
 
54
- # Return transformed original and target images
55
- return transform(original), transform(target)
56
 
57
  # Training function
58
  def train_model(epochs):
59
  # Load the dataset
60
  ds = load_dataset(dataset_id)
61
  print(f"ds{ds}")
62
- # Transform function to resize and convert to tensor
63
  transform = transforms.Compose([
64
  transforms.Resize((IMG_SIZE, IMG_SIZE)),
65
  transforms.ToTensor(),
66
  ])
67
-
68
-
69
- dataset = Pix2PixDataset(ds)
70
- print(f"dataset{dataset}")
71
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
72
 
73
  # Initialize model, loss function, and optimizer
 
27
 
28
  # Create dataset and dataloader
29
  class Pix2PixDataset(torch.utils.data.Dataset):
30
+ def __init__(self, ds, transform):
31
  # Filter dataset for 'original' (label = 0) and 'target' (label = 1) images
32
  self.originals = [x for x in ds["train"] if x['label'] == 0]
33
  self.targets = [x for x in ds["train"] if x['label'] == 1]
 
39
  print(f"Number of original images: {len(self.originals)}")
40
  print(f"Number of target images: {len(self.targets)}")
41
 
42
+ self.transform = transform # Store the transform
43
+
44
  def __len__(self):
45
  return len(self.originals)
46
 
47
  def __getitem__(self, idx):
 
48
  original_img = self.originals[idx]['image']
49
  target_img = self.targets[idx]['image']
50
 
51
+ original = original_img.convert('RGB') # Convert to RGB if needed
52
+ target = target_img.convert('RGB') # Convert to RGB if needed
 
53
 
54
+ # Apply the necessary transforms
55
+ return self.transform(original), self.transform(target)
56
 
57
  # Training function
58
  def train_model(epochs):
59
  # Load the dataset
60
  ds = load_dataset(dataset_id)
61
  print(f"ds{ds}")
62
+ # Create the transform function outside of the dataset class
63
  transform = transforms.Compose([
64
  transforms.Resize((IMG_SIZE, IMG_SIZE)),
65
  transforms.ToTensor(),
66
  ])
67
+
68
+ # Create dataset and dataloader
69
+ dataset = Pix2PixDataset(ds, transform)
 
70
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
71
 
72
  # Initialize model, loss function, and optimizer