Spaces:
Build error
Build error
from model import SUPPORTED_SUMM_MODELS | |
from model.base_model import SummModel | |
from model.single_doc import LexRankModel | |
from dataset.st_dataset import SummDataset | |
from dataset.non_huggingface_datasets import ScisummnetDataset | |
from typing import List, Tuple | |
def get_lxr_train_set(dataset: SummDataset, size: int = 100) -> List[str]: | |
""" | |
return some dummy summarization examples, in the format of a list of sources | |
""" | |
subset = [] | |
for i in range(size): | |
subset.append(next(iter(dataset.train_set))) | |
src = list( | |
map( | |
lambda x: " ".join(x.source) | |
if dataset.is_dialogue_based or dataset.is_multi_document | |
else x.source[0] | |
if isinstance(dataset, ScisummnetDataset) | |
else x.source, | |
subset, | |
) | |
) | |
return src | |
def assemble_model_pipeline( | |
dataset: SummDataset, model_list: List[SummModel] = SUPPORTED_SUMM_MODELS | |
) -> List[Tuple[SummModel, str]]: | |
""" | |
Return initialized list of all model pipelines that match the summarization task of given dataset. | |
:param SummDataset `dataset`: Dataset to retrieve model pipelines for. | |
:param List[SummModel] `model_list`: List of candidate model classes (uninitialized). Defaults to `model.SUPPORTED_SUMM_MODELS`. | |
:returns List of tuples, where each tuple contains an initialized model and the name of that model as `(model, name)`. | |
""" | |
dataset = dataset if isinstance(dataset, SummDataset) else dataset() | |
single_doc_model_list = list( | |
filter( | |
lambda model_cls: not ( | |
model_cls.is_dialogue_based | |
or model_cls.is_query_based | |
or model_cls.is_multi_document | |
), | |
model_list, | |
) | |
) | |
single_doc_model_instances = [ | |
model_cls(get_lxr_train_set(dataset)) | |
if model_cls == LexRankModel | |
else model_cls() | |
for model_cls in single_doc_model_list | |
] | |
multi_doc_model_list = list( | |
filter(lambda model_cls: model_cls.is_multi_document, model_list) | |
) | |
query_based_model_list = list( | |
filter(lambda model_cls: model_cls.is_query_based, model_list) | |
) | |
dialogue_based_model_list = list( | |
filter(lambda model_cls: model_cls.is_dialogue_based, model_list) | |
) | |
dialogue_based_model_instances = ( | |
[model_cls() for model_cls in dialogue_based_model_list] | |
if dataset.is_dialogue_based | |
else [] | |
) | |
matching_models = [] | |
if dataset.is_query_based: | |
if dataset.is_dialogue_based: | |
for query_model_cls in query_based_model_list: | |
for dialogue_model in dialogue_based_model_list: | |
full_query_dialogue_model = query_model_cls( | |
model_backend=dialogue_model | |
) | |
matching_models.append( | |
( | |
full_query_dialogue_model, | |
f"{query_model_cls.model_name} ({dialogue_model.model_name})", | |
) | |
) | |
else: | |
for query_model_cls in query_based_model_list: | |
for single_doc_model in single_doc_model_list: | |
full_query_model = ( | |
query_model_cls( | |
model_backend=single_doc_model, | |
data=get_lxr_train_set(dataset), | |
) | |
if single_doc_model == LexRankModel | |
else query_model_cls(model_backend=single_doc_model) | |
) | |
matching_models.append( | |
( | |
full_query_model, | |
f"{query_model_cls.model_name} ({single_doc_model.model_name})", | |
) | |
) | |
return matching_models | |
if dataset.is_multi_document: | |
for multi_doc_model_cls in multi_doc_model_list: | |
for single_doc_model in single_doc_model_list: | |
full_multi_doc_model = ( | |
multi_doc_model_cls( | |
model_backend=single_doc_model, data=get_lxr_train_set(dataset) | |
) | |
if single_doc_model == LexRankModel | |
else multi_doc_model_cls(model_backend=single_doc_model) | |
) | |
matching_models.append( | |
( | |
full_multi_doc_model, | |
f"{multi_doc_model_cls.model_name} ({single_doc_model.model_name})", | |
) | |
) | |
return matching_models | |
if dataset.is_dialogue_based: | |
return list( | |
map( | |
lambda db_model: (db_model, db_model.model_name), | |
dialogue_based_model_instances, | |
) | |
) | |
return list( | |
map(lambda s_model: (s_model, s_model.model_name), single_doc_model_instances) | |
) | |