Gong Baitao
commited on
Commit
•
2156c56
1
Parent(s):
32554d7
Update tokenization_cpmbee.py
Browse files- tokenization_cpmbee.py +130 -0
tokenization_cpmbee.py
CHANGED
@@ -18,6 +18,7 @@ import os
|
|
18 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
19 |
|
20 |
import numpy as np
|
|
|
21 |
from typing_extensions import TypedDict
|
22 |
|
23 |
from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
|
@@ -866,3 +867,132 @@ class CpmBeeTokenizer(PreTrainedTokenizer):
|
|
866 |
)
|
867 |
|
868 |
return batch_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
19 |
|
20 |
import numpy as np
|
21 |
+
from numpy.typing import NDArray
|
22 |
from typing_extensions import TypedDict
|
23 |
|
24 |
from transformers.tokenization_utils import PaddingStrategy, PreTrainedTokenizer, TensorType
|
|
|
867 |
)
|
868 |
|
869 |
return batch_outputs
|
870 |
+
|
871 |
+
def prepare_for_finetune(
|
872 |
+
self,
|
873 |
+
data_list: List[Dict],
|
874 |
+
max_length: int = 2048
|
875 |
+
):
|
876 |
+
_inputs: List[NDArray[np.int32]] = []
|
877 |
+
_inputs_sub: List[NDArray[np.int32]] = []
|
878 |
+
_context: List[NDArray[np.int8]] = []
|
879 |
+
_sample_ids: List[NDArray[np.int32]] = []
|
880 |
+
_segments: List[NDArray[np.int32]] = []
|
881 |
+
_num_segments: List[NDArray[np.int32]] = []
|
882 |
+
_segment_rel_offset: List[NDArray[np.int32]] = []
|
883 |
+
_segment_rel: List[NDArray[np.int32]] = []
|
884 |
+
_spans: List[List[int]] = []
|
885 |
+
_raw_data: List[List[Any]] = []
|
886 |
+
|
887 |
+
raw_data = {}
|
888 |
+
for data in data_list:
|
889 |
+
(
|
890 |
+
input_ids,
|
891 |
+
input_id_subs,
|
892 |
+
context,
|
893 |
+
segment_ids,
|
894 |
+
segment_rel,
|
895 |
+
n_segments,
|
896 |
+
_
|
897 |
+
) = self.convert_data_to_id(data)
|
898 |
+
|
899 |
+
input_ids = input_ids[: max_length]
|
900 |
+
context = context[: max_length]
|
901 |
+
segment_ids = segment_ids[: max_length]
|
902 |
+
raw_data["input"] = data
|
903 |
+
raw_data["samples"] = []
|
904 |
+
|
905 |
+
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
906 |
+
segment_rel_offset = np.zeros(input_ids.shape, dtype=np.int32)
|
907 |
+
num_segments = np.full(input_ids.shape, n_segments, dtype=np.int32)
|
908 |
+
|
909 |
+
_inputs.append(input_ids)
|
910 |
+
_inputs_sub.append(input_id_subs)
|
911 |
+
_context.append(context)
|
912 |
+
_sample_ids.append(sample_ids)
|
913 |
+
_segments.append(segment_ids)
|
914 |
+
_num_segments.append(num_segments)
|
915 |
+
_segment_rel_offset.append(segment_rel_offset)
|
916 |
+
_segment_rel.append(segment_rel)
|
917 |
+
_spans.append([input_ids.shape[0]])
|
918 |
+
_raw_data.append([raw_data])
|
919 |
+
|
920 |
+
batch_size = len(_inputs)
|
921 |
+
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
922 |
+
inputs_sub = np.zeros((batch_size, max_length), dtype=np.int32)
|
923 |
+
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
924 |
+
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
925 |
+
segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
926 |
+
num_segments = np.zeros((batch_size, max_length), dtype=np.int32)
|
927 |
+
segment_rel_offset = np.zeros((batch_size, max_length), dtype=np.int32)
|
928 |
+
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
929 |
+
|
930 |
+
max_rel = 0
|
931 |
+
for i in range(batch_size):
|
932 |
+
max_rel = max(max_rel, _segment_rel[i].shape[0])
|
933 |
+
segment_rel = np.zeros((batch_size, max_rel), dtype=np.int32)
|
934 |
+
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
935 |
+
length = np.zeros((batch_size,), dtype=np.int32)
|
936 |
+
|
937 |
+
batch_ext_table_map: Dict[Tuple[int, int], int] = {}
|
938 |
+
batch_ext_table_ids: List[int] = []
|
939 |
+
batch_ext_table_sub: List[int] = []
|
940 |
+
raw_data_list: List[Any] = []
|
941 |
+
|
942 |
+
for i in range(batch_size):
|
943 |
+
instance_length = _inputs[i].shape[0]
|
944 |
+
rel_size = _segment_rel[i].shape[0]
|
945 |
+
inputs[i, :instance_length] = _inputs[i]
|
946 |
+
inputs_sub[i, :instance_length] = _inputs_sub[i]
|
947 |
+
context[i, :instance_length] = _context[i]
|
948 |
+
sample_ids[i, :instance_length] = _sample_ids[i]
|
949 |
+
segments[i, :instance_length] = _segments[i]
|
950 |
+
num_segments[i, :instance_length] = _num_segments[i]
|
951 |
+
segment_rel_offset[i, :instance_length] = _segment_rel_offset[i]
|
952 |
+
segment_rel[i, :rel_size] = _segment_rel[i]
|
953 |
+
|
954 |
+
span_begin = 0
|
955 |
+
for span_id, span_end in enumerate(_spans[i]):
|
956 |
+
spans[i, span_begin:span_end] = span_id
|
957 |
+
span_begin = span_end
|
958 |
+
length[i] = instance_length
|
959 |
+
raw_data_list.extend(_raw_data[i])
|
960 |
+
|
961 |
+
for j in range(instance_length):
|
962 |
+
idx, idx_sub = _inputs[i][j], _inputs_sub[i][j]
|
963 |
+
tgt_idx = idx
|
964 |
+
if idx_sub > 0:
|
965 |
+
# need to be in ext table
|
966 |
+
if (idx, idx_sub) not in batch_ext_table_map:
|
967 |
+
batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map)
|
968 |
+
batch_ext_table_ids.append(idx)
|
969 |
+
batch_ext_table_sub.append(idx_sub)
|
970 |
+
tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.vocab_size
|
971 |
+
if j > 1 and context[i, j - 1] == 0:
|
972 |
+
if idx != self.bos_token_id:
|
973 |
+
tgt[i, j - 1] = tgt_idx
|
974 |
+
else:
|
975 |
+
tgt[i, j - 1] = self.eos_token_id
|
976 |
+
if context[i, instance_length - 1] == 0:
|
977 |
+
tgt[i, instance_length - 1] = self.eos_token_id
|
978 |
+
|
979 |
+
if len(batch_ext_table_map) == 0:
|
980 |
+
# placeholder
|
981 |
+
batch_ext_table_ids.append(0)
|
982 |
+
batch_ext_table_sub.append(1)
|
983 |
+
|
984 |
+
return BatchEncoding({
|
985 |
+
"input_ids": inputs,
|
986 |
+
"input_id_sub": inputs_sub,
|
987 |
+
"length": length,
|
988 |
+
"context": context > 0,
|
989 |
+
"sample_ids": sample_ids,
|
990 |
+
"num_segments": num_segments,
|
991 |
+
"segment": segments,
|
992 |
+
"segment_rel_offset": segment_rel_offset,
|
993 |
+
"segment_rel": segment_rel,
|
994 |
+
"span": spans,
|
995 |
+
"labels": tgt,
|
996 |
+
"ext_table_ids": np.array(batch_ext_table_ids, dtype=np.int32),
|
997 |
+
"ext_table_sub": np.array(batch_ext_table_sub, dtype=np.int32)
|
998 |
+
}, tensor_type="pt")
|