K00B404 commited on
Commit
5010115
1 Parent(s): d626bab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -25,19 +25,6 @@ LR = 0.0002
25
  dataset_id = "K00B404/pix2pix_flux_set"
26
  model_repo_id = "K00B404/pix2pix_flux"
27
 
28
- # Training function
29
- def train_model(epochs):
30
- # Load the dataset
31
- ds = load_dataset(dataset_id)
32
-
33
- # Transform function to resize and convert to tensor
34
- transform = transforms.Compose([
35
- transforms.Resize((IMG_SIZE, IMG_SIZE)),
36
- transforms.ToTensor(),
37
- ])
38
-
39
- # Create dataset and dataloader
40
- # Create dataset and dataloader
41
  class Pix2PixDataset(torch.utils.data.Dataset):
42
  def __init__(self, ds):
43
  self.originals = [x for x in ds["train"] if x['label'] == 'original']
@@ -60,6 +47,18 @@ class Pix2PixDataset(torch.utils.data.Dataset):
60
 
61
  # Return transformed original and target images
62
  return transform(original), transform(target)
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  dataset = Pix2PixDataset(ds)
65
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
 
25
  dataset_id = "K00B404/pix2pix_flux_set"
26
  model_repo_id = "K00B404/pix2pix_flux"
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class Pix2PixDataset(torch.utils.data.Dataset):
29
  def __init__(self, ds):
30
  self.originals = [x for x in ds["train"] if x['label'] == 'original']
 
47
 
48
  # Return transformed original and target images
49
  return transform(original), transform(target)
50
+
51
+ # Training function
52
+ def train_model(epochs):
53
+ # Load the dataset
54
+ ds = load_dataset(dataset_id)
55
+
56
+ # Transform function to resize and convert to tensor
57
+ transform = transforms.Compose([
58
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
59
+ transforms.ToTensor(),
60
+ ])
61
+
62
 
63
  dataset = Pix2PixDataset(ds)
64
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)