File size: 2,257 Bytes
da58751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f6df11
da58751
 
2f6df11
 
 
da58751
2f6df11
 
 
 
 
 
 
 
 
da58751
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Any, Dict, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

dtype = torch.bfloat16


class EndpointHandler:
    def __init__(self, path=""):
        # load the model
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(
            path, device_map="auto", torch_dtype=dtype
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        # create inference pipeline
        self.pipeline = pipeline(
            "text-generation", model=self.model, tokenizer=self.tokenizer
        )
        self.ce = torch.nn.CrossEntropyLoss(
            ignore_index=self.tokenizer.pad_token_id, reduction="none"
        )

    def compute_log_likelihood(self, lm_logits, input_ids):
        predictions = lm_logits[..., :-1, :].contiguous()
        target_ids = input_ids[..., 1:].contiguous()

        ce_loss = self.ce(
            predictions.view(-1, predictions.size(-1)),
            target_ids.view(-1),
        )
        return -ce_loss.view_as(target_ids)[0]

    def __call__(self, data: Any):
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", None)
        input_tokens = self.tokenizer.batch_encode_plus(
                [inputs], return_tensors="pt", padding=False
            )
        for t in input_tokens:
            if torch.is_tensor(input_tokens[t]):
                input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())

        logits = self.model(
            input_ids=input_tokens["input_ids"],
            attention_mask=input_tokens["attention_mask"],
        )[0]
        log_likelihood = self.compute_log_likelihood(
            logits, input_tokens["input_ids"]
        )
        return (logits, log_likelihood)
        

# if __name__ == "__main__":
#     model = EndpointHandler("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

#     data = {
#         "inputs": "Can you please let us know more details about your ",
#         "parameters": {
#             "no_generation": True,
#             # "function_to_apply": "none",
#             # "return_text": False,
#         },
#     }
#     x = model(data)