Spaces:
Runtime error
Runtime error
import copy | |
from typing import Iterator, Optional | |
from tqdm.auto import tqdm | |
from ..action import BaseAction | |
from ..export import BaseExporter | |
from ..model import ImageItem | |
from ..utils import task_ctx, get_task_names | |
class BaseDataSource: | |
def _iter(self) -> Iterator[ImageItem]: | |
raise NotImplementedError # pragma: no cover | |
def _iter_from(self) -> Iterator[ImageItem]: | |
yield from self._iter() | |
def __iter__(self) -> Iterator[ImageItem]: | |
yield from self._iter_from() | |
def __or__(self, other): | |
from .compose import ParallelDataSource | |
if isinstance(self, ParallelDataSource): | |
if isinstance(other, ParallelDataSource): | |
return ParallelDataSource(*self.sources, *other.sources) | |
else: | |
return ParallelDataSource(*self.sources, other) | |
else: | |
if isinstance(other, ParallelDataSource): | |
return ParallelDataSource(self, *other.sources) | |
else: | |
return ParallelDataSource(self, other) | |
def __add__(self, other): | |
from .compose import ComposedDataSource | |
if isinstance(self, ComposedDataSource): | |
if isinstance(other, ComposedDataSource): | |
return ComposedDataSource(*self.sources, *other.sources) | |
else: | |
return ComposedDataSource(*self.sources, other) | |
else: | |
if isinstance(other, ComposedDataSource): | |
return ComposedDataSource(self, *other.sources) | |
else: | |
return ComposedDataSource(self, other) | |
def __getitem__(self, item): | |
from ..action import SliceSelectAction | |
if isinstance(item, slice): | |
return self.attach(SliceSelectAction(item.start, item.stop, item.step)) | |
else: | |
raise TypeError(f'Data source\'s getitem only accept slices, but {item!r} found.') | |
def attach(self, *actions: BaseAction) -> 'AttachedDataSource': | |
return AttachedDataSource(self, *actions) | |
def export(self, exporter: BaseExporter, name: Optional[str] = None): | |
exporter = copy.deepcopy(exporter) | |
exporter.reset() | |
with task_ctx(name): | |
return exporter.export_from(iter(self)) | |
class RootDataSource(BaseDataSource): | |
def _iter(self) -> Iterator[ImageItem]: | |
raise NotImplementedError # pragma: no cover | |
def _iter_from(self) -> Iterator[ImageItem]: | |
names = get_task_names() | |
if names: | |
desc = f'{self.__class__.__name__} - {".".join(names)}' | |
else: | |
desc = f'{self.__class__.__name__}' | |
for item in tqdm(self._iter(), desc=desc): | |
yield item | |
class AttachedDataSource(BaseDataSource): | |
def __init__(self, source: BaseDataSource, *actions: BaseAction): | |
self.source = source | |
self.actions = actions | |
def _iter(self) -> Iterator[ImageItem]: | |
t = self.source | |
for action in self.actions: | |
action = copy.deepcopy(action) | |
action.reset() | |
t = action.iter_from(t) | |
yield from t | |
class EmptySource(BaseDataSource): | |
def _iter(self) -> Iterator[ImageItem]: | |
yield from [] | |