File size: 4,313 Bytes
72a1159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1b5167
 
 
 
 
 
72a1159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52c67ef
 
72a1159
 
 
 
 
 
 
 
 
 
 
 
ee83d59
 
 
 
 
 
 
52c67ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import transformers
from utils import load_model, static_init
from global_config import GlobalConfig


@static_init
class ModelFactory:
    models_names = {}
    models = {}
    tokenizers = {}
    run_model = None
    dtype = torch.bfloat16
    load_device = torch.device("cpu")
    run_device = torch.device("cpu")

    @classmethod
    def __static_init__(cls):
        names_sec = GlobalConfig.get_section("models.names")
        if names_sec is not None:
            for name in names_sec:
                cls.models_names[name] = GlobalConfig.get("models.names", name)

        if GlobalConfig.get_section("models.params") is not None:
            dtype = GlobalConfig.get("models.params", "dtype")
            if dtype == "bfloat16":
                cls.dtype = torch.bfloat16
            elif dtype == "float16":
                cls.dtype = torch.float16
            elif dtype == "float32":
                cls.dtype = torch.float32

            load_device = GlobalConfig.get("models.params", "load_device")
            run_device = GlobalConfig.get("models.params", "run_device")
            if not torch.cuda.is_available():
                if load_device == "cuda" or run_device == "cuda":
                    print("cuda is not available, use cpu instead")
                    load_device = "cpu"
                    run_device = "cpu"

            if load_device is not None:
                cls.load_device = torch.device(str(load_device))
            if run_device is not None:
                cls.run_device = torch.device(str(run_device))

    @classmethod
    def __load_model(cls, name):
        if name not in cls.models_names:
            print(f"{name} is not a valid model name")
            return None

        if name not in cls.models:
            model, tokenizer = load_model(
                cls.models_names[name], cls.load_device
            )
            cls.models[name] = model
            cls.tokenizers[name] = tokenizer
        else:
            model, tokenizer = cls.models[name], cls.tokenizers[name]

        return model, tokenizer

    @classmethod
    def load_model(cls, name):
        if name not in cls.models:
            if cls.__load_model(name) is None:
                return None, None

        if name != cls.run_model and cls.run_model is not None:
            cls.models[cls.run_model].to(cls.load_device)

        cls.models[name].to(cls.run_device)
        cls.run_model = name

        return cls.models[name], cls.tokenizers[name]

    @classmethod
    def get_models_names(cls):
        return list(cls.models_names.keys())

    @classmethod
    def get_model_max_length(cls, name: str):
        if name in cls.tokenizers:
            return cls.tokenizers[name].model_max_length
        else:
            return 0

    @classmethod
    def compute_perplexity(cls, model_name, text):
        # This code is copied from https://huggingface.co/docs/transformers/perplexity
        model, tokenizer = cls.load_model(model_name)
        if model is None or tokenizer is None:
            return 0
        device = model.device
        encodings = tokenizer(text, return_tensors="pt").to(device)

        max_length = model.config.n_positions
        stride = max_length//2
        seq_len = encodings.input_ids.size(1)

        nlls = []
        prev_end_loc = 0
        for begin_loc in range(0, seq_len, stride):
            end_loc = min(begin_loc + max_length, seq_len)
            trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
            input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
            target_ids = input_ids.clone()
            target_ids[:, :-trg_len] = -100

            with torch.no_grad():
                outputs = model(input_ids, labels=target_ids)

                # loss is calculated using CrossEntropyLoss which averages over valid labels
                # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
                # to the left by 1.
                neg_log_likelihood = outputs.loss

            nlls.append(neg_log_likelihood)

            prev_end_loc = end_loc
            if end_loc == seq_len:
                break

        ppl = torch.exp(torch.stack(nlls).mean()).item()
        return ppl