ugmSorcero commited on
Commit
304cf45
1 Parent(s): f456ef3
core/pipelines.py CHANGED
@@ -74,11 +74,12 @@ def dense_passage_retrieval(
74
 
75
  return search_pipeline, index_pipeline
76
 
 
77
  def dense_passage_retrieval_ranker(
78
  index="documents",
79
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
80
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
81
- ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2"
82
  ):
83
  search_pipeline, index_pipeline = dense_passage_retrieval(
84
  index=index,
@@ -86,7 +87,7 @@ def dense_passage_retrieval_ranker(
86
  passage_embedding_model=passage_embedding_model,
87
  )
88
  ranker = SentenceTransformersRanker(model_name_or_path=ranker_model)
89
-
90
  search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
91
-
92
- return search_pipeline, index_pipeline
 
74
 
75
  return search_pipeline, index_pipeline
76
 
77
+
78
  def dense_passage_retrieval_ranker(
79
  index="documents",
80
  query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
81
  passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
82
+ ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
83
  ):
84
  search_pipeline, index_pipeline = dense_passage_retrieval(
85
  index=index,
 
87
  passage_embedding_model=passage_embedding_model,
88
  )
89
  ranker = SentenceTransformersRanker(model_name_or_path=ranker_model)
90
+
91
  search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
92
+
93
+ return search_pipeline, index_pipeline
interface/components.py CHANGED
@@ -29,13 +29,13 @@ def component_select_pipeline(container):
29
  if (
30
  st.session_state["pipeline"] is None
31
  or st.session_state["pipeline"]["name"] != selected_pipeline
32
- or list(st.session_state["pipeline_func_parameters"][index_pipe].values()) != list(pipeline_func_parameters[index_pipe].values())
 
33
  ):
34
  st.session_state["pipeline_func_parameters"] = pipeline_func_parameters
35
- (
36
- search_pipeline,
37
- index_pipeline,
38
- ) = pipeline_funcs[index_pipe](**pipeline_func_parameters[index_pipe])
39
  st.session_state["pipeline"] = {
40
  "name": selected_pipeline,
41
  "search_pipeline": search_pipeline,
 
29
  if (
30
  st.session_state["pipeline"] is None
31
  or st.session_state["pipeline"]["name"] != selected_pipeline
32
+ or list(st.session_state["pipeline_func_parameters"][index_pipe].values())
33
+ != list(pipeline_func_parameters[index_pipe].values())
34
  ):
35
  st.session_state["pipeline_func_parameters"] = pipeline_func_parameters
36
+ (search_pipeline, index_pipeline,) = pipeline_funcs[
37
+ index_pipe
38
+ ](**pipeline_func_parameters[index_pipe])
 
39
  st.session_state["pipeline"] = {
40
  "name": selected_pipeline,
41
  "search_pipeline": search_pipeline,
interface/config.py CHANGED
@@ -1,10 +1,7 @@
1
  from interface.pages import page_landing_page, page_search, page_index
2
 
3
  # Define default Session Variables over the whole session.
4
- session_state_variables = {
5
- "pipeline": None,
6
- "pipeline_func_parameters": []
7
- }
8
 
9
  # Define Pages for the demo
10
  pages = {
 
1
  from interface.pages import page_landing_page, page_search, page_index
2
 
3
  # Define default Session Variables over the whole session.
4
+ session_state_variables = {"pipeline": None, "pipeline_func_parameters": []}
 
 
 
5
 
6
  # Define Pages for the demo
7
  pages = {
interface/utils.py CHANGED
@@ -16,7 +16,10 @@ def get_pipelines():
16
  pipeline_names = [
17
  " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
18
  ]
19
- pipeline_func_parameters = [{key: value.default for key, value in signature(pipe_func).parameters.items()} for pipe_func in pipeline_funcs]
 
 
 
20
  return pipeline_names, pipeline_funcs, pipeline_func_parameters
21
 
22
 
 
16
  pipeline_names = [
17
  " ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
18
  ]
19
+ pipeline_func_parameters = [
20
+ {key: value.default for key, value in signature(pipe_func).parameters.items()}
21
+ for pipe_func in pipeline_funcs
22
+ ]
23
  return pipeline_names, pipeline_funcs, pipeline_func_parameters
24
 
25