Rohil Bansal commited on
Commit
09ae4e4
1 Parent(s): d36d296

New training

Browse files
.gitattributes CHANGED
@@ -1 +1,2 @@
1
  checkpoints/latest_checkpoint.pth.tar filter=lfs diff=lfs merge=lfs -text
 
 
1
  checkpoints/latest_checkpoint.pth.tar filter=lfs diff=lfs merge=lfs -text
2
+ checkpoints/latest_checkpoint1.pth.tar filter=lfs diff=lfs merge=lfs -text
checkpoints/latest_checkpoint.pth.tar CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c8de65df3e4879e931cdf7f3de2fdc3d05298c0e955b39b2281627f36e27fcff
3
- size 686252474
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b32b1f4363aad01e662d468989e7e0b8f41afec20ffcbf1e87b6a6147454cbd
3
+ size 686253114
checkpoints/latest_checkpoint1.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8de65df3e4879e931cdf7f3de2fdc3d05298c0e955b39b2281627f36e27fcff
3
+ size 686252474
colorizer_pipeline.py CHANGED
@@ -230,13 +230,22 @@ def visualize_results(epoch, generator, train_loader, device):
230
  generator.train()
231
 
232
  def save_checkpoint(state, filename="checkpoint.pth.tar"):
233
- torch.save(state, filename)
 
 
 
 
 
 
 
 
234
  mlflow.log_artifact(filename)
235
 
236
- def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD):
237
  if os.path.isfile(filename):
238
  print(f"Loading checkpoint '{filename}'")
239
- checkpoint = torch.load(filename)
 
240
  start_epoch = checkpoint['epoch'] + 1
241
  generator.load_state_dict(checkpoint['generator_state_dict'])
242
  discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
@@ -248,6 +257,11 @@ def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD):
248
  print(f"No checkpoint found at '{filename}'")
249
  return 0
250
 
 
 
 
 
 
251
  # Training function
252
  def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002, beta1=0.5):
253
  criterion = nn.BCEWithLogitsLoss()
@@ -256,12 +270,8 @@ def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002,
256
  optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
257
  optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
258
 
259
- checkpoint_dir = "checkpoints"
260
- os.makedirs(checkpoint_dir, exist_ok=True)
261
- os.makedirs("results", exist_ok=True)
262
-
263
  checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth.tar")
264
- start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, optimizerG, optimizerD)
265
 
266
  experiment_id = setup_mlflow()
267
  with mlflow.start_run(experiment_id=experiment_id, run_name="training_run") as run:
@@ -270,7 +280,7 @@ def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002,
270
  generator.train()
271
  discriminator.train()
272
 
273
- num_iterations = 2
274
  pbar = tqdm(enumerate(islice(train_loader, num_iterations)), total=num_iterations, desc=f"Epoch {epoch+1}/{num_epochs}")
275
 
276
  for i, (real_L, real_AB) in pbar:
 
230
  generator.train()
231
 
232
  def save_checkpoint(state, filename="checkpoint.pth.tar"):
233
+ # Only save the necessary state
234
+ save_state = {
235
+ 'epoch': state['epoch'],
236
+ 'generator_state_dict': state['generator_state_dict'],
237
+ 'discriminator_state_dict': state['discriminator_state_dict'],
238
+ 'optimizerG_state_dict': state['optimizerG_state_dict'],
239
+ 'optimizerD_state_dict': state['optimizerD_state_dict'],
240
+ }
241
+ torch.save(save_state, filename)
242
  mlflow.log_artifact(filename)
243
 
244
+ def load_checkpoint(filename, generator, discriminator, optimizerG, optimizerD, device):
245
  if os.path.isfile(filename):
246
  print(f"Loading checkpoint '{filename}'")
247
+ # Use weights_only=True for safer loading
248
+ checkpoint = torch.load(filename, map_location=device, weights_only=True)
249
  start_epoch = checkpoint['epoch'] + 1
250
  generator.load_state_dict(checkpoint['generator_state_dict'])
251
  discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
 
257
  print(f"No checkpoint found at '{filename}'")
258
  return 0
259
 
260
+ # Global variables
261
+ checkpoint_dir = "checkpoints"
262
+ os.makedirs(checkpoint_dir, exist_ok=True)
263
+ os.makedirs("results", exist_ok=True)
264
+
265
  # Training function
266
  def train(generator, discriminator, train_loader, num_epochs, device, lr=0.0002, beta1=0.5):
267
  criterion = nn.BCEWithLogitsLoss()
 
270
  optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
271
  optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
272
 
 
 
 
 
273
  checkpoint_path = os.path.join(checkpoint_dir, "latest_checkpoint.pth.tar")
274
+ start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, optimizerG, optimizerD, device)
275
 
276
  experiment_id = setup_mlflow()
277
  with mlflow.start_run(experiment_id=experiment_id, run_name="training_run") as run:
 
280
  generator.train()
281
  discriminator.train()
282
 
283
+ num_iterations = 2000
284
  pbar = tqdm(enumerate(islice(train_loader, num_iterations)), total=num_iterations, desc=f"Epoch {epoch+1}/{num_epochs}")
285
 
286
  for i, (real_L, real_AB) in pbar:
convert_checkpoint.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ def load_and_save_checkpoint(input_filename, output_filename, device):
5
+ if os.path.isfile(input_filename):
6
+ print(f"Loading checkpoint '{input_filename}'")
7
+ checkpoint = torch.load(input_filename, map_location=device)
8
+
9
+ # Extract only the necessary state
10
+ save_state = {
11
+ 'epoch': checkpoint['epoch'],
12
+ 'generator_state_dict': checkpoint['generator_state_dict'],
13
+ 'discriminator_state_dict': checkpoint['discriminator_state_dict'],
14
+ 'optimizerG_state_dict': checkpoint['optimizerG_state_dict'],
15
+ 'optimizerD_state_dict': checkpoint['optimizerD_state_dict'],
16
+ }
17
+
18
+ # Save the checkpoint
19
+ torch.save(save_state, output_filename)
20
+ print(f"Saved checkpoint to '{output_filename}'")
21
+ else:
22
+ print(f"No checkpoint found at '{input_filename}'")
23
+
24
+ if __name__ == "__main__":
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ print(f"Using device: {device}")
27
+
28
+ input_checkpoint = "checkpoints/latest_checkpoint.pth.tar"
29
+ output_checkpoint = "checkpoints/converted_checkpoint.pth.tar"
30
+
31
+ load_and_save_checkpoint(input_checkpoint, output_checkpoint, device)