from typing import List, Set import torch def sorted_list(s: Set[str]) -> List[str]: return sorted(list(set(s))) def device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def nested_to_device(s): # s is either a tensor or a dictionary if isinstance(s, torch.Tensor): return s.to(device()) return {k: v.to(device()) for k, v in s.items()} def nested_apply(h, s): # h is an unary function, s is one of N, tuple of N, list of N, or set of N if isinstance(s, str): return h(s) ret = [nested_apply(h, i) for i in s] if isinstance(s, tuple): return tuple(ret) if isinstance(s, set): return set(ret) return ret