sunshangquan commited on
Commit
629756b
1 Parent(s): 30e3fae

commit from ssq

Browse files
Files changed (4) hide show
  1. app.py +13 -10
  2. examples/example.jpeg +0 -0
  3. git.bash → git.sh +0 -0
  4. requirements.txt +1 -0
app.py CHANGED
@@ -3,16 +3,16 @@ import gradio as gr
3
  import numpy as np
4
  import torch.nn.functional as F
5
 
6
- from Allweather.util import load_img
7
  from basicsr.models.archs.histoformer_arch import Histoformer
8
 
9
  model_restoration = Histoformer()
10
- checkpoint = torch.load("Allweather/pretrained_models/net_g_real.pth")
11
- model_restoration.load_state_dict(checkpoint['params'])
12
- model_restoration.cuda()
13
- model_restoration = nn.DataParallel(model_restoration)
14
 
15
- def preprocess(file_, factor = 8):
 
16
  img = np.float32(load_img(file_))/255.
17
  img = torch.from_numpy(img).permute(2,0,1)
18
  input_ = img.unsqueeze(0).cuda()
@@ -24,12 +24,15 @@ def preprocess(file_, factor = 8):
24
  padw = W-w if w%factor!=0 else 0
25
  input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
26
  return input_
 
 
 
 
 
 
27
 
28
- def predict(input_img):
29
- prediction = model_restoration(preprocess(input_img))
30
- return input_img, prediction
31
  example_images = [
32
- "examples/lemur.jpg",
33
  ]
34
  gradio_app = gr.Interface(
35
  predict,
 
3
  import numpy as np
4
  import torch.nn.functional as F
5
 
6
+ from Allweather.util import load_img, save_img
7
  from basicsr.models.archs.histoformer_arch import Histoformer
8
 
9
  model_restoration = Histoformer()
10
+ model, transform = Histoformer.create_model_and_transforms()
11
+ model = model.to(device)
12
+ model.eval()
 
13
 
14
+ factor = 8
15
+ def predict(input_img):
16
  img = np.float32(load_img(file_))/255.
17
  img = torch.from_numpy(img).permute(2,0,1)
18
  input_ = img.unsqueeze(0).cuda()
 
24
  padw = W-w if w%factor!=0 else 0
25
  input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
26
  return input_
27
+ prediction = model_restoration(input_)
28
+ output_path = "inverse_depth_map.png"
29
+ restored = restored[:,:,:h,:w]
30
+ restored = torch.clamp(restored,0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
31
+
32
+ save_img((os.path.join(result_dir, os.path.splitext(os.path.split(file_)[-1])[0]+'.png')), img_as_ubyte(restored))
33
 
 
 
 
34
  example_images = [
35
+ "examples/example.jpeg",
36
  ]
37
  gradio_app = gr.Interface(
38
  predict,
examples/example.jpeg ADDED
git.bash → git.sh RENAMED
File without changes
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch