Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ model_repo_id = "K00B404/pix2pix_flux"
|
|
27 |
|
28 |
# Create dataset and dataloader
|
29 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
30 |
-
def __init__(self, ds):
|
31 |
# Filter dataset for 'original' (label = 0) and 'target' (label = 1) images
|
32 |
self.originals = [x for x in ds["train"] if x['label'] == 0]
|
33 |
self.targets = [x for x in ds["train"] if x['label'] == 1]
|
@@ -39,35 +39,34 @@ class Pix2PixDataset(torch.utils.data.Dataset):
|
|
39 |
print(f"Number of original images: {len(self.originals)}")
|
40 |
print(f"Number of target images: {len(self.targets)}")
|
41 |
|
|
|
|
|
42 |
def __len__(self):
|
43 |
return len(self.originals)
|
44 |
|
45 |
def __getitem__(self, idx):
|
46 |
-
# Directly use the 'image' object without loading via Image.open()
|
47 |
original_img = self.originals[idx]['image']
|
48 |
target_img = self.targets[idx]['image']
|
49 |
|
50 |
-
#
|
51 |
-
|
52 |
-
target = target_img.convert('RGB')
|
53 |
|
54 |
-
#
|
55 |
-
return transform(original), transform(target)
|
56 |
|
57 |
# Training function
|
58 |
def train_model(epochs):
|
59 |
# Load the dataset
|
60 |
ds = load_dataset(dataset_id)
|
61 |
print(f"ds{ds}")
|
62 |
-
#
|
63 |
transform = transforms.Compose([
|
64 |
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
65 |
transforms.ToTensor(),
|
66 |
])
|
67 |
-
|
68 |
-
|
69 |
-
dataset = Pix2PixDataset(ds)
|
70 |
-
print(f"dataset{dataset}")
|
71 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
72 |
|
73 |
# Initialize model, loss function, and optimizer
|
|
|
27 |
|
28 |
# Create dataset and dataloader
|
29 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
30 |
+
def __init__(self, ds, transform):
|
31 |
# Filter dataset for 'original' (label = 0) and 'target' (label = 1) images
|
32 |
self.originals = [x for x in ds["train"] if x['label'] == 0]
|
33 |
self.targets = [x for x in ds["train"] if x['label'] == 1]
|
|
|
39 |
print(f"Number of original images: {len(self.originals)}")
|
40 |
print(f"Number of target images: {len(self.targets)}")
|
41 |
|
42 |
+
self.transform = transform # Store the transform
|
43 |
+
|
44 |
def __len__(self):
|
45 |
return len(self.originals)
|
46 |
|
47 |
def __getitem__(self, idx):
|
|
|
48 |
original_img = self.originals[idx]['image']
|
49 |
target_img = self.targets[idx]['image']
|
50 |
|
51 |
+
original = original_img.convert('RGB') # Convert to RGB if needed
|
52 |
+
target = target_img.convert('RGB') # Convert to RGB if needed
|
|
|
53 |
|
54 |
+
# Apply the necessary transforms
|
55 |
+
return self.transform(original), self.transform(target)
|
56 |
|
57 |
# Training function
|
58 |
def train_model(epochs):
|
59 |
# Load the dataset
|
60 |
ds = load_dataset(dataset_id)
|
61 |
print(f"ds{ds}")
|
62 |
+
# Create the transform function outside of the dataset class
|
63 |
transform = transforms.Compose([
|
64 |
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
65 |
transforms.ToTensor(),
|
66 |
])
|
67 |
+
|
68 |
+
# Create dataset and dataloader
|
69 |
+
dataset = Pix2PixDataset(ds, transform)
|
|
|
70 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
71 |
|
72 |
# Initialize model, loss function, and optimizer
|