Spaces:
Runtime error
Runtime error
File size: 3,198 Bytes
4f8ad24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import copy
from typing import Iterator, Optional
from tqdm.auto import tqdm
from ..action import BaseAction
from ..export import BaseExporter
from ..model import ImageItem
from ..utils import task_ctx, get_task_names
class BaseDataSource:
def _iter(self) -> Iterator[ImageItem]:
raise NotImplementedError # pragma: no cover
def _iter_from(self) -> Iterator[ImageItem]:
yield from self._iter()
def __iter__(self) -> Iterator[ImageItem]:
yield from self._iter_from()
def __or__(self, other):
from .compose import ParallelDataSource
if isinstance(self, ParallelDataSource):
if isinstance(other, ParallelDataSource):
return ParallelDataSource(*self.sources, *other.sources)
else:
return ParallelDataSource(*self.sources, other)
else:
if isinstance(other, ParallelDataSource):
return ParallelDataSource(self, *other.sources)
else:
return ParallelDataSource(self, other)
def __add__(self, other):
from .compose import ComposedDataSource
if isinstance(self, ComposedDataSource):
if isinstance(other, ComposedDataSource):
return ComposedDataSource(*self.sources, *other.sources)
else:
return ComposedDataSource(*self.sources, other)
else:
if isinstance(other, ComposedDataSource):
return ComposedDataSource(self, *other.sources)
else:
return ComposedDataSource(self, other)
def __getitem__(self, item):
from ..action import SliceSelectAction
if isinstance(item, slice):
return self.attach(SliceSelectAction(item.start, item.stop, item.step))
else:
raise TypeError(f'Data source\'s getitem only accept slices, but {item!r} found.')
def attach(self, *actions: BaseAction) -> 'AttachedDataSource':
return AttachedDataSource(self, *actions)
def export(self, exporter: BaseExporter, name: Optional[str] = None):
exporter = copy.deepcopy(exporter)
exporter.reset()
with task_ctx(name):
return exporter.export_from(iter(self))
class RootDataSource(BaseDataSource):
def _iter(self) -> Iterator[ImageItem]:
raise NotImplementedError # pragma: no cover
def _iter_from(self) -> Iterator[ImageItem]:
names = get_task_names()
if names:
desc = f'{self.__class__.__name__} - {".".join(names)}'
else:
desc = f'{self.__class__.__name__}'
for item in tqdm(self._iter(), desc=desc):
yield item
class AttachedDataSource(BaseDataSource):
def __init__(self, source: BaseDataSource, *actions: BaseAction):
self.source = source
self.actions = actions
def _iter(self) -> Iterator[ImageItem]:
t = self.source
for action in self.actions:
action = copy.deepcopy(action)
action.reset()
t = action.iter_from(t)
yield from t
class EmptySource(BaseDataSource):
def _iter(self) -> Iterator[ImageItem]:
yield from []
|