Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding:utf-8 -*- | |
import os | |
import sys | |
import shutil | |
from tqdm import tqdm | |
import yaml | |
import random | |
import importlib | |
from PIL import Image | |
import imageio | |
import numpy as np | |
import cv2 | |
import torch | |
from torchvision import utils | |
from scipy.interpolate import PchipInterpolator | |
def split_filename(filename): | |
absname = os.path.abspath(filename) | |
dirname, basename = os.path.split(absname) | |
split_tmp = basename.rsplit('.', maxsplit=1) | |
if len(split_tmp) == 2: | |
rootname, extname = split_tmp | |
elif len(split_tmp) == 1: | |
rootname = split_tmp[0] | |
extname = None | |
else: | |
raise ValueError("programming error!") | |
return dirname, rootname, extname | |
def data2file(data, filename, type=None, override=False, printable=False, **kwargs): | |
dirname, rootname, extname = split_filename(filename) | |
print_did_not_save_flag = True | |
if type: | |
extname = type | |
if not os.path.exists(dirname): | |
os.makedirs(dirname, exist_ok=True) | |
if not os.path.exists(filename) or override: | |
if extname in ['jpg', 'png', 'jpeg']: | |
utils.save_image(data, filename, **kwargs) | |
elif extname == 'gif': | |
imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0) | |
elif extname == 'txt': | |
if kwargs is None: | |
kwargs = {} | |
max_step = kwargs.get('max_step') | |
if max_step is None: | |
max_step = np.Infinity | |
with open(filename, 'w', encoding='utf-8') as f: | |
for i, e in enumerate(data): | |
if i < max_step: | |
f.write(str(e) + '\n') | |
else: | |
break | |
else: | |
raise ValueError('Do not support this type') | |
if printable: print('Saved data to %s' % os.path.abspath(filename)) | |
else: | |
if print_did_not_save_flag: print( | |
'Did not save data to %s because file exists and override is False' % os.path.abspath( | |
filename)) | |
def file2data(filename, type=None, printable=True, **kwargs): | |
dirname, rootname, extname = split_filename(filename) | |
print_load_flag = True | |
if type: | |
extname = type | |
if extname in ['pth', 'ckpt', 'bin']: | |
data = torch.load(filename, map_location=kwargs.get('map_location')) | |
if "state_dict" in data.keys(): | |
data = data["state_dict"] | |
data = {k.replace("_forward_module.", ""):v for k,v in data.items()} | |
elif extname == 'txt': | |
top = kwargs.get('top', None) | |
with open(filename, encoding='utf-8') as f: | |
if top: | |
data = [f.readline() for _ in range(top)] | |
else: | |
data = [e for e in f.read().split('\n') if e] | |
elif extname == 'yaml': | |
with open(filename, 'r') as f: | |
data = yaml.load(f) | |
else: | |
raise ValueError('type can only support h5, npy, json, txt') | |
if printable: | |
if print_load_flag: | |
print('Loaded data from %s' % os.path.abspath(filename)) | |
return data | |
def ensure_dirname(dirname, override=False): | |
if os.path.exists(dirname) and override: | |
print('Removing dirname: %s' % os.path.abspath(dirname)) | |
try: | |
shutil.rmtree(dirname) | |
except OSError as e: | |
raise ValueError('Failed to delete %s because %s' % (dirname, e)) | |
if not os.path.exists(dirname): | |
print('Making dirname: %s' % os.path.abspath(dirname)) | |
os.makedirs(dirname, exist_ok=True) | |
def import_filename(filename): | |
spec = importlib.util.spec_from_file_location("mymodule", filename) | |
module = importlib.util.module_from_spec(spec) | |
sys.modules[spec.name] = module | |
spec.loader.exec_module(module) | |
return module | |
def adaptively_load_state_dict(target, state_dict): | |
target_dict = target.state_dict() | |
try: | |
common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()} | |
# unmatch_dict = {k: v for k, v in state_dict.items() if k not in target_dict or v.size() != target_dict[k].size()} | |
except Exception as e: | |
print('load error %s', e) | |
common_dict = {k: v for k, v in state_dict.items() if k in target_dict} | |
if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \ | |
target.state_dict()['param_groups'][0]['params']: | |
print('Detected mismatch params, auto adapte state_dict to current') | |
common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params'] | |
target_dict.update(common_dict) | |
target.load_state_dict(target_dict) | |
missing_keys = [k for k in target_dict.keys() if k not in common_dict] | |
unexpected_keys = [k for k in state_dict.keys() if k not in common_dict] | |
if len(unexpected_keys) != 0: | |
print( | |
f"Some weights of state_dict were not used in target: {unexpected_keys}" | |
) | |
if len(missing_keys) != 0: | |
print( | |
f"Some weights of state_dict are missing used in target {missing_keys}" | |
) | |
if len(unexpected_keys) == 0 and len(missing_keys) == 0: | |
print("Strictly Loaded state_dict.") | |
def set_seed(seed=42): | |
random.seed(seed) | |
os.environ['PYHTONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def image2pil(filename): | |
return Image.open(filename) | |
def image2arr(filename): | |
pil = image2pil(filename) | |
return pil2arr(pil) | |
def pil2arr(pil): | |
if isinstance(pil, list): | |
arr = np.array( | |
[np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil]) | |
else: | |
arr = np.array(pil) | |
return arr | |
def arr2pil(arr): | |
if arr.ndim == 3: | |
return Image.fromarray(arr.astype('uint8'), 'RGB') | |
elif arr.ndim == 4: | |
return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)] | |
else: | |
raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim) | |
def interpolate_trajectory(points, n_points): | |
x = [point[0] for point in points] | |
y = [point[1] for point in points] | |
t = np.linspace(0, 1, len(points)) | |
fx = PchipInterpolator(t, x) | |
fy = PchipInterpolator(t, y) | |
new_t = np.linspace(0, 1, n_points) | |
new_x = fx(new_t) | |
new_y = fy(new_t) | |
new_points = list(zip(new_x, new_y)) | |
return new_points | |
def visualize_drag(background_image_path, splited_tracks, drag_mode, width, height, model_length): | |
if drag_mode=='object': | |
color = (255, 0, 0, 255) | |
elif drag_mode=='camera': | |
color = (0, 0, 255, 255) | |
background_image = Image.open(background_image_path).convert('RGBA') | |
background_image = background_image.resize((width, height)) | |
w, h = background_image.size | |
transparent_background = np.array(background_image) | |
transparent_background[:, :, -1] = 128 | |
transparent_background = Image.fromarray(transparent_background) | |
# Create a transparent layer with the same size as the background image | |
transparent_layer = np.zeros((h, w, 4)) | |
for splited_track in splited_tracks: | |
if len(splited_track) > 1: | |
splited_track = interpolate_trajectory(splited_track, model_length) | |
splited_track = splited_track[:model_length] | |
for i in range(len(splited_track)-1): | |
start_point = (int(splited_track[i][0]), int(splited_track[i][1])) | |
end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1])) | |
vx = end_point[0] - start_point[0] | |
vy = end_point[1] - start_point[1] | |
arrow_length = np.sqrt(vx**2 + vy**2) | |
if i == len(splited_track)-2: | |
cv2.arrowedLine(transparent_layer, start_point, end_point, color, 2, tipLength=8 / arrow_length) | |
else: | |
cv2.line(transparent_layer, start_point, end_point, color, 2) | |
else: | |
cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 5, color, -1) | |
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) | |
trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) | |
return trajectory_map, transparent_layer |