File size: 6,061 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from typing import Dict, List, Optional, Union

import torch
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          pipeline)

from opencompass.utils.prompt import PromptList

from .base import BaseModel, LMTemplateParser

PromptType = Union[PromptList, str]


class AlayaLM(BaseModel):
    """Model wrapper for Alaya model.

    Args:
        path (str): The name or path to Alaya model, could be a local path
            or a Huggingface model tag of Alaya.
        max_seq_len (int): The maximum length of the input sequence. Defaults
            to 2048.
        tokenizer_only (bool): If True, only the tokenizer will be initialized.
            Defaults to False.
        meta_template (Dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.

    Note:
        Alaya has some arguments which should be fixed such as
            eos_token_id and  bad_words_ids.
        Model config should be loaded from a model config file.
        Triton is supported to accelerate the inference process.
        This class supports both Alaya Base model and Alaya Chat model.
    """

    def __init__(self,
                 path: str,
                 max_seq_len: int = 2048,
                 tokenizer_only: bool = False,
                 meta_template: Optional[Dict] = None,
                 **kwargs):

        self.template_parser = LMTemplateParser(meta_template)

        self.max_seq_len = max_seq_len
        self.tokenizer_only = tokenizer_only
        self.meta_template = meta_template

        self.name = path
        self.eos_token_id = 2
        self.bad_words_ids = 3

        self.gpu_id = '0'

        self.config = AutoConfig.from_pretrained(self.name,
                                                 trust_remote_code=True,
                                                 local_file_only=True)
        self.config.attn_config['attn_impl'] = 'triton'
        self.config.init_device = 'cuda:' + self.gpu_id

        self.model = AutoModelForCausalLM.from_pretrained(
            self.name,
            config=self.config,
            torch_dtype=torch.bfloat16,  # Load model weights in bfloat16
            trust_remote_code=True,
        )

        self.tokenizer = AutoTokenizer.from_pretrained(self.name,
                                                       local_file_only=True,
                                                       padding_side='left')

        self.pipe = pipeline('text-generation',
                             model=self.model,
                             tokenizer=self.tokenizer,
                             bad_words_ids=[[self.bad_words_ids]],
                             eos_token_id=self.eos_token_id,
                             pad_token_id=self.eos_token_id,
                             device='cuda:' + self.gpu_id)

    def do_inference(self, instruction, history=[]):
        PROMPT_FORMAT = '### Instruction:\t\n{instruction}\n\n'
        OUTPUT_FORMAT = '### Output:\t\n{output} </s>'

        prompt = PROMPT_FORMAT.format(instruction=instruction)

        history2llm = []

        for i, msg in enumerate(history):
            if i % 2 == 0:  # user
                msg2llm = PROMPT_FORMAT.format(instruction=msg)
            else:  # alaya
                msg2llm = OUTPUT_FORMAT.format(output=msg)
            history2llm.append(msg2llm)

        flag = '### Output:\t\n'
        prompt2LLM = ''.join(history2llm) + prompt

        if len(prompt2LLM) >= 1500:
            prompt2LLM = prompt2LLM[-1500:]

        result = self.pipe(prompt2LLM,
                           max_new_tokens=100,
                           max_length=1900,
                           do_sample=True,
                           use_cache=True,
                           eos_token_id=self.eos_token_id,
                           pad_token_id=self.eos_token_id)

        try:
            output = result[0]['generated_text'][len(prompt2LLM):].lstrip(flag)
        except Exception:
            output = result[0]['generated_text']
        return output

    def generate(
        self,
        inputs,
        max_out_len: int = 1000,
    ) -> List[str]:
        """Generate results given a list of inputs."""
        outputs = []
        for instruction in inputs:
            output = self.do_inference(instruction=instruction)
            outputs.append(output)
        return outputs

    def get_token_len(self, prompt: str) -> int:
        """Get lengths of the tokenized string."""
        return len(self.tokenizer.encode(prompt))

    def get_ppl(self,
                inputs: List[str],
                mask_length: Optional[List[int]] = None) -> List[float]:
        """Copied from .huggingface.py."""

        assert mask_length is None, 'mask_length is not supported'
        bsz = len(inputs)
        params = self.model.params
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
        # tokenize
        prompt_tokens = [self.tokenizer.encode(x, True, False) for x in inputs]
        max_prompt_size = max([len(t) for t in prompt_tokens])
        total_len = min(params.max_seq_len, max_prompt_size)
        tokens = torch.zeros((bsz, total_len)).cuda().long()
        for k, t in enumerate(prompt_tokens):
            num_token = min(total_len, len(t))
            tokens[k, :num_token] = torch.tensor(t[-num_token:]).long()
        # forward
        outputs = self.model.forward(tokens, 0)
        # compute ppl
        shift_logits = outputs[..., :-1, :].contiguous().float()
        shift_labels = tokens[..., 1:].contiguous()
        shift_logits = shift_logits.view(-1, shift_logits.size(-1))
        shift_labels = shift_labels.view(-1)
        loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=0)
        loss = loss_fct(shift_logits, shift_labels).view(bsz, -1)
        lens = (tokens != 0).sum(-1).cpu().numpy()
        ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
        return ce_loss