File size: 5,007 Bytes
546a9ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from model import SUPPORTED_SUMM_MODELS
from model.base_model import SummModel
from model.single_doc import LexRankModel

from dataset.st_dataset import SummDataset
from dataset.non_huggingface_datasets import ScisummnetDataset

from typing import List, Tuple


def get_lxr_train_set(dataset: SummDataset, size: int = 100) -> List[str]:

    """
    return some dummy summarization examples, in the format of a list of sources
    """
    subset = []
    for i in range(size):
        subset.append(next(iter(dataset.train_set)))

    src = list(
        map(
            lambda x: " ".join(x.source)
            if dataset.is_dialogue_based or dataset.is_multi_document
            else x.source[0]
            if isinstance(dataset, ScisummnetDataset)
            else x.source,
            subset,
        )
    )

    return src


def assemble_model_pipeline(
    dataset: SummDataset, model_list: List[SummModel] = SUPPORTED_SUMM_MODELS
) -> List[Tuple[SummModel, str]]:

    """
    Return initialized list of all model pipelines that match the summarization task of given dataset.

    :param SummDataset `dataset`: Dataset to retrieve model pipelines for.
    :param List[SummModel] `model_list`: List of candidate model classes (uninitialized). Defaults to `model.SUPPORTED_SUMM_MODELS`.
    :returns List of tuples, where each tuple contains an initialized model and the name of that model as `(model, name)`.
    """

    dataset = dataset if isinstance(dataset, SummDataset) else dataset()

    single_doc_model_list = list(
        filter(
            lambda model_cls: not (
                model_cls.is_dialogue_based
                or model_cls.is_query_based
                or model_cls.is_multi_document
            ),
            model_list,
        )
    )
    single_doc_model_instances = [
        model_cls(get_lxr_train_set(dataset))
        if model_cls == LexRankModel
        else model_cls()
        for model_cls in single_doc_model_list
    ]

    multi_doc_model_list = list(
        filter(lambda model_cls: model_cls.is_multi_document, model_list)
    )

    query_based_model_list = list(
        filter(lambda model_cls: model_cls.is_query_based, model_list)
    )

    dialogue_based_model_list = list(
        filter(lambda model_cls: model_cls.is_dialogue_based, model_list)
    )
    dialogue_based_model_instances = (
        [model_cls() for model_cls in dialogue_based_model_list]
        if dataset.is_dialogue_based
        else []
    )

    matching_models = []
    if dataset.is_query_based:
        if dataset.is_dialogue_based:
            for query_model_cls in query_based_model_list:
                for dialogue_model in dialogue_based_model_list:
                    full_query_dialogue_model = query_model_cls(
                        model_backend=dialogue_model
                    )
                    matching_models.append(
                        (
                            full_query_dialogue_model,
                            f"{query_model_cls.model_name} ({dialogue_model.model_name})",
                        )
                    )
        else:
            for query_model_cls in query_based_model_list:
                for single_doc_model in single_doc_model_list:
                    full_query_model = (
                        query_model_cls(
                            model_backend=single_doc_model,
                            data=get_lxr_train_set(dataset),
                        )
                        if single_doc_model == LexRankModel
                        else query_model_cls(model_backend=single_doc_model)
                    )
                    matching_models.append(
                        (
                            full_query_model,
                            f"{query_model_cls.model_name} ({single_doc_model.model_name})",
                        )
                    )
        return matching_models

    if dataset.is_multi_document:
        for multi_doc_model_cls in multi_doc_model_list:
            for single_doc_model in single_doc_model_list:
                full_multi_doc_model = (
                    multi_doc_model_cls(
                        model_backend=single_doc_model, data=get_lxr_train_set(dataset)
                    )
                    if single_doc_model == LexRankModel
                    else multi_doc_model_cls(model_backend=single_doc_model)
                )
                matching_models.append(
                    (
                        full_multi_doc_model,
                        f"{multi_doc_model_cls.model_name} ({single_doc_model.model_name})",
                    )
                )
        return matching_models

    if dataset.is_dialogue_based:
        return list(
            map(
                lambda db_model: (db_model, db_model.model_name),
                dialogue_based_model_instances,
            )
        )

    return list(
        map(lambda s_model: (s_model, s_model.model_name), single_doc_model_instances)
    )