skytnt commited on
Commit
0b4ff36
1 Parent(s): 8275d1e

Update anime_aesthetic.py

Browse files
Files changed (1) hide show
  1. anime_aesthetic.py +10 -3
anime_aesthetic.py CHANGED
@@ -216,11 +216,18 @@ def rescale_pad(image, output_size, random_pad=False):
216
  if h != output_size or w != output_size:
217
  r = min(output_size / h, output_size / w)
218
  new_h, new_w = int(h * r), int(w * r)
 
 
 
219
  ph = output_size - new_h
220
  pw = output_size - new_w
 
 
 
 
221
  image = transforms.functional.resize(image, [new_h, new_w])
222
  image = transforms.functional.pad(
223
- image, [pw // 2, ph // 2, pw // 2 + pw % 2, ph // 2 + ph % 2], random.uniform(0, 1) if random_pad else 0
224
  )
225
  return image
226
 
@@ -435,7 +442,7 @@ if __name__ == "__main__":
435
  parser.add_argument(
436
  "--data-split",
437
  type=float,
438
- default=0.9995,
439
  help="split rate for training and validation",
440
  )
441
 
@@ -486,7 +493,7 @@ if __name__ == "__main__":
486
  "--log-step", type=int, default=2, help="log training loss every n steps"
487
  )
488
  parser.add_argument(
489
- "--val-epoch", type=int, default=0.1, help="valid and save every n epoch"
490
  )
491
 
492
  opt = parser.parse_args()
 
216
  if h != output_size or w != output_size:
217
  r = min(output_size / h, output_size / w)
218
  new_h, new_w = int(h * r), int(w * r)
219
+ if random_pad:
220
+ r2 = random.uniform(0.9, 1)
221
+ new_h, new_w = int(new_h * r2), int(new_w * r2)
222
  ph = output_size - new_h
223
  pw = output_size - new_w
224
+ left = random.randint(0, pw) if random_pad else pw // 2
225
+ right = pw - left
226
+ top = random.randint(0, ph) if random_pad else ph // 2
227
+ bottom = ph - top
228
  image = transforms.functional.resize(image, [new_h, new_w])
229
  image = transforms.functional.pad(
230
+ image, [left, top, right, bottom], random.uniform(0, 1) if random_pad else 0
231
  )
232
  return image
233
 
 
442
  parser.add_argument(
443
  "--data-split",
444
  type=float,
445
+ default=0.9999,
446
  help="split rate for training and validation",
447
  )
448
 
 
493
  "--log-step", type=int, default=2, help="log training loss every n steps"
494
  )
495
  parser.add_argument(
496
+ "--val-epoch", type=int, default=0.025, help="valid and save every n epoch"
497
  )
498
 
499
  opt = parser.parse_args()