File size: 3,583 Bytes
e2cf2b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pickle
import imageio
import numpy as np
import scipy.interpolate
import torch
from tqdm import tqdm
import gradio as gr 
from huggingface_hub import hf_hub_download

def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
    batch_size, channels, img_h, img_w = img.shape
    if grid_w is None:
        grid_w = batch_size // grid_h
    assert batch_size == grid_w * grid_h
    if float_to_uint8:
        img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
    img = img.permute(2, 0, 3, 1, 4)
    img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
    if chw_to_hwc:
        img = img.permute(1, 2, 0)
    if to_numpy:
        img = img.cpu().numpy()
    return img




network_pkl=hf_hub_download('SerdarHelli/BrainMRIGAN/braingan-400.pkl')
with open(network_pkl, 'rb') as f:
    G = pickle.load(f)['G_ema'] 

def predict(Seed,choices):
  device = torch.device('cuda')
  G.eval()
  G.to(device)
  shuffle_seed=None
  w_frames=60*4
  kind='cubic' 
  num_keyframes=None
  wraps=2
  psi=1 
  device=torch.device('cuda')

  
  if choices=='4x2':
    grid_w = 4
    grid_h = 2
    s1=Seed
    seeds=(np.arange(s1-16,s1)).tolist()
  if choices=='2x1':
    grid_w = 2
    grid_h = 1
    s1=Seed
    seeds=(np.arange(s1-4,s1)).tolist()


  mp4='ex.mp4'
  truncation_psi=1
  num_keyframes=None


  if num_keyframes is None:
      if len(seeds) % (grid_w*grid_h) != 0:
          raise ValueError('Number of input seeds must be divisible by grid W*H')
      num_keyframes = len(seeds) // (grid_w*grid_h)

  all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
  for idx in range(num_keyframes*grid_h*grid_w):
      all_seeds[idx] = seeds[idx % len(seeds)]

  if shuffle_seed is not None:
      rng = np.random.RandomState(seed=shuffle_seed)
      rng.shuffle(all_seeds)

  zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
  ws = G.mapping(z=zs, c=None, truncation_psi=psi)
  _ = G.synthesis(ws[:1]) # warm up
  ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])

  # Interpolation.
  grid = []
  for yi in range(grid_h):
      row = []
      for xi in range(grid_w):
          x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
          y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
          interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
          row.append(interp)
      grid.append(row)

  # Render video.
  video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264')
  for frame_idx in tqdm(range(num_keyframes * w_frames)):
      imgs = []
      for yi in range(grid_h):
          for xi in range(grid_w):
              interp = grid[yi][xi]
              w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
              img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
              imgs.append(img)
      video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
  video_out.close()
  return 'ex.mp4'



choices=['4x2','2x1']
interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2",
                       description = "",
                       article = "Author: S.Serdar Helli",
                       inputs=[gr.inputs.Slider( minimum=16, maximum=2**10,label='Seed'),gr.inputs.Radio( choices=choices,  default='4x2',label='Image Grid')],
                       outputs=gr.outputs.Video(label='Video'))


interface.launch(debug=True)