vobecant commited on
Commit
435cc18
1 Parent(s): 0c4f2ac

Initial commit.

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -136,10 +136,14 @@ model, window_size, window_stride, im_size = create_model()
136
 
137
 
138
  def get_transformations():
139
- return transforms.Compose([
140
- transforms.ToTensor(),
141
- transforms.Resize(im_size),
142
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
 
 
 
 
143
 
144
 
145
  def predict(input_img, cs_mapping):
 
136
 
137
 
138
  def get_transformations():
139
+ trans_list = [transforms.ToTensor()]
140
+
141
+ if im_size != 1024:
142
+ trans_list.append(transforms.Resize(im_size))
143
+
144
+ trans_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
145
+
146
+ return transforms.Compose(trans_list)
147
 
148
 
149
  def predict(input_img, cs_mapping):