LittleApple-fp16's picture
Upload 88 files
4f8ad24
raw
history blame
3.2 kB
import copy
from typing import Iterator, Optional
from tqdm.auto import tqdm
from ..action import BaseAction
from ..export import BaseExporter
from ..model import ImageItem
from ..utils import task_ctx, get_task_names
class BaseDataSource:
def _iter(self) -> Iterator[ImageItem]:
raise NotImplementedError # pragma: no cover
def _iter_from(self) -> Iterator[ImageItem]:
yield from self._iter()
def __iter__(self) -> Iterator[ImageItem]:
yield from self._iter_from()
def __or__(self, other):
from .compose import ParallelDataSource
if isinstance(self, ParallelDataSource):
if isinstance(other, ParallelDataSource):
return ParallelDataSource(*self.sources, *other.sources)
else:
return ParallelDataSource(*self.sources, other)
else:
if isinstance(other, ParallelDataSource):
return ParallelDataSource(self, *other.sources)
else:
return ParallelDataSource(self, other)
def __add__(self, other):
from .compose import ComposedDataSource
if isinstance(self, ComposedDataSource):
if isinstance(other, ComposedDataSource):
return ComposedDataSource(*self.sources, *other.sources)
else:
return ComposedDataSource(*self.sources, other)
else:
if isinstance(other, ComposedDataSource):
return ComposedDataSource(self, *other.sources)
else:
return ComposedDataSource(self, other)
def __getitem__(self, item):
from ..action import SliceSelectAction
if isinstance(item, slice):
return self.attach(SliceSelectAction(item.start, item.stop, item.step))
else:
raise TypeError(f'Data source\'s getitem only accept slices, but {item!r} found.')
def attach(self, *actions: BaseAction) -> 'AttachedDataSource':
return AttachedDataSource(self, *actions)
def export(self, exporter: BaseExporter, name: Optional[str] = None):
exporter = copy.deepcopy(exporter)
exporter.reset()
with task_ctx(name):
return exporter.export_from(iter(self))
class RootDataSource(BaseDataSource):
def _iter(self) -> Iterator[ImageItem]:
raise NotImplementedError # pragma: no cover
def _iter_from(self) -> Iterator[ImageItem]:
names = get_task_names()
if names:
desc = f'{self.__class__.__name__} - {".".join(names)}'
else:
desc = f'{self.__class__.__name__}'
for item in tqdm(self._iter(), desc=desc):
yield item
class AttachedDataSource(BaseDataSource):
def __init__(self, source: BaseDataSource, *actions: BaseAction):
self.source = source
self.actions = actions
def _iter(self) -> Iterator[ImageItem]:
t = self.source
for action in self.actions:
action = copy.deepcopy(action)
action.reset()
t = action.iter_from(t)
yield from t
class EmptySource(BaseDataSource):
def _iter(self) -> Iterator[ImageItem]:
yield from []