Spaces:
Runtime error
Runtime error
ugmSorcero
commited on
Commit
•
304cf45
1
Parent(s):
f456ef3
Linter
Browse files- core/pipelines.py +5 -4
- interface/components.py +5 -5
- interface/config.py +1 -4
- interface/utils.py +4 -1
core/pipelines.py
CHANGED
@@ -74,11 +74,12 @@ def dense_passage_retrieval(
|
|
74 |
|
75 |
return search_pipeline, index_pipeline
|
76 |
|
|
|
77 |
def dense_passage_retrieval_ranker(
|
78 |
index="documents",
|
79 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
80 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
81 |
-
ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2"
|
82 |
):
|
83 |
search_pipeline, index_pipeline = dense_passage_retrieval(
|
84 |
index=index,
|
@@ -86,7 +87,7 @@ def dense_passage_retrieval_ranker(
|
|
86 |
passage_embedding_model=passage_embedding_model,
|
87 |
)
|
88 |
ranker = SentenceTransformersRanker(model_name_or_path=ranker_model)
|
89 |
-
|
90 |
search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
|
91 |
-
|
92 |
-
return search_pipeline, index_pipeline
|
|
|
74 |
|
75 |
return search_pipeline, index_pipeline
|
76 |
|
77 |
+
|
78 |
def dense_passage_retrieval_ranker(
|
79 |
index="documents",
|
80 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
81 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
82 |
+
ranker_model="cross-encoder/ms-marco-MiniLM-L-12-v2",
|
83 |
):
|
84 |
search_pipeline, index_pipeline = dense_passage_retrieval(
|
85 |
index=index,
|
|
|
87 |
passage_embedding_model=passage_embedding_model,
|
88 |
)
|
89 |
ranker = SentenceTransformersRanker(model_name_or_path=ranker_model)
|
90 |
+
|
91 |
search_pipeline.add_node(ranker, name="Ranker", inputs=["DPRRetriever"])
|
92 |
+
|
93 |
+
return search_pipeline, index_pipeline
|
interface/components.py
CHANGED
@@ -29,13 +29,13 @@ def component_select_pipeline(container):
|
|
29 |
if (
|
30 |
st.session_state["pipeline"] is None
|
31 |
or st.session_state["pipeline"]["name"] != selected_pipeline
|
32 |
-
or list(st.session_state["pipeline_func_parameters"][index_pipe].values())
|
|
|
33 |
):
|
34 |
st.session_state["pipeline_func_parameters"] = pipeline_func_parameters
|
35 |
-
(
|
36 |
-
|
37 |
-
|
38 |
-
) = pipeline_funcs[index_pipe](**pipeline_func_parameters[index_pipe])
|
39 |
st.session_state["pipeline"] = {
|
40 |
"name": selected_pipeline,
|
41 |
"search_pipeline": search_pipeline,
|
|
|
29 |
if (
|
30 |
st.session_state["pipeline"] is None
|
31 |
or st.session_state["pipeline"]["name"] != selected_pipeline
|
32 |
+
or list(st.session_state["pipeline_func_parameters"][index_pipe].values())
|
33 |
+
!= list(pipeline_func_parameters[index_pipe].values())
|
34 |
):
|
35 |
st.session_state["pipeline_func_parameters"] = pipeline_func_parameters
|
36 |
+
(search_pipeline, index_pipeline,) = pipeline_funcs[
|
37 |
+
index_pipe
|
38 |
+
](**pipeline_func_parameters[index_pipe])
|
|
|
39 |
st.session_state["pipeline"] = {
|
40 |
"name": selected_pipeline,
|
41 |
"search_pipeline": search_pipeline,
|
interface/config.py
CHANGED
@@ -1,10 +1,7 @@
|
|
1 |
from interface.pages import page_landing_page, page_search, page_index
|
2 |
|
3 |
# Define default Session Variables over the whole session.
|
4 |
-
session_state_variables = {
|
5 |
-
"pipeline": None,
|
6 |
-
"pipeline_func_parameters": []
|
7 |
-
}
|
8 |
|
9 |
# Define Pages for the demo
|
10 |
pages = {
|
|
|
1 |
from interface.pages import page_landing_page, page_search, page_index
|
2 |
|
3 |
# Define default Session Variables over the whole session.
|
4 |
+
session_state_variables = {"pipeline": None, "pipeline_func_parameters": []}
|
|
|
|
|
|
|
5 |
|
6 |
# Define Pages for the demo
|
7 |
pages = {
|
interface/utils.py
CHANGED
@@ -16,7 +16,10 @@ def get_pipelines():
|
|
16 |
pipeline_names = [
|
17 |
" ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
|
18 |
]
|
19 |
-
pipeline_func_parameters = [
|
|
|
|
|
|
|
20 |
return pipeline_names, pipeline_funcs, pipeline_func_parameters
|
21 |
|
22 |
|
|
|
16 |
pipeline_names = [
|
17 |
" ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
|
18 |
]
|
19 |
+
pipeline_func_parameters = [
|
20 |
+
{key: value.default for key, value in signature(pipe_func).parameters.items()}
|
21 |
+
for pipe_func in pipeline_funcs
|
22 |
+
]
|
23 |
return pipeline_names, pipeline_funcs, pipeline_func_parameters
|
24 |
|
25 |
|