SummerTime / pipeline /__init__.py
akhaliq3
spaces demo
546a9ba
from model import SUPPORTED_SUMM_MODELS
from model.base_model import SummModel
from model.single_doc import LexRankModel
from dataset.st_dataset import SummDataset
from dataset.non_huggingface_datasets import ScisummnetDataset
from typing import List, Tuple
def get_lxr_train_set(dataset: SummDataset, size: int = 100) -> List[str]:
"""
return some dummy summarization examples, in the format of a list of sources
"""
subset = []
for i in range(size):
subset.append(next(iter(dataset.train_set)))
src = list(
map(
lambda x: " ".join(x.source)
if dataset.is_dialogue_based or dataset.is_multi_document
else x.source[0]
if isinstance(dataset, ScisummnetDataset)
else x.source,
subset,
)
)
return src
def assemble_model_pipeline(
dataset: SummDataset, model_list: List[SummModel] = SUPPORTED_SUMM_MODELS
) -> List[Tuple[SummModel, str]]:
"""
Return initialized list of all model pipelines that match the summarization task of given dataset.
:param SummDataset `dataset`: Dataset to retrieve model pipelines for.
:param List[SummModel] `model_list`: List of candidate model classes (uninitialized). Defaults to `model.SUPPORTED_SUMM_MODELS`.
:returns List of tuples, where each tuple contains an initialized model and the name of that model as `(model, name)`.
"""
dataset = dataset if isinstance(dataset, SummDataset) else dataset()
single_doc_model_list = list(
filter(
lambda model_cls: not (
model_cls.is_dialogue_based
or model_cls.is_query_based
or model_cls.is_multi_document
),
model_list,
)
)
single_doc_model_instances = [
model_cls(get_lxr_train_set(dataset))
if model_cls == LexRankModel
else model_cls()
for model_cls in single_doc_model_list
]
multi_doc_model_list = list(
filter(lambda model_cls: model_cls.is_multi_document, model_list)
)
query_based_model_list = list(
filter(lambda model_cls: model_cls.is_query_based, model_list)
)
dialogue_based_model_list = list(
filter(lambda model_cls: model_cls.is_dialogue_based, model_list)
)
dialogue_based_model_instances = (
[model_cls() for model_cls in dialogue_based_model_list]
if dataset.is_dialogue_based
else []
)
matching_models = []
if dataset.is_query_based:
if dataset.is_dialogue_based:
for query_model_cls in query_based_model_list:
for dialogue_model in dialogue_based_model_list:
full_query_dialogue_model = query_model_cls(
model_backend=dialogue_model
)
matching_models.append(
(
full_query_dialogue_model,
f"{query_model_cls.model_name} ({dialogue_model.model_name})",
)
)
else:
for query_model_cls in query_based_model_list:
for single_doc_model in single_doc_model_list:
full_query_model = (
query_model_cls(
model_backend=single_doc_model,
data=get_lxr_train_set(dataset),
)
if single_doc_model == LexRankModel
else query_model_cls(model_backend=single_doc_model)
)
matching_models.append(
(
full_query_model,
f"{query_model_cls.model_name} ({single_doc_model.model_name})",
)
)
return matching_models
if dataset.is_multi_document:
for multi_doc_model_cls in multi_doc_model_list:
for single_doc_model in single_doc_model_list:
full_multi_doc_model = (
multi_doc_model_cls(
model_backend=single_doc_model, data=get_lxr_train_set(dataset)
)
if single_doc_model == LexRankModel
else multi_doc_model_cls(model_backend=single_doc_model)
)
matching_models.append(
(
full_multi_doc_model,
f"{multi_doc_model_cls.model_name} ({single_doc_model.model_name})",
)
)
return matching_models
if dataset.is_dialogue_based:
return list(
map(
lambda db_model: (db_model, db_model.model_name),
dialogue_based_model_instances,
)
)
return list(
map(lambda s_model: (s_model, s_model.model_name), single_doc_model_instances)
)