File size: 7,454 Bytes
7d1b5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
from transformers import AutoTokenizer
from functools import partial

from .configuration_live import LiveConfigMixin

def get_stream_placeholder_len(num_frames: int, model_config: LiveConfigMixin) -> str:
    return num_frames * model_config.frame_num_tokens * len(model_config.v_placeholder) + len(model_config.frame_token_interval) * (num_frames - 1)

def get_stream_placeholder_jinja2(model_config: LiveConfigMixin) -> str:
    return f"'{model_config.frame_token_interval}'.join([{model_config.frame_num_tokens} * '{model_config.v_placeholder}'] * message['num_frames'])"

def get_stream_learn_ranges(num_frames: int, model_config: LiveConfigMixin) -> torch.Tensor:
    len_frame_placeholder_with_interval = model_config.frame_num_tokens * len(model_config.v_placeholder) + len(model_config.frame_token_interval)
    intermediate_interval_idxs = torch.arange(
        len_frame_placeholder_with_interval,
        len_frame_placeholder_with_interval * num_frames + 1,
        len_frame_placeholder_with_interval
    ) - len(model_config.frame_token_interval)
    len_learn = len(model_config.frame_token_interval) if model_config.frame_token_interval else len(model_config.v_placeholder)
    learn_ranges = torch.stack([
        intermediate_interval_idxs,
        intermediate_interval_idxs + len_learn
    ], dim=1)
    return learn_ranges

def chat_template(self, stream_placeholder_jinja2: str):
    """
    system prompt
    [<v>,<v>,<v>]
    User: ...
    Assistant: ...</s>
    [<v>,<v>]
    Assistant: ...</s>
    User: ...
    Assistant: ...</s>
    """
    template = (
        "{% if messages[0]['role'] == 'system' %}"
        "{{ bos_token + messages[0]['content'] + '\n' }}" # system
        "{% set messages = messages[1:] %}"
        "{% endif %}"
        "{% for message in messages %}"
        "{% if message['role'] == 'user' %}"
        "{% if add_stream_query_prompt %}"
        "{{ ']\nUser: ' + message['content'] }}"
        "{% else %}"
        "{{ '\nUser: ' + message['content'] }}"
        "{% endif %}"
        "{% elif message['role'] == 'assistant' %}"
        "{{ '\nAssistant: '  + message['content'] + eos_token }}"
        "{% elif message['role'] == 'stream' and message['num_frames'] > 0: %}"
        "{{ '\n[' + STREAM_PLACEHOLDER + ']' }}"
        "{% endif %}"
        "{% endfor %}"
        "{% if add_generation_prompt %}"
        "{{ '\nAssistant:' }}"
        "{% elif add_stream_prompt %}"
        "{{ '\n[' }}"
        "{% elif add_stream_generation_prompt %}"
        "{{ ']\nAssistant:' }}"
        "{% endif %}"
    )
    template = template.replace('STREAM_PLACEHOLDER', stream_placeholder_jinja2)
    return template

def chat_template_transition(tokenizer):
    return {
        (None, 'system'): tokenizer.bos_token,
        ('system', 'user'): '\n\nUser: ',
        ('system', 'stream'): '\n\n[',
        ('user', 'assistant'): '\nAssistant: ',
        ('user', 'stream'): '\n[',
        ('user', 'user'): '\nUser: ',
        ('assistant', 'user'): f'{tokenizer.eos_token}\nUser: ',
        ('assistant', 'stream'): f'{tokenizer.eos_token}\n[',
        ('stream', 'user'): ']\nUser: ',
        ('stream', 'assistant'): ']\nAssistant: ',
        'assistant': 'Assistant: ',
        'eos_token': tokenizer.eos_token,
    }

def chat_template_offsets(tokenizer):
    return {k:len(v) for k, v in chat_template_transition(tokenizer).items()}

def get_learn_ranges(conversation: list[dict], *, chat_template_offsets: dict[tuple, int], model_config: LiveConfigMixin):
    offset = 0
    learn_ranges = []
    last_role = None
    for message in conversation:
        role = message['role']
        offset += chat_template_offsets[(last_role, role)]
        last_role = role
        if role == 'stream':
            if message.get('learn', False):
                ranges = get_stream_learn_ranges(message['num_frames'], model_config) + offset
                # the last one has ]\n, should also consider \n
                ranges[-1, 1] += 1
                if not isinstance(message['learn'], bool):
                    ranges = ranges[:message['learn']]
                learn_ranges.extend([range(r[0], r[1]) for r in ranges])
            offset += get_stream_placeholder_len(message['num_frames'], model_config)
        else:
            if role == 'assistant':
                if message.get('learn', False):
                    learn_ranges.append(range(offset - chat_template_offsets['assistant'], offset + len(message['content']) + chat_template_offsets['eos_token']))
            offset += len(message['content'])
    return learn_ranges

def build_live_tokenizer_and_update_config(llm_pretrained: str, model_config: LiveConfigMixin) -> AutoTokenizer:
    tokenizer = AutoTokenizer.from_pretrained(llm_pretrained, use_fast=True, padding_side='left')
    tokenizer.add_special_tokens({'additional_special_tokens': [model_config.v_placeholder]})
    v_placeholder_id = len(tokenizer) - 1
    if model_config.frame_token_interval:
        frame_token_interval_id = tokenizer.convert_tokens_to_ids(model_config.frame_token_interval)
    else:
        frame_token_interval_id = None
    tokenizer.pad_token = tokenizer.eos_token
    model_config.update(dict(v_placeholder_id=v_placeholder_id, frame_token_interval_id=frame_token_interval_id, eos_token_id=tokenizer.eos_token_id))
    tokenizer.chat_template = chat_template(tokenizer, get_stream_placeholder_jinja2(model_config))
    tokenizer.get_learn_ranges = partial(get_learn_ranges, chat_template_offsets=chat_template_offsets(tokenizer), model_config=model_config)
    return tokenizer

if __name__ == '__main__':
    config = LiveConfigMixin(frame_token_interval=',', frame_token_cls=True, frame_token_pooled=[3,3], frame_num_tokens=10)
    tokenizer = build_live_tokenizer_and_update_config('meta-llama/Meta-Llama-3-8B-Instruct', config)
    chat = [
        {'role': 'system', 'content': 'cool.'},
        {'role': 'stream', 'num_frames': 2, 'learn': 1},
        {'role': 'user', 'content': 'cool?'},
        {'role': 'assistant', 'content': 'cool.', 'learn': True},
        {'role': 'stream', 'num_frames': 3, 'learn': 3},
        {'role': 'assistant', 'content': 'so cool.', 'learn': True},
    ]
    prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
    learn_ranges = tokenizer.get_learn_ranges(chat)
    batch = tokenizer([prompt], return_offsets_mapping=True, add_special_tokens=False, return_tensors="pt", padding=True)
    batch_labels = torch.full_like(batch.input_ids, -100, dtype=torch.long)
    for text, labels, input_ids, offset_mapping, learn_range in zip(
        [prompt], batch_labels, batch.input_ids, batch.offset_mapping, [learn_ranges]
    ):
        for learn_r in learn_range:
            start = torch.nonzero(offset_mapping[:,0] == learn_r.start).item()
            if offset_mapping[:,0][-1] >= learn_r.stop:
                stop = torch.nonzero(offset_mapping[:,0] == learn_r.stop).item()
            else: # the last eos token
                stop = len(input_ids)
            labels[start-1:stop-1] = input_ids[start:stop]
            # NOTE: input_ids may out of boundary of len(tokenizer) - 1. (1 is the added vision placeholder)
            # this is because some frames has v_placeholder_id target. so replace it with eos token.
            labels[labels >= len(tokenizer) - 1] = tokenizer.eos_token_id
    print(batch.input_ids)
    print(batch_labels)