cocktailpeanut commited on
Commit
0ab242d
1 Parent(s): 16a2b40

device handling

Browse files
Files changed (2) hide show
  1. app.py +9 -1
  2. requirements.txt +5 -5
app.py CHANGED
@@ -5,11 +5,19 @@ from PIL import Image, ImageFilter
5
  import uuid
6
  from scipy.interpolate import interp1d, PchipInterpolator
7
  import torchvision
 
8
  from utils import *
9
 
10
  output_dir = "outputs"
11
  ensure_dirname(output_dir)
12
 
 
 
 
 
 
 
 
13
  def interpolate_trajectory(points, n_points):
14
  x = [point[0] for point in points]
15
  y = [point[1] for point in points]
@@ -219,7 +227,7 @@ with gr.Blocks() as demo:
219
  2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
220
  3. Animate the image according the path with a click on "Run" button. <br>""")
221
 
222
- DragNUWA_net = Drag("cuda:0", 'drag_nuwa.pth', 'DragNUWA_net.py', 320, 576, 14)
223
  first_frame_path = gr.State()
224
  tracking_points = gr.State([])
225
 
 
5
  import uuid
6
  from scipy.interpolate import interp1d, PchipInterpolator
7
  import torchvision
8
+ import torch
9
  from utils import *
10
 
11
  output_dir = "outputs"
12
  ensure_dirname(output_dir)
13
 
14
+ if torch.cuda.is_available():
15
+ device = "cuda"
16
+ elif torch.backends.mps.is_available():
17
+ device = "mps"
18
+ else:
19
+ device = "cpu"
20
+
21
  def interpolate_trajectory(points, n_points):
22
  x = [point[0] for point in points]
23
  y = [point[1] for point in points]
 
227
  2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
228
  3. Animate the image according the path with a click on "Run" button. <br>""")
229
 
230
+ DragNUWA_net = Drag(device, 'drag_nuwa.pth', 'DragNUWA_net.py', 320, 576, 14)
231
  first_frame_path = gr.State()
232
  tracking_points = gr.State([])
233
 
requirements.txt CHANGED
@@ -24,11 +24,11 @@ streamlit>=0.73.1
24
  tensorboardx==2.6
25
  timm>=0.9.2
26
  tokenizers==0.12.1
27
- torch>=2.0.1
28
- torchaudio>=2.0.2
29
  torchdata==0.6.1
30
  torchmetrics>=1.0.1
31
- torchvision>=0.15.2
32
  tqdm>=4.65.0
33
  transformers==4.19.1
34
  triton==2.0.0
@@ -36,7 +36,7 @@ urllib3<1.27,>=1.25.4
36
  wandb>=0.15.6
37
  webdataset>=0.2.33
38
  wheel>=0.41.0
39
- xformers>=0.0.20
40
  colorlog
41
  deepdish
42
  json_lines
@@ -51,4 +51,4 @@ pyyaml
51
  pandas
52
  einops
53
  deepspeed
54
- gradio==3.50.2
 
24
  tensorboardx==2.6
25
  timm>=0.9.2
26
  tokenizers==0.12.1
27
+ #torch>=2.0.1
28
+ #torchaudio>=2.0.2
29
  torchdata==0.6.1
30
  torchmetrics>=1.0.1
31
+ #torchvision>=0.15.2
32
  tqdm>=4.65.0
33
  transformers==4.19.1
34
  triton==2.0.0
 
36
  wandb>=0.15.6
37
  webdataset>=0.2.33
38
  wheel>=0.41.0
39
+ #xformers>=0.0.20
40
  colorlog
41
  deepdish
42
  json_lines
 
51
  pandas
52
  einops
53
  deepspeed
54
+ gradio==3.50.2