Spaces:
Runtime error
Runtime error
import json | |
import os | |
import pickle | |
import signal | |
import threading | |
import time | |
import zipfile | |
import gdown | |
import numpy as np | |
import requests | |
import torch | |
import tqdm | |
from autocuda import auto_cuda, auto_cuda_name | |
from findfile import find_files, find_cwd_file, find_file | |
from termcolor import colored | |
from functools import wraps | |
from update_checker import parse_version | |
from anonymous_demo import __version__ | |
def save_args(config, save_path): | |
f = open(os.path.join(save_path), mode="w", encoding="utf8") | |
for arg in config.args: | |
if config.args_call_count[arg]: | |
f.write("{}: {}\n".format(arg, config.args[arg])) | |
f.close() | |
def print_args(config, logger=None, mode=0): | |
args = [key for key in sorted(config.args.keys())] | |
for arg in args: | |
if logger: | |
logger.info( | |
"{0}:{1}\t-->\tCalling Count:{2}".format( | |
arg, config.args[arg], config.args_call_count[arg] | |
) | |
) | |
else: | |
print( | |
"{0}:{1}\t-->\tCalling Count:{2}".format( | |
arg, config.args[arg], config.args_call_count[arg] | |
) | |
) | |
def check_and_fix_labels(label_set: set, label_name, all_data, opt): | |
if "-100" in label_set: | |
label_to_index = { | |
origin_label: int(idx) - 1 if origin_label != "-100" else -100 | |
for origin_label, idx in zip(sorted(label_set), range(len(label_set))) | |
} | |
index_to_label = { | |
int(idx) - 1 if origin_label != "-100" else -100: origin_label | |
for origin_label, idx in zip(sorted(label_set), range(len(label_set))) | |
} | |
else: | |
label_to_index = { | |
origin_label: int(idx) | |
for origin_label, idx in zip(sorted(label_set), range(len(label_set))) | |
} | |
index_to_label = { | |
int(idx): origin_label | |
for origin_label, idx in zip(sorted(label_set), range(len(label_set))) | |
} | |
if "index_to_label" not in opt.args: | |
opt.index_to_label = index_to_label | |
opt.label_to_index = label_to_index | |
if opt.index_to_label != index_to_label: | |
opt.index_to_label.update(index_to_label) | |
opt.label_to_index.update(label_to_index) | |
num_label = {l: 0 for l in label_set} | |
num_label["Sum"] = len(all_data) | |
for item in all_data: | |
try: | |
num_label[item[label_name]] += 1 | |
item[label_name] = label_to_index[item[label_name]] | |
except Exception as e: | |
# print(e) | |
num_label[item.polarity] += 1 | |
item.polarity = label_to_index[item.polarity] | |
print("Dataset Label Details: {}".format(num_label)) | |
def check_and_fix_IOB_labels(label_map, opt): | |
index_to_IOB_label = { | |
int(label_map[origin_label]): origin_label for origin_label in label_map | |
} | |
opt.index_to_IOB_label = index_to_IOB_label | |
def get_device(auto_device): | |
if isinstance(auto_device, str) and auto_device == "allcuda": | |
device = "cuda" | |
elif isinstance(auto_device, str): | |
device = auto_device | |
elif isinstance(auto_device, bool): | |
device = auto_cuda() if auto_device else "cpu" | |
else: | |
device = auto_cuda() | |
try: | |
torch.device(device) | |
except RuntimeError as e: | |
print( | |
colored("Device assignment error: {}, redirect to CPU".format(e), "red") | |
) | |
device = "cpu" | |
device_name = auto_cuda_name() | |
return device, device_name | |
def _load_word_vec(path, word2idx=None, embed_dim=300): | |
fin = open(path, "r", encoding="utf-8", newline="\n", errors="ignore") | |
word_vec = {} | |
for line in tqdm.tqdm(fin.readlines(), postfix="Loading embedding file..."): | |
tokens = line.rstrip().split() | |
word, vec = " ".join(tokens[:-embed_dim]), tokens[-embed_dim:] | |
if word in word2idx.keys(): | |
word_vec[word] = np.asarray(vec, dtype="float32") | |
return word_vec | |
def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt): | |
if not os.path.exists("run"): | |
os.makedirs("run") | |
embed_matrix_path = "run/{}".format(os.path.join(opt.dataset_name, dat_fname)) | |
if os.path.exists(embed_matrix_path): | |
print( | |
colored( | |
"Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)".format( | |
embed_matrix_path | |
), | |
"green", | |
) | |
) | |
embedding_matrix = pickle.load(open(embed_matrix_path, "rb")) | |
else: | |
glove_path = prepare_glove840_embedding(embed_matrix_path) | |
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim)) | |
word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim) | |
for word, i in tqdm.tqdm( | |
word2idx.items(), | |
postfix=colored("Building embedding_matrix {}".format(dat_fname), "yellow"), | |
): | |
vec = word_vec.get(word) | |
if vec is not None: | |
embedding_matrix[i] = vec | |
pickle.dump(embedding_matrix, open(embed_matrix_path, "wb")) | |
return embedding_matrix | |
def pad_and_truncate( | |
sequence, maxlen, dtype="int64", padding="post", truncating="post", value=0 | |
): | |
x = (np.ones(maxlen) * value).astype(dtype) | |
if truncating == "pre": | |
trunc = sequence[-maxlen:] | |
else: | |
trunc = sequence[:maxlen] | |
trunc = np.asarray(trunc, dtype=dtype) | |
if padding == "post": | |
x[: len(trunc)] = trunc | |
else: | |
x[-len(trunc) :] = trunc | |
return x | |
class TransformerConnectionError(ValueError): | |
def __init__(self): | |
pass | |
def retry(f): | |
def decorated(*args, **kwargs): | |
count = 5 | |
while count: | |
try: | |
return f(*args, **kwargs) | |
except ( | |
TransformerConnectionError, | |
requests.exceptions.RequestException, | |
requests.exceptions.ConnectionError, | |
requests.exceptions.HTTPError, | |
requests.exceptions.ConnectTimeout, | |
requests.exceptions.ProxyError, | |
requests.exceptions.SSLError, | |
requests.exceptions.BaseHTTPError, | |
) as e: | |
print(colored("Training Exception: {}, will retry later".format(e))) | |
time.sleep(60) | |
count -= 1 | |
return decorated | |
def save_json(dic, save_path): | |
if isinstance(dic, str): | |
dic = eval(dic) | |
with open(save_path, "w", encoding="utf-8") as f: | |
# f.write(str(dict)) | |
str_ = json.dumps(dic, ensure_ascii=False) | |
f.write(str_) | |
def load_json(save_path): | |
with open(save_path, "r", encoding="utf-8") as f: | |
data = f.readline().strip() | |
print(type(data), data) | |
dic = json.loads(data) | |
return dic | |
def init_optimizer(optimizer): | |
optimizers = { | |
"adadelta": torch.optim.Adadelta, # default lr=1.0 | |
"adagrad": torch.optim.Adagrad, # default lr=0.01 | |
"adam": torch.optim.Adam, # default lr=0.001 | |
"adamax": torch.optim.Adamax, # default lr=0.002 | |
"asgd": torch.optim.ASGD, # default lr=0.01 | |
"rmsprop": torch.optim.RMSprop, # default lr=0.01 | |
"sgd": torch.optim.SGD, | |
"adamw": torch.optim.AdamW, | |
torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0 | |
torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01 | |
torch.optim.Adam: torch.optim.Adam, # default lr=0.001 | |
torch.optim.Adamax: torch.optim.Adamax, # default lr=0.002 | |
torch.optim.ASGD: torch.optim.ASGD, # default lr=0.01 | |
torch.optim.RMSprop: torch.optim.RMSprop, # default lr=0.01 | |
torch.optim.SGD: torch.optim.SGD, | |
torch.optim.AdamW: torch.optim.AdamW, | |
} | |
if optimizer in optimizers: | |
return optimizers[optimizer] | |
elif hasattr(torch.optim, optimizer.__name__): | |
return optimizer | |
else: | |
raise KeyError( | |
"Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer".format( | |
optimizer | |
) | |
) | |