File size: 5,301 Bytes
3adea03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import time
from jinja2 import meta, TemplateError
import pytest
import promptsource.templates
from promptsource.utils import get_dataset_builder
from uuid import UUID

# Sets up Jinja environment
env = promptsource.templates.env

# Loads templates and iterates over each data (sub)set
template_collection = promptsource.templates.TemplateCollection()


def test_uuids():
    """
    Checks that all UUIDs across promptsource are unique. (Although collisions
    are unlikely, copying and pasting YAML files could lead to duplicates.
    """
    all_uuids = {}

    # Iterates over all datasets
    for dataset_name, subset_name in template_collection.keys:

        # Iterates over each template for current data (sub)set
        dataset_templates = template_collection.get_dataset(dataset_name, subset_name)
        for template_name in dataset_templates.all_template_names:
            template = dataset_templates[template_name]

            uuid = template.get_id()

            if uuid in all_uuids:
                raise ValueError(f"Template {template_name} for dataset {dataset_name}/{subset_name} "
                                 f"has duplicate uuid {template.get_id()} as "
                                 f"{all_uuids[uuid][0]}/{all_uuids[uuid][1]}.")

            all_uuids[uuid] = (dataset_name, subset_name)


@pytest.mark.parametrize("dataset", template_collection.keys)
def test_dataset(dataset):
    """
    Validates all the templates in the repository with simple syntactic checks:
    0. Are all templates parsable YAML?
    1. Do all templates parse in Jinja and are all referenced variables in the dataset schema?
    2. Does the template contain a prompt/output separator "|||" ?
    3. Are all names and templates within a data (sub)set unique?
    4. Is the YAML dictionary properly formatted?
    5. Is the UUID valid?

    :param dataset: (dataset_name, subset_name) pair to test

    """
    dataset_name, subset_name = dataset

    # Loads dataset information
    tries = 0
    max_tries = 3
    while True:
        try:
            builder_instance = get_dataset_builder(dataset_name, subset_name)
            break
        except ConnectionError as e:
            if tries < max_tries:
                time.sleep(2)
                tries += 1
            else:
                raise e

    has_features = builder_instance.info.features is not None
    if has_features:
        features = builder_instance.info.features.keys()
        features = set([feature.replace("-", "_") for feature in features])

    # Initializes sets for checking uniqueness among templates
    template_name_set = set()
    template_jinja_set = set()

    # Iterates over each template for current data (sub)set
    dataset_templates = template_collection.get_dataset(dataset_name, subset_name)
    any_original = False
    for template_name in dataset_templates.all_template_names:
        template = dataset_templates[template_name]
        any_original = any_original or template.metadata.original_task
        # Check 1: Jinja and all features valid?
        try:
            parse = env.parse(template.jinja)
        except TemplateError as e:
            raise ValueError(f"Template for dataset {dataset_name}/{subset_name} "
                             f"with uuid {template.get_id()} failed to parse.") from e

        variables = meta.find_undeclared_variables(parse)
        for variable in variables:
            if has_features and variable not in features and variable != "answer_choices":
                raise ValueError(f"Template for dataset {dataset_name}/{subset_name} "
                                 f"with uuid {template.get_id()} has unrecognized variable {variable}.")

        # Check 2: Prompt/output separator present?
        if "|||" not in template.jinja:
            raise ValueError(f"Template for dataset {dataset_name}/{subset_name} "
                             f"with uuid {template.get_id()} has no prompt/output separator.")

        # Check 3: Unique names and templates?
        if template.get_name() in template_name_set:
            raise ValueError(f"Template for dataset {dataset_name}/{subset_name} "
                             f"with uuid {template.get_id()} has duplicate name.")

        if template.jinja in template_jinja_set:
            raise ValueError(f"Template for dataset {dataset_name}/{subset_name} "
                             f"with uuid {template.get_id()} has duplicate definition.")

        template_name_set.add(template.get_name())
        template_jinja_set.add(template.jinja)

        # Check 4: Is the YAML dictionary properly formatted?
        try:
            if dataset_templates.templates[template.get_id()] != template:
                raise ValueError(f"Template for dataset {dataset_name}/{subset_name} "
                                 f"with uuid {template.get_id()} has wrong YAML key.")
        except KeyError as e:
            raise ValueError(f"Template for dataset {dataset_name}/{subset_name} "
                             f"with uuid {template.get_id()} has wrong YAML key.") from e

        # Check 5: Is the UUID valid?
        UUID(template.get_id())

    # Turned off for now until we fix.
    #assert any_original, "There must be at least one original task template for each dataset"