ziqima commited on
Commit
f0dc24f
1 Parent(s): 7591db0
Files changed (2) hide show
  1. inference/inference.py +2 -2
  2. requirements.txt +1 -1
inference/inference.py CHANGED
@@ -77,7 +77,7 @@ def get_heatmap_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have ba
77
  heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze()
78
  return heatmap_rgb
79
 
80
- @spaces.GPU(duration=90)
81
  def segment_obj(xyz, rgb, normal, queries):
82
  model = load_model()
83
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
@@ -85,7 +85,7 @@ def segment_obj(xyz, rgb, normal, queries):
85
  seg_rgb = get_segmentation_rgb(model, data_dict)
86
  return seg_rgb
87
 
88
- @spaces.GPU(duration=90)
89
  def get_heatmap(xyz, rgb, normal, query):
90
  model = load_model()
91
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
 
77
  heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze()
78
  return heatmap_rgb
79
 
80
+ @spaces.GPU
81
  def segment_obj(xyz, rgb, normal, queries):
82
  model = load_model()
83
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
 
85
  seg_rgb = get_segmentation_rgb(model, data_dict)
86
  return seg_rgb
87
 
88
+ @spaces.GPU
89
  def get_heatmap(xyz, rgb, normal, query):
90
  model = load_model()
91
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
requirements.txt CHANGED
@@ -24,7 +24,7 @@ scipy
24
  plyfile
25
  termcolor
26
  timm
27
- spconv
28
  transformers
29
  open3d
30
  sentencepiece
 
24
  plyfile
25
  termcolor
26
  timm
27
+ spconv-cu118
28
  transformers
29
  open3d
30
  sentencepiece