ziqima commited on
Commit
9a7ea82
1 Parent(s): 4893ce0

add zerogpu

Browse files
Files changed (1) hide show
  1. inference/inference.py +5 -2
inference/inference.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text
 
5
 
6
  DEVICE = "cpu"
7
  if torch.cuda.is_available():
@@ -75,14 +76,16 @@ def get_heatmap_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have ba
75
  scores = all_logits.squeeze().cpu()
76
  heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze()
77
  return heatmap_rgb
78
-
 
79
  def segment_obj(xyz, rgb, normal, queries):
80
  model = load_model()
81
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
82
  data_dict["label_embeds"] = encode_text(queries)
83
  seg_rgb = get_segmentation_rgb(model, data_dict)
84
  return seg_rgb
85
-
 
86
  def get_heatmap(xyz, rgb, normal, query):
87
  model = load_model()
88
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text
5
+ import spaces
6
 
7
  DEVICE = "cpu"
8
  if torch.cuda.is_available():
 
76
  scores = all_logits.squeeze().cpu()
77
  heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze()
78
  return heatmap_rgb
79
+
80
+ @spaces.GPU(duration=90)
81
  def segment_obj(xyz, rgb, normal, queries):
82
  model = load_model()
83
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
84
  data_dict["label_embeds"] = encode_text(queries)
85
  seg_rgb = get_segmentation_rgb(model, data_dict)
86
  return seg_rgb
87
+
88
+ @spaces.GPU(duration=90)
89
  def get_heatmap(xyz, rgb, normal, query):
90
  model = load_model()
91
  data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))