ziqima commited on
Commit
7591db0
1 Parent(s): 9a7ea82
Files changed (2) hide show
  1. inference/inference.py +3 -3
  2. inference/utils.py +3 -3
inference/inference.py CHANGED
@@ -4,9 +4,9 @@ import matplotlib.pyplot as plt
4
  from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text
5
  import spaces
6
 
7
- DEVICE = "cpu"
8
- if torch.cuda.is_available():
9
- DEVICE = "cuda:0"
10
 
11
  def pred_3d_upsample(
12
  pred, # n_subsampled_pts, feat_dim
 
4
  from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text
5
  import spaces
6
 
7
+ DEVICE = "cuda:0"
8
+ #if torch.cuda.is_available():
9
+ #DEVICE = "cuda:0"
10
 
11
  def pred_3d_upsample(
12
  pred, # n_subsampled_pts, feat_dim
inference/utils.py CHANGED
@@ -5,9 +5,9 @@ import numpy as np
5
  import random
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
- DEVICE = "cpu"
9
- if torch.cuda.is_available():
10
- DEVICE = "cuda:0"
11
 
12
  def get_seg_color(labels):
13
  part_num = labels.max()
 
5
  import random
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
+ DEVICE = "cuda:0"
9
+ #if torch.cuda.is_available():
10
+ #DEVICE = "cuda:0"
11
 
12
  def get_seg_color(labels):
13
  part_num = labels.max()