lambdaofgod commited on
Commit
34c16bd
1 Parent(s): a284f57
Files changed (1) hide show
  1. app_implementation.py +13 -5
app_implementation.py CHANGED
@@ -16,9 +16,17 @@ from search_utils import (
16
 
17
 
18
  class RetrievalApp:
 
 
 
 
 
 
 
 
19
  def get_device_options(self):
20
- if torch.cuda.is_available:
21
- return ["cuda", "cpu"]
22
  else:
23
  return ["cpu"]
24
 
@@ -67,9 +75,9 @@ class RetrievalApp:
67
  st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
68
  with st.spinner(text="fetching results"):
69
  st.write(
70
- retrieval_pipe.search(query, k, description_length, additional_shown_cols).to_html(
71
- escape=False, index=False
72
- ),
73
  unsafe_allow_html=True,
74
  )
75
  print("finished retrieval")
 
16
 
17
 
18
  class RetrievalApp:
19
+ def is_cuda_available(self):
20
+ try:
21
+ t = torch.Tensor([1]).cuda()
22
+ except:
23
+ return False
24
+ finally:
25
+ return True
26
+
27
  def get_device_options(self):
28
+ if self.is_cuda_available():
29
+ return ["cpu", "cuda"]
30
  else:
31
  return ["cpu"]
32
 
 
75
  st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
76
  with st.spinner(text="fetching results"):
77
  st.write(
78
+ retrieval_pipe.search(
79
+ query, k, description_length, additional_shown_cols
80
+ ).to_html(escape=False, index=False),
81
  unsafe_allow_html=True,
82
  )
83
  print("finished retrieval")