Spaces:
Build error
Build error
File size: 5,007 Bytes
546a9ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from model import SUPPORTED_SUMM_MODELS
from model.base_model import SummModel
from model.single_doc import LexRankModel
from dataset.st_dataset import SummDataset
from dataset.non_huggingface_datasets import ScisummnetDataset
from typing import List, Tuple
def get_lxr_train_set(dataset: SummDataset, size: int = 100) -> List[str]:
"""
return some dummy summarization examples, in the format of a list of sources
"""
subset = []
for i in range(size):
subset.append(next(iter(dataset.train_set)))
src = list(
map(
lambda x: " ".join(x.source)
if dataset.is_dialogue_based or dataset.is_multi_document
else x.source[0]
if isinstance(dataset, ScisummnetDataset)
else x.source,
subset,
)
)
return src
def assemble_model_pipeline(
dataset: SummDataset, model_list: List[SummModel] = SUPPORTED_SUMM_MODELS
) -> List[Tuple[SummModel, str]]:
"""
Return initialized list of all model pipelines that match the summarization task of given dataset.
:param SummDataset `dataset`: Dataset to retrieve model pipelines for.
:param List[SummModel] `model_list`: List of candidate model classes (uninitialized). Defaults to `model.SUPPORTED_SUMM_MODELS`.
:returns List of tuples, where each tuple contains an initialized model and the name of that model as `(model, name)`.
"""
dataset = dataset if isinstance(dataset, SummDataset) else dataset()
single_doc_model_list = list(
filter(
lambda model_cls: not (
model_cls.is_dialogue_based
or model_cls.is_query_based
or model_cls.is_multi_document
),
model_list,
)
)
single_doc_model_instances = [
model_cls(get_lxr_train_set(dataset))
if model_cls == LexRankModel
else model_cls()
for model_cls in single_doc_model_list
]
multi_doc_model_list = list(
filter(lambda model_cls: model_cls.is_multi_document, model_list)
)
query_based_model_list = list(
filter(lambda model_cls: model_cls.is_query_based, model_list)
)
dialogue_based_model_list = list(
filter(lambda model_cls: model_cls.is_dialogue_based, model_list)
)
dialogue_based_model_instances = (
[model_cls() for model_cls in dialogue_based_model_list]
if dataset.is_dialogue_based
else []
)
matching_models = []
if dataset.is_query_based:
if dataset.is_dialogue_based:
for query_model_cls in query_based_model_list:
for dialogue_model in dialogue_based_model_list:
full_query_dialogue_model = query_model_cls(
model_backend=dialogue_model
)
matching_models.append(
(
full_query_dialogue_model,
f"{query_model_cls.model_name} ({dialogue_model.model_name})",
)
)
else:
for query_model_cls in query_based_model_list:
for single_doc_model in single_doc_model_list:
full_query_model = (
query_model_cls(
model_backend=single_doc_model,
data=get_lxr_train_set(dataset),
)
if single_doc_model == LexRankModel
else query_model_cls(model_backend=single_doc_model)
)
matching_models.append(
(
full_query_model,
f"{query_model_cls.model_name} ({single_doc_model.model_name})",
)
)
return matching_models
if dataset.is_multi_document:
for multi_doc_model_cls in multi_doc_model_list:
for single_doc_model in single_doc_model_list:
full_multi_doc_model = (
multi_doc_model_cls(
model_backend=single_doc_model, data=get_lxr_train_set(dataset)
)
if single_doc_model == LexRankModel
else multi_doc_model_cls(model_backend=single_doc_model)
)
matching_models.append(
(
full_multi_doc_model,
f"{multi_doc_model_cls.model_name} ({single_doc_model.model_name})",
)
)
return matching_models
if dataset.is_dialogue_based:
return list(
map(
lambda db_model: (db_model, db_model.model_name),
dialogue_based_model_instances,
)
)
return list(
map(lambda s_model: (s_model, s_model.model_name), single_doc_model_instances)
)
|