Compression / core_utils_llmlingua2.py
qminh369's picture
Upload 4 files
10f85ab verified
raw
history blame
5.16 kB
import os
import random
import string
import numpy as np
import torch
from torch.utils.data import Dataset
class TokenClfDataset(Dataset): # Hàm tạo custom dataset
def __init__(
self,
texts,
max_len=512, # 256 (phobert) 512 (xlm-roberta)
tokenizer=None,
model_name="m_bert",
):
self.len = len(texts)
self.texts = texts
self.tokenizer = tokenizer
self.max_len = max_len
self.model_name = model_name
if "m_bert" in model_name:
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
self.unk_token = "[UNK]"
self.pad_token = "[PAD]"
self.mask_token = "[MASK]"
elif "xlm-roberta-large" in model_name:
self.bos_token = "<s>"
self.eos_token = "</s>"
self.sep_token = "</s>"
self.cls_token = "<s>"
self.unk_token = "<unk>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
elif "xlm-roberta" in model_name:
self.bos_token = "<s>"
self.eos_token = "</s>"
self.sep_token = "</s>"
self.cls_token = "<s>"
self.unk_token = "<unk>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
elif "phobert" in model_name:
self.bos_token = "<s>"
self.eos_token = "</s>"
self.sep_token = "</s>"
self.cls_token = "<s>"
self.unk_token = "<unk>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
#else: raise NotImplementedError()
def __getitem__(self, index):
text = self.texts[index]
tokenized_text = self.tokenizer.tokenize(text)
tokenized_text = (
[self.cls_token] + tokenized_text + [self.sep_token]
) # add special tokens
if len(tokenized_text) > self.max_len:
tokenized_text = tokenized_text[: self.max_len]
else:
tokenized_text = tokenized_text + [
self.pad_token for _ in range(self.max_len - len(tokenized_text))
]
attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text]
ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
return {
"ids": torch.tensor(ids, dtype=torch.long),
"mask": torch.tensor(attn_mask, dtype=torch.long),
}
def __len__(self):
return self.len
def seed_everything(seed: int):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def is_begin_of_new_word(token, model_name, force_tokens, token_map): # Thêm kí tự bắt đầu vào từ mới
if "m_bert" in model_name:
if token.lstrip("##") in force_tokens or token.lstrip("##") in set(
token_map.values()
):
return True
return not token.startswith("##")
elif "xlm-roberta-large" in model_name:
#print("xlm-roberta-large")
if (
token in string.punctuation
or token in force_tokens
or token in set(token_map.values())
):
return True
return token.startswith("▁") # check xem token có bắt đầu bằng kí tự "_" hay ko -> Trả về False
elif "xlm-roberta" in model_name:
#print("xlm-roberta-large")
if (
token in string.punctuation
or token in force_tokens
or token in set(token_map.values())
):
return True
return token.startswith("▁")
elif "phobert" in model_name:
#print("minh phobert")
#print("xlm-roberta-large")
if (
token in string.punctuation # điều kiện hoặc
or token in force_tokens
or token in set(token_map.values())
):
return True
#return token.startswith("▁") #
#return not token.startswith("▁")
#return not token.startswith("@@")
return not token.endswith("@@")
#return token.startswith("@@")
#else: raise NotImplementedError()
def replace_added_token(token, token_map):
for ori_token, new_token in token_map.items():
token = token.replace(new_token, ori_token)
return token
def get_pure_token(token, model_name): # hàm get pure token trả về token gốc (sau khi loại bỏ kí tự đặc biệt subword)
if "m_bert" in model_name:
return token.lstrip("##")
elif "xlm-roberta-large" in model_name:
return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ
elif "xlm-roberta" in model_name:
return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ
elif "phobert" in model_name:
#return token.lstrip("▁")
#return token.lstrip("@@")
return token.rstrip("@@")
# else: raise NotImplementedError()