future-xy commited on
Commit
f0ad559
1 Parent(s): 3020792

add base class code

Browse files
Files changed (1) hide show
  1. src/backend/hflm_with_measurement.py +207 -0
src/backend/hflm_with_measurement.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from datetime import timedelta
4
+ from pathlib import Path
5
+ from typing import List, Literal, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import transformers
10
+ from accelerate import (
11
+ Accelerator,
12
+ DistributedType,
13
+ InitProcessGroupKwargs,
14
+ find_executable_batch_size,
15
+ )
16
+ from packaging import version
17
+ from peft import PeftModel
18
+ from peft import __version__ as PEFT_VERSION
19
+ from tqdm import tqdm
20
+ from transformers.models.auto.modeling_auto import (
21
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
22
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
23
+ )
24
+
25
+ from lm_eval import utils
26
+ from lm_eval.api.instance import Instance
27
+ from lm_eval.api.model import TemplateLM
28
+ from lm_eval.api.registry import register_model
29
+ from lm_eval.models.utils import (
30
+ Collator,
31
+ clear_torch_cache,
32
+ get_dtype,
33
+ pad_and_concat,
34
+ stop_sequences_criteria,
35
+ )
36
+ from lm_eval.models.huggingface import HFLM
37
+
38
+
39
+ class HFLMWithMeasurement(HFLM):
40
+ def __init__(self, **kwargs):
41
+ super().__init__(**kwargs)
42
+
43
+ def _model_generate(self, context, max_length, stop, **generation_kwargs):
44
+ # temperature = 0.0 if not set
45
+ # if do_sample is false and temp==0.0:
46
+ # remove temperature, as do_sample=False takes care of this
47
+ # and we don't want a warning from HF
48
+ generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
49
+ do_sample = generation_kwargs.get("do_sample", None)
50
+
51
+ # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
52
+ if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
53
+ generation_kwargs["do_sample"] = do_sample = False
54
+
55
+ if do_sample is False and generation_kwargs.get("temperature") == 0.0:
56
+ generation_kwargs.pop("temperature")
57
+ # build stopping criteria
58
+ stopping_criteria = stop_sequences_criteria(
59
+ self.tokenizer, stop, context.shape[1], context.shape[0]
60
+ )
61
+ return self.model.generate(
62
+ input_ids=context,
63
+ max_length=max_length,
64
+ stopping_criteria=stopping_criteria,
65
+ pad_token_id=self.tokenizer.pad_token_id,
66
+ use_cache=True,
67
+ **generation_kwargs,
68
+ )
69
+
70
+ def generate_until(
71
+ self, requests: List[Instance], disable_tqdm: bool = False
72
+ ) -> List[str]:
73
+ res = []
74
+
75
+ def _collate(req: Tuple[str, dict]):
76
+ """Defines the key for the sorted method"""
77
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
78
+ # - time estimates will always be over not underestimates, which is more useful for planning
79
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
80
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
81
+ # automatic adaptive batches much much easier to implement
82
+ # - any OOMs will happen right away rather than near the end
83
+ toks = self.tok_encode(req[0])
84
+ return -len(toks), req[0]
85
+
86
+ pbar = tqdm(
87
+ total=len(requests),
88
+ disable=(disable_tqdm or (self.rank != 0)),
89
+ desc="Running generate_until requests",
90
+ )
91
+ adaptive_batch_size = None
92
+ if self.batch_size == "auto":
93
+ # using rolling window with maximum context
94
+ print("Passed argument batch_size = auto. Detecting largest batch size")
95
+ batch_size = self._detect_batch_size()
96
+ print(f"Determined Largest batch size: {batch_size}")
97
+ adaptive_batch_size = batch_size
98
+ # for each different set of kwargs, we execute all requests, by batch.
99
+ batch_size = (
100
+ self.batch_size
101
+ if self.batch_size != "auto"
102
+ else adaptive_batch_size
103
+ if adaptive_batch_size is not None
104
+ else 0
105
+ )
106
+ batch_fn = (
107
+ self._batch_scheduler
108
+ if self.batch_size == "auto" and not adaptive_batch_size
109
+ else None
110
+ )
111
+
112
+ # we group requests by their generation_kwargs,
113
+ # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
114
+ # in the same batch.
115
+ # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
116
+ re_ords = Collator(
117
+ [reg.args for reg in requests],
118
+ sort_fn=_collate,
119
+ group_by="gen_kwargs",
120
+ group_fn=lambda x: x[1],
121
+ )
122
+ chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
123
+ for chunk in chunks:
124
+ contexts, all_gen_kwargs = zip(*chunk)
125
+ # we assume all gen kwargs in the batch are the same
126
+ # this is safe to assume because the `grouper` object ensures it.
127
+ gen_kwargs = all_gen_kwargs[0]
128
+ # unpack our keyword arguments.
129
+ until = None
130
+ if isinstance(gen_kwargs, dict):
131
+ kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
132
+ if "until" in kwargs.keys():
133
+ until = kwargs.pop("until")
134
+ if isinstance(until, str):
135
+ until = [kwargs]
136
+ elif not isinstance(until, list):
137
+ raise ValueError(
138
+ f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
139
+ )
140
+ else:
141
+ raise ValueError(
142
+ f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
143
+ )
144
+ # add EOS token to stop sequences
145
+ eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
146
+ if not until:
147
+ until = [eos]
148
+ else:
149
+ until.append(eos)
150
+ if "max_gen_toks" in kwargs.keys():
151
+ max_gen_toks = kwargs.pop("max_gen_toks")
152
+ else:
153
+ max_gen_toks = self.max_gen_toks
154
+
155
+ # set the max length in tokens of inputs ("context_enc")
156
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
157
+ # max len for inputs = max length, minus room to generate the max new tokens
158
+ max_ctx_len = self.max_length - max_gen_toks
159
+ elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
160
+ # max len for inputs = encoder's whole max_length
161
+ max_ctx_len = self.max_length
162
+
163
+ # encode, pad, and truncate contexts for this batch
164
+ context_enc, attn_masks = self.tok_batch_encode(
165
+ contexts,
166
+ left_truncate_len=max_ctx_len,
167
+ truncation=self.truncation,
168
+ )
169
+ context_enc = context_enc.to(self.device)
170
+ attn_masks = attn_masks.to(self.device)
171
+
172
+ if "max_length" not in kwargs:
173
+ kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
174
+
175
+ # perform batched generation
176
+ cont = self._model_generate(
177
+ context=context_enc,
178
+ attention_mask=attn_masks,
179
+ stop=until,
180
+ **kwargs,
181
+ )
182
+
183
+ cont_toks_list = cont.tolist()
184
+ for cont_toks, context in zip(cont_toks_list, contexts):
185
+ # discard context + left-padding toks if using causal decoder-only LM
186
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
187
+ cont_toks = cont_toks[context_enc.shape[1] :]
188
+
189
+ s = self.tok_decode(cont_toks)
190
+
191
+ # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
192
+ for term in until:
193
+ if len(term) > 0:
194
+ # ignore '' separator,
195
+ # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
196
+ s = s.split(term)[0]
197
+
198
+ res.append(s)
199
+
200
+ self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
201
+ pbar.update(1)
202
+ # reorder this group of results back to original unsorted form
203
+ res = re_ords.get_original(res)
204
+
205
+ pbar.close()
206
+
207
+ return res