|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
from abc import abstractmethod |
|
from pathlib import Path |
|
|
|
import json5 |
|
import torch |
|
import yaml |
|
|
|
|
|
|
|
class BaseDataset(torch.utils.data.Dataset): |
|
r"""Base dataset for training and validating.""" |
|
|
|
def __init__(self, args, cfg, is_valid=False): |
|
pass |
|
|
|
|
|
class BaseTestDataset(torch.utils.data.Dataset): |
|
r"""Test dataset for inference.""" |
|
|
|
def __init__(self, args=None, cfg=None, infer_type="from_dataset"): |
|
assert infer_type in ["from_dataset", "from_file"] |
|
|
|
self.args = args |
|
self.cfg = cfg |
|
self.infer_type = infer_type |
|
|
|
@abstractmethod |
|
def __getitem__(self, index): |
|
pass |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
def get_metadata(self): |
|
path = Path(self.args.source) |
|
if path.suffix == ".json" or path.suffix == ".jsonc": |
|
metadata = json5.load(open(self.args.source, "r")) |
|
elif path.suffix == ".yaml" or path.suffix == ".yml": |
|
metadata = yaml.full_load(open(self.args.source, "r")) |
|
else: |
|
raise ValueError(f"Unsupported file type: {path.suffix}") |
|
|
|
return metadata |
|
|