Spaces:
Runtime error
Runtime error
Upload 422 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- inference/__init__.py +3 -0
- inference/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/__init__.py +52 -0
- inference/core/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/__pycache__/constants.cpython-310.pyc +0 -0
- inference/core/__pycache__/env.cpython-310.pyc +0 -0
- inference/core/__pycache__/exceptions.cpython-310.pyc +0 -0
- inference/core/__pycache__/logger.cpython-310.pyc +0 -0
- inference/core/__pycache__/nms.cpython-310.pyc +0 -0
- inference/core/__pycache__/roboflow_api.cpython-310.pyc +0 -0
- inference/core/__pycache__/usage.cpython-310.pyc +0 -0
- inference/core/__pycache__/version.cpython-310.pyc +0 -0
- inference/core/active_learning/__init__.py +0 -0
- inference/core/active_learning/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/accounting.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/batching.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/cache_operations.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/configuration.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/core.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/entities.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/middlewares.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/post_processing.cpython-310.pyc +0 -0
- inference/core/active_learning/__pycache__/utils.cpython-310.pyc +0 -0
- inference/core/active_learning/accounting.py +96 -0
- inference/core/active_learning/batching.py +26 -0
- inference/core/active_learning/cache_operations.py +293 -0
- inference/core/active_learning/configuration.py +203 -0
- inference/core/active_learning/core.py +219 -0
- inference/core/active_learning/entities.py +141 -0
- inference/core/active_learning/middlewares.py +307 -0
- inference/core/active_learning/post_processing.py +128 -0
- inference/core/active_learning/samplers/__init__.py +0 -0
- inference/core/active_learning/samplers/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/__pycache__/close_to_threshold.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/__pycache__/contains_classes.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/__pycache__/number_of_detections.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/__pycache__/random.cpython-310.pyc +0 -0
- inference/core/active_learning/samplers/close_to_threshold.py +227 -0
- inference/core/active_learning/samplers/contains_classes.py +58 -0
- inference/core/active_learning/samplers/number_of_detections.py +107 -0
- inference/core/active_learning/samplers/random.py +37 -0
- inference/core/active_learning/utils.py +16 -0
- inference/core/cache/__init__.py +22 -0
- inference/core/cache/__pycache__/__init__.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/base.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/memory.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/model_artifacts.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/redis.cpython-310.pyc +0 -0
- inference/core/cache/__pycache__/serializers.cpython-310.pyc +0 -0
- inference/core/cache/base.py +130 -0
inference/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from inference.core.interfaces.stream.stream import Stream # isort:skip
|
2 |
+
from inference.core.interfaces.stream.inference_pipeline import InferencePipeline
|
3 |
+
from inference.models.utils import get_roboflow_model
|
inference/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (399 Bytes). View file
|
|
inference/core/__init__.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import time
|
3 |
+
|
4 |
+
import requests
|
5 |
+
|
6 |
+
from inference.core.env import DISABLE_VERSION_CHECK, VERSION_CHECK_MODE
|
7 |
+
from inference.core.logger import logger
|
8 |
+
from inference.core.version import __version__
|
9 |
+
|
10 |
+
latest_release = None
|
11 |
+
last_checked = 0
|
12 |
+
cache_duration = 86400 # 24 hours
|
13 |
+
log_frequency = 300 # 5 minutes
|
14 |
+
|
15 |
+
|
16 |
+
def get_latest_release_version():
|
17 |
+
global latest_release, last_checked
|
18 |
+
now = time.time()
|
19 |
+
if latest_release is None or now - last_checked > cache_duration:
|
20 |
+
try:
|
21 |
+
logger.debug("Checking for latest inference release version...")
|
22 |
+
response = requests.get(
|
23 |
+
"https://api.github.com/repos/roboflow/inference/releases/latest"
|
24 |
+
)
|
25 |
+
response.raise_for_status()
|
26 |
+
latest_release = response.json()["tag_name"].lstrip("v")
|
27 |
+
last_checked = now
|
28 |
+
except requests.exceptions.RequestException:
|
29 |
+
pass
|
30 |
+
|
31 |
+
|
32 |
+
def check_latest_release_against_current():
|
33 |
+
get_latest_release_version()
|
34 |
+
if latest_release is not None and latest_release != __version__:
|
35 |
+
logger.warning(
|
36 |
+
f"Your inference package version {__version__} is out of date! Please upgrade to version {latest_release} of inference for the latest features and bug fixes by running `pip install --upgrade inference`."
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def check_latest_release_against_current_continuous():
|
41 |
+
while True:
|
42 |
+
check_latest_release_against_current()
|
43 |
+
time.sleep(log_frequency)
|
44 |
+
|
45 |
+
|
46 |
+
if not DISABLE_VERSION_CHECK:
|
47 |
+
if VERSION_CHECK_MODE == "continuous":
|
48 |
+
t = threading.Thread(target=check_latest_release_against_current_continuous)
|
49 |
+
t.daemon = True
|
50 |
+
t.start()
|
51 |
+
else:
|
52 |
+
check_latest_release_against_current()
|
inference/core/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.73 kB). View file
|
|
inference/core/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (371 Bytes). View file
|
|
inference/core/__pycache__/env.cpython-310.pyc
ADDED
Binary file (6.87 kB). View file
|
|
inference/core/__pycache__/exceptions.cpython-310.pyc
ADDED
Binary file (6.17 kB). View file
|
|
inference/core/__pycache__/logger.cpython-310.pyc
ADDED
Binary file (551 Bytes). View file
|
|
inference/core/__pycache__/nms.cpython-310.pyc
ADDED
Binary file (4.74 kB). View file
|
|
inference/core/__pycache__/roboflow_api.cpython-310.pyc
ADDED
Binary file (10.1 kB). View file
|
|
inference/core/__pycache__/usage.cpython-310.pyc
ADDED
Binary file (1.85 kB). View file
|
|
inference/core/__pycache__/version.cpython-310.pyc
ADDED
Binary file (250 Bytes). View file
|
|
inference/core/active_learning/__init__.py
ADDED
File without changes
|
inference/core/active_learning/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (192 Bytes). View file
|
|
inference/core/active_learning/__pycache__/accounting.cpython-310.pyc
ADDED
Binary file (2.76 kB). View file
|
|
inference/core/active_learning/__pycache__/batching.cpython-310.pyc
ADDED
Binary file (921 Bytes). View file
|
|
inference/core/active_learning/__pycache__/cache_operations.cpython-310.pyc
ADDED
Binary file (5.9 kB). View file
|
|
inference/core/active_learning/__pycache__/configuration.cpython-310.pyc
ADDED
Binary file (5.3 kB). View file
|
|
inference/core/active_learning/__pycache__/core.cpython-310.pyc
ADDED
Binary file (5.2 kB). View file
|
|
inference/core/active_learning/__pycache__/entities.cpython-310.pyc
ADDED
Binary file (4.72 kB). View file
|
|
inference/core/active_learning/__pycache__/middlewares.cpython-310.pyc
ADDED
Binary file (8.68 kB). View file
|
|
inference/core/active_learning/__pycache__/post_processing.cpython-310.pyc
ADDED
Binary file (2.94 kB). View file
|
|
inference/core/active_learning/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (852 Bytes). View file
|
|
inference/core/active_learning/accounting.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
from inference.core.entities.types import DatasetID, WorkspaceID
|
4 |
+
from inference.core.roboflow_api import (
|
5 |
+
get_roboflow_labeling_batches,
|
6 |
+
get_roboflow_labeling_jobs,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
def image_can_be_submitted_to_batch(
|
11 |
+
batch_name: str,
|
12 |
+
workspace_id: WorkspaceID,
|
13 |
+
dataset_id: DatasetID,
|
14 |
+
max_batch_images: Optional[int],
|
15 |
+
api_key: str,
|
16 |
+
) -> bool:
|
17 |
+
"""Check if an image can be submitted to a batch.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
batch_name: Name of the batch.
|
21 |
+
workspace_id: ID of the workspace.
|
22 |
+
dataset_id: ID of the dataset.
|
23 |
+
max_batch_images: Maximum number of images allowed in the batch.
|
24 |
+
api_key: API key to use for the request.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
True if the image can be submitted to the batch, False otherwise.
|
28 |
+
"""
|
29 |
+
if max_batch_images is None:
|
30 |
+
return True
|
31 |
+
labeling_batches = get_roboflow_labeling_batches(
|
32 |
+
api_key=api_key,
|
33 |
+
workspace_id=workspace_id,
|
34 |
+
dataset_id=dataset_id,
|
35 |
+
)
|
36 |
+
matching_labeling_batch = get_matching_labeling_batch(
|
37 |
+
all_labeling_batches=labeling_batches["batches"],
|
38 |
+
batch_name=batch_name,
|
39 |
+
)
|
40 |
+
if matching_labeling_batch is None:
|
41 |
+
return max_batch_images > 0
|
42 |
+
batch_images_under_labeling = 0
|
43 |
+
if matching_labeling_batch["numJobs"] > 0:
|
44 |
+
labeling_jobs = get_roboflow_labeling_jobs(
|
45 |
+
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
|
46 |
+
)
|
47 |
+
batch_images_under_labeling = get_images_in_labeling_jobs_of_specific_batch(
|
48 |
+
all_labeling_jobs=labeling_jobs["jobs"],
|
49 |
+
batch_id=matching_labeling_batch["id"],
|
50 |
+
)
|
51 |
+
total_batch_images = matching_labeling_batch["images"] + batch_images_under_labeling
|
52 |
+
return max_batch_images > total_batch_images
|
53 |
+
|
54 |
+
|
55 |
+
def get_matching_labeling_batch(
|
56 |
+
all_labeling_batches: List[dict],
|
57 |
+
batch_name: str,
|
58 |
+
) -> Optional[dict]:
|
59 |
+
"""Get the matching labeling batch.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
all_labeling_batches: All labeling batches.
|
63 |
+
batch_name: Name of the batch.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
The matching labeling batch if found, None otherwise.
|
67 |
+
|
68 |
+
"""
|
69 |
+
matching_batch = None
|
70 |
+
for labeling_batch in all_labeling_batches:
|
71 |
+
if labeling_batch["name"] == batch_name:
|
72 |
+
matching_batch = labeling_batch
|
73 |
+
break
|
74 |
+
return matching_batch
|
75 |
+
|
76 |
+
|
77 |
+
def get_images_in_labeling_jobs_of_specific_batch(
|
78 |
+
all_labeling_jobs: List[dict],
|
79 |
+
batch_id: str,
|
80 |
+
) -> int:
|
81 |
+
"""Get the number of images in labeling jobs of a specific batch.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
all_labeling_jobs: All labeling jobs.
|
85 |
+
batch_id: ID of the batch.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
The number of images in labeling jobs of the batch.
|
89 |
+
|
90 |
+
"""
|
91 |
+
|
92 |
+
matching_jobs = []
|
93 |
+
for labeling_job in all_labeling_jobs:
|
94 |
+
if batch_id in labeling_job["sourceBatch"]:
|
95 |
+
matching_jobs.append(labeling_job)
|
96 |
+
return sum(job["numImages"] for job in matching_jobs)
|
inference/core/active_learning/batching.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inference.core.active_learning.entities import (
|
2 |
+
ActiveLearningConfiguration,
|
3 |
+
BatchReCreationInterval,
|
4 |
+
)
|
5 |
+
from inference.core.active_learning.utils import (
|
6 |
+
generate_start_timestamp_for_this_month,
|
7 |
+
generate_start_timestamp_for_this_week,
|
8 |
+
generate_today_timestamp,
|
9 |
+
)
|
10 |
+
|
11 |
+
RECREATION_INTERVAL2TIMESTAMP_GENERATOR = {
|
12 |
+
BatchReCreationInterval.DAILY: generate_today_timestamp,
|
13 |
+
BatchReCreationInterval.WEEKLY: generate_start_timestamp_for_this_week,
|
14 |
+
BatchReCreationInterval.MONTHLY: generate_start_timestamp_for_this_month,
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def generate_batch_name(configuration: ActiveLearningConfiguration) -> str:
|
19 |
+
batch_name = configuration.batches_name_prefix
|
20 |
+
if configuration.batch_recreation_interval is BatchReCreationInterval.NEVER:
|
21 |
+
return batch_name
|
22 |
+
timestamp_generator = RECREATION_INTERVAL2TIMESTAMP_GENERATOR[
|
23 |
+
configuration.batch_recreation_interval
|
24 |
+
]
|
25 |
+
timestamp = timestamp_generator()
|
26 |
+
return f"{batch_name}_{timestamp}"
|
inference/core/active_learning/cache_operations.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from datetime import datetime
|
4 |
+
from typing import Generator, List, Optional, OrderedDict, Union
|
5 |
+
|
6 |
+
import redis.lock
|
7 |
+
|
8 |
+
from inference.core import logger
|
9 |
+
from inference.core.active_learning.entities import StrategyLimit, StrategyLimitType
|
10 |
+
from inference.core.active_learning.utils import TIMESTAMP_FORMAT
|
11 |
+
from inference.core.cache.base import BaseCache
|
12 |
+
|
13 |
+
MAX_LOCK_TIME = 5
|
14 |
+
SECONDS_IN_HOUR = 60 * 60
|
15 |
+
USAGE_KEY = "usage"
|
16 |
+
|
17 |
+
LIMIT_TYPE2KEY_INFIX_GENERATOR = {
|
18 |
+
StrategyLimitType.MINUTELY: lambda: f"minute_{datetime.utcnow().minute}",
|
19 |
+
StrategyLimitType.HOURLY: lambda: f"hour_{datetime.utcnow().hour}",
|
20 |
+
StrategyLimitType.DAILY: lambda: f"day_{datetime.utcnow().strftime(TIMESTAMP_FORMAT)}",
|
21 |
+
}
|
22 |
+
LIMIT_TYPE2KEY_EXPIRATION = {
|
23 |
+
StrategyLimitType.MINUTELY: 120,
|
24 |
+
StrategyLimitType.HOURLY: 2 * SECONDS_IN_HOUR,
|
25 |
+
StrategyLimitType.DAILY: 25 * SECONDS_IN_HOUR,
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def use_credit_of_matching_strategy(
|
30 |
+
cache: BaseCache,
|
31 |
+
workspace: str,
|
32 |
+
project: str,
|
33 |
+
matching_strategies_limits: OrderedDict[str, List[StrategyLimit]],
|
34 |
+
) -> Optional[str]:
|
35 |
+
# In scope of this function, cache keys updates regarding usage limits for
|
36 |
+
# specific :workspace and :project are locked - to ensure increment to be done atomically
|
37 |
+
# Limits are accounted at the moment of registration - which may introduce inaccuracy
|
38 |
+
# given that registration is postponed from prediction
|
39 |
+
# Returns: strategy with spare credit if found - else None
|
40 |
+
with lock_limits(cache=cache, workspace=workspace, project=project):
|
41 |
+
strategy_with_spare_credit = find_strategy_with_spare_usage_credit(
|
42 |
+
cache=cache,
|
43 |
+
workspace=workspace,
|
44 |
+
project=project,
|
45 |
+
matching_strategies_limits=matching_strategies_limits,
|
46 |
+
)
|
47 |
+
if strategy_with_spare_credit is None:
|
48 |
+
return None
|
49 |
+
consume_strategy_limits_usage_credit(
|
50 |
+
cache=cache,
|
51 |
+
workspace=workspace,
|
52 |
+
project=project,
|
53 |
+
strategy_name=strategy_with_spare_credit,
|
54 |
+
)
|
55 |
+
return strategy_with_spare_credit
|
56 |
+
|
57 |
+
|
58 |
+
def return_strategy_credit(
|
59 |
+
cache: BaseCache,
|
60 |
+
workspace: str,
|
61 |
+
project: str,
|
62 |
+
strategy_name: str,
|
63 |
+
) -> None:
|
64 |
+
# In scope of this function, cache keys updates regarding usage limits for
|
65 |
+
# specific :workspace and :project are locked - to ensure decrement to be done atomically
|
66 |
+
# Returning strategy is a bit naive (we may add to a pool of credits from the next period - but only
|
67 |
+
# if we have previously taken from the previous one and some credits are used in the new pool) -
|
68 |
+
# in favour of easier implementation.
|
69 |
+
with lock_limits(cache=cache, workspace=workspace, project=project):
|
70 |
+
return_strategy_limits_usage_credit(
|
71 |
+
cache=cache,
|
72 |
+
workspace=workspace,
|
73 |
+
project=project,
|
74 |
+
strategy_name=strategy_name,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
@contextmanager
|
79 |
+
def lock_limits(
|
80 |
+
cache: BaseCache,
|
81 |
+
workspace: str,
|
82 |
+
project: str,
|
83 |
+
) -> Generator[Union[threading.Lock, redis.lock.Lock], None, None]:
|
84 |
+
limits_lock_key = generate_cache_key_for_active_learning_usage_lock(
|
85 |
+
workspace=workspace,
|
86 |
+
project=project,
|
87 |
+
)
|
88 |
+
with cache.lock(key=limits_lock_key, expire=MAX_LOCK_TIME) as lock:
|
89 |
+
yield lock
|
90 |
+
|
91 |
+
|
92 |
+
def find_strategy_with_spare_usage_credit(
|
93 |
+
cache: BaseCache,
|
94 |
+
workspace: str,
|
95 |
+
project: str,
|
96 |
+
matching_strategies_limits: OrderedDict[str, List[StrategyLimit]],
|
97 |
+
) -> Optional[str]:
|
98 |
+
for strategy_name, strategy_limits in matching_strategies_limits.items():
|
99 |
+
rejected_by_strategy = (
|
100 |
+
datapoint_should_be_rejected_based_on_strategy_usage_limits(
|
101 |
+
cache=cache,
|
102 |
+
workspace=workspace,
|
103 |
+
project=project,
|
104 |
+
strategy_name=strategy_name,
|
105 |
+
strategy_limits=strategy_limits,
|
106 |
+
)
|
107 |
+
)
|
108 |
+
if not rejected_by_strategy:
|
109 |
+
return strategy_name
|
110 |
+
return None
|
111 |
+
|
112 |
+
|
113 |
+
def datapoint_should_be_rejected_based_on_strategy_usage_limits(
|
114 |
+
cache: BaseCache,
|
115 |
+
workspace: str,
|
116 |
+
project: str,
|
117 |
+
strategy_name: str,
|
118 |
+
strategy_limits: List[StrategyLimit],
|
119 |
+
) -> bool:
|
120 |
+
for strategy_limit in strategy_limits:
|
121 |
+
limit_reached = datapoint_should_be_rejected_based_on_limit_usage(
|
122 |
+
cache=cache,
|
123 |
+
workspace=workspace,
|
124 |
+
project=project,
|
125 |
+
strategy_name=strategy_name,
|
126 |
+
strategy_limit=strategy_limit,
|
127 |
+
)
|
128 |
+
if limit_reached:
|
129 |
+
logger.debug(
|
130 |
+
f"Violated Active Learning strategy limit: {strategy_limit.limit_type.name} "
|
131 |
+
f"with value {strategy_limit.value} for sampling strategy: {strategy_name}."
|
132 |
+
)
|
133 |
+
return True
|
134 |
+
return False
|
135 |
+
|
136 |
+
|
137 |
+
def datapoint_should_be_rejected_based_on_limit_usage(
|
138 |
+
cache: BaseCache,
|
139 |
+
workspace: str,
|
140 |
+
project: str,
|
141 |
+
strategy_name: str,
|
142 |
+
strategy_limit: StrategyLimit,
|
143 |
+
) -> bool:
|
144 |
+
current_usage = get_current_strategy_limit_usage(
|
145 |
+
cache=cache,
|
146 |
+
workspace=workspace,
|
147 |
+
project=project,
|
148 |
+
strategy_name=strategy_name,
|
149 |
+
limit_type=strategy_limit.limit_type,
|
150 |
+
)
|
151 |
+
if current_usage is None:
|
152 |
+
current_usage = 0
|
153 |
+
return current_usage >= strategy_limit.value
|
154 |
+
|
155 |
+
|
156 |
+
def consume_strategy_limits_usage_credit(
|
157 |
+
cache: BaseCache,
|
158 |
+
workspace: str,
|
159 |
+
project: str,
|
160 |
+
strategy_name: str,
|
161 |
+
) -> None:
|
162 |
+
for limit_type in StrategyLimitType:
|
163 |
+
consume_strategy_limit_usage_credit(
|
164 |
+
cache=cache,
|
165 |
+
workspace=workspace,
|
166 |
+
project=project,
|
167 |
+
strategy_name=strategy_name,
|
168 |
+
limit_type=limit_type,
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
def consume_strategy_limit_usage_credit(
|
173 |
+
cache: BaseCache,
|
174 |
+
workspace: str,
|
175 |
+
project: str,
|
176 |
+
strategy_name: str,
|
177 |
+
limit_type: StrategyLimitType,
|
178 |
+
) -> None:
|
179 |
+
current_value = get_current_strategy_limit_usage(
|
180 |
+
cache=cache,
|
181 |
+
limit_type=limit_type,
|
182 |
+
workspace=workspace,
|
183 |
+
project=project,
|
184 |
+
strategy_name=strategy_name,
|
185 |
+
)
|
186 |
+
if current_value is None:
|
187 |
+
current_value = 0
|
188 |
+
current_value += 1
|
189 |
+
set_current_strategy_limit_usage(
|
190 |
+
current_value=current_value,
|
191 |
+
cache=cache,
|
192 |
+
limit_type=limit_type,
|
193 |
+
workspace=workspace,
|
194 |
+
project=project,
|
195 |
+
strategy_name=strategy_name,
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def return_strategy_limits_usage_credit(
|
200 |
+
cache: BaseCache,
|
201 |
+
workspace: str,
|
202 |
+
project: str,
|
203 |
+
strategy_name: str,
|
204 |
+
) -> None:
|
205 |
+
for limit_type in StrategyLimitType:
|
206 |
+
return_strategy_limit_usage_credit(
|
207 |
+
cache=cache,
|
208 |
+
workspace=workspace,
|
209 |
+
project=project,
|
210 |
+
strategy_name=strategy_name,
|
211 |
+
limit_type=limit_type,
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
def return_strategy_limit_usage_credit(
|
216 |
+
cache: BaseCache,
|
217 |
+
workspace: str,
|
218 |
+
project: str,
|
219 |
+
strategy_name: str,
|
220 |
+
limit_type: StrategyLimitType,
|
221 |
+
) -> None:
|
222 |
+
current_value = get_current_strategy_limit_usage(
|
223 |
+
cache=cache,
|
224 |
+
limit_type=limit_type,
|
225 |
+
workspace=workspace,
|
226 |
+
project=project,
|
227 |
+
strategy_name=strategy_name,
|
228 |
+
)
|
229 |
+
if current_value is None:
|
230 |
+
return None
|
231 |
+
current_value = max(current_value - 1, 0)
|
232 |
+
set_current_strategy_limit_usage(
|
233 |
+
current_value=current_value,
|
234 |
+
cache=cache,
|
235 |
+
limit_type=limit_type,
|
236 |
+
workspace=workspace,
|
237 |
+
project=project,
|
238 |
+
strategy_name=strategy_name,
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
def get_current_strategy_limit_usage(
|
243 |
+
cache: BaseCache,
|
244 |
+
workspace: str,
|
245 |
+
project: str,
|
246 |
+
strategy_name: str,
|
247 |
+
limit_type: StrategyLimitType,
|
248 |
+
) -> Optional[int]:
|
249 |
+
usage_key = generate_cache_key_for_active_learning_usage(
|
250 |
+
limit_type=limit_type,
|
251 |
+
workspace=workspace,
|
252 |
+
project=project,
|
253 |
+
strategy_name=strategy_name,
|
254 |
+
)
|
255 |
+
value = cache.get(usage_key)
|
256 |
+
if value is None:
|
257 |
+
return value
|
258 |
+
return value[USAGE_KEY]
|
259 |
+
|
260 |
+
|
261 |
+
def set_current_strategy_limit_usage(
|
262 |
+
current_value: int,
|
263 |
+
cache: BaseCache,
|
264 |
+
workspace: str,
|
265 |
+
project: str,
|
266 |
+
strategy_name: str,
|
267 |
+
limit_type: StrategyLimitType,
|
268 |
+
) -> None:
|
269 |
+
usage_key = generate_cache_key_for_active_learning_usage(
|
270 |
+
limit_type=limit_type,
|
271 |
+
workspace=workspace,
|
272 |
+
project=project,
|
273 |
+
strategy_name=strategy_name,
|
274 |
+
)
|
275 |
+
expire = LIMIT_TYPE2KEY_EXPIRATION[limit_type]
|
276 |
+
cache.set(key=usage_key, value={USAGE_KEY: current_value}, expire=expire) # type: ignore
|
277 |
+
|
278 |
+
|
279 |
+
def generate_cache_key_for_active_learning_usage_lock(
|
280 |
+
workspace: str,
|
281 |
+
project: str,
|
282 |
+
) -> str:
|
283 |
+
return f"active_learning:usage:{workspace}:{project}:usage:lock"
|
284 |
+
|
285 |
+
|
286 |
+
def generate_cache_key_for_active_learning_usage(
|
287 |
+
limit_type: StrategyLimitType,
|
288 |
+
workspace: str,
|
289 |
+
project: str,
|
290 |
+
strategy_name: str,
|
291 |
+
) -> str:
|
292 |
+
time_infix = LIMIT_TYPE2KEY_INFIX_GENERATOR[limit_type]()
|
293 |
+
return f"active_learning:usage:{workspace}:{project}:{strategy_name}:{time_infix}"
|
inference/core/active_learning/configuration.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from dataclasses import asdict
|
3 |
+
from typing import Any, Dict, List, Optional
|
4 |
+
|
5 |
+
from inference.core import logger
|
6 |
+
from inference.core.active_learning.entities import (
|
7 |
+
ActiveLearningConfiguration,
|
8 |
+
RoboflowProjectMetadata,
|
9 |
+
SamplingMethod,
|
10 |
+
)
|
11 |
+
from inference.core.active_learning.samplers.close_to_threshold import (
|
12 |
+
initialize_close_to_threshold_sampling,
|
13 |
+
)
|
14 |
+
from inference.core.active_learning.samplers.contains_classes import (
|
15 |
+
initialize_classes_based_sampling,
|
16 |
+
)
|
17 |
+
from inference.core.active_learning.samplers.number_of_detections import (
|
18 |
+
initialize_detections_number_based_sampling,
|
19 |
+
)
|
20 |
+
from inference.core.active_learning.samplers.random import initialize_random_sampling
|
21 |
+
from inference.core.cache.base import BaseCache
|
22 |
+
from inference.core.exceptions import (
|
23 |
+
ActiveLearningConfigurationDecodingError,
|
24 |
+
ActiveLearningConfigurationError,
|
25 |
+
RoboflowAPINotAuthorizedError,
|
26 |
+
RoboflowAPINotNotFoundError,
|
27 |
+
)
|
28 |
+
from inference.core.roboflow_api import (
|
29 |
+
get_roboflow_active_learning_configuration,
|
30 |
+
get_roboflow_dataset_type,
|
31 |
+
get_roboflow_workspace,
|
32 |
+
)
|
33 |
+
from inference.core.utils.roboflow import get_model_id_chunks
|
34 |
+
|
35 |
+
TYPE2SAMPLING_INITIALIZERS = {
|
36 |
+
"random": initialize_random_sampling,
|
37 |
+
"close_to_threshold": initialize_close_to_threshold_sampling,
|
38 |
+
"classes_based": initialize_classes_based_sampling,
|
39 |
+
"detections_number_based": initialize_detections_number_based_sampling,
|
40 |
+
}
|
41 |
+
ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE = 900 # 15 min
|
42 |
+
|
43 |
+
|
44 |
+
def prepare_active_learning_configuration(
|
45 |
+
api_key: str,
|
46 |
+
model_id: str,
|
47 |
+
cache: BaseCache,
|
48 |
+
) -> Optional[ActiveLearningConfiguration]:
|
49 |
+
project_metadata = get_roboflow_project_metadata(
|
50 |
+
api_key=api_key,
|
51 |
+
model_id=model_id,
|
52 |
+
cache=cache,
|
53 |
+
)
|
54 |
+
if not project_metadata.active_learning_configuration.get("enabled", False):
|
55 |
+
return None
|
56 |
+
logger.info(
|
57 |
+
f"Configuring active learning for workspace: {project_metadata.workspace_id}, "
|
58 |
+
f"project: {project_metadata.dataset_id} of type: {project_metadata.dataset_type}. "
|
59 |
+
f"AL configuration: {project_metadata.active_learning_configuration}"
|
60 |
+
)
|
61 |
+
return initialise_active_learning_configuration(
|
62 |
+
project_metadata=project_metadata,
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def prepare_active_learning_configuration_inplace(
|
67 |
+
api_key: str,
|
68 |
+
model_id: str,
|
69 |
+
active_learning_configuration: Optional[dict],
|
70 |
+
) -> Optional[ActiveLearningConfiguration]:
|
71 |
+
if (
|
72 |
+
active_learning_configuration is None
|
73 |
+
or active_learning_configuration.get("enabled", False) is False
|
74 |
+
):
|
75 |
+
return None
|
76 |
+
dataset_id, version_id = get_model_id_chunks(model_id=model_id)
|
77 |
+
workspace_id = get_roboflow_workspace(api_key=api_key)
|
78 |
+
dataset_type = get_roboflow_dataset_type(
|
79 |
+
api_key=api_key,
|
80 |
+
workspace_id=workspace_id,
|
81 |
+
dataset_id=dataset_id,
|
82 |
+
)
|
83 |
+
project_metadata = RoboflowProjectMetadata(
|
84 |
+
dataset_id=dataset_id,
|
85 |
+
version_id=version_id,
|
86 |
+
workspace_id=workspace_id,
|
87 |
+
dataset_type=dataset_type,
|
88 |
+
active_learning_configuration=active_learning_configuration,
|
89 |
+
)
|
90 |
+
return initialise_active_learning_configuration(
|
91 |
+
project_metadata=project_metadata,
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
def get_roboflow_project_metadata(
|
96 |
+
api_key: str,
|
97 |
+
model_id: str,
|
98 |
+
cache: BaseCache,
|
99 |
+
) -> RoboflowProjectMetadata:
|
100 |
+
logger.info(f"Fetching active learning configuration.")
|
101 |
+
config_cache_key = construct_cache_key_for_active_learning_config(
|
102 |
+
api_key=api_key, model_id=model_id
|
103 |
+
)
|
104 |
+
cached_config = cache.get(config_cache_key)
|
105 |
+
if cached_config is not None:
|
106 |
+
logger.info("Found Active Learning configuration in cache.")
|
107 |
+
return parse_cached_roboflow_project_metadata(cached_config=cached_config)
|
108 |
+
dataset_id, version_id = get_model_id_chunks(model_id=model_id)
|
109 |
+
workspace_id = get_roboflow_workspace(api_key=api_key)
|
110 |
+
dataset_type = get_roboflow_dataset_type(
|
111 |
+
api_key=api_key,
|
112 |
+
workspace_id=workspace_id,
|
113 |
+
dataset_id=dataset_id,
|
114 |
+
)
|
115 |
+
try:
|
116 |
+
roboflow_api_configuration = get_roboflow_active_learning_configuration(
|
117 |
+
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id
|
118 |
+
)
|
119 |
+
except (RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError):
|
120 |
+
# currently backend returns HTTP 404 if dataset does not exist
|
121 |
+
# or workspace_id from api_key indicate that the owner is different,
|
122 |
+
# so in the situation when we query for Universe dataset.
|
123 |
+
# We want the owner of public dataset to be able to set AL configs
|
124 |
+
# and use them, but not other people. At this point it's known
|
125 |
+
# that HTTP 404 means not authorised (which will probably change
|
126 |
+
# in future iteration of backend) - so on both NotAuth and NotFound
|
127 |
+
# errors we assume that we simply cannot use AL with this model and
|
128 |
+
# this api_key.
|
129 |
+
roboflow_api_configuration = {"enabled": False}
|
130 |
+
configuration = RoboflowProjectMetadata(
|
131 |
+
dataset_id=dataset_id,
|
132 |
+
version_id=version_id,
|
133 |
+
workspace_id=workspace_id,
|
134 |
+
dataset_type=dataset_type,
|
135 |
+
active_learning_configuration=roboflow_api_configuration,
|
136 |
+
)
|
137 |
+
cache.set(
|
138 |
+
key=config_cache_key,
|
139 |
+
value=asdict(configuration),
|
140 |
+
expire=ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE,
|
141 |
+
)
|
142 |
+
return configuration
|
143 |
+
|
144 |
+
|
145 |
+
def construct_cache_key_for_active_learning_config(api_key: str, model_id: str) -> str:
|
146 |
+
dataset_id = model_id.split("/")[0]
|
147 |
+
api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest()
|
148 |
+
return f"active_learning:configurations:{api_key_hash}:{dataset_id}"
|
149 |
+
|
150 |
+
|
151 |
+
def parse_cached_roboflow_project_metadata(
|
152 |
+
cached_config: dict,
|
153 |
+
) -> RoboflowProjectMetadata:
|
154 |
+
try:
|
155 |
+
return RoboflowProjectMetadata(**cached_config)
|
156 |
+
except Exception as error:
|
157 |
+
raise ActiveLearningConfigurationDecodingError(
|
158 |
+
f"Failed to initialise Active Learning configuration. Cause: {str(error)}"
|
159 |
+
) from error
|
160 |
+
|
161 |
+
|
162 |
+
def initialise_active_learning_configuration(
|
163 |
+
project_metadata: RoboflowProjectMetadata,
|
164 |
+
) -> ActiveLearningConfiguration:
|
165 |
+
sampling_methods = initialize_sampling_methods(
|
166 |
+
sampling_strategies_configs=project_metadata.active_learning_configuration[
|
167 |
+
"sampling_strategies"
|
168 |
+
],
|
169 |
+
)
|
170 |
+
target_workspace_id = project_metadata.active_learning_configuration.get(
|
171 |
+
"target_workspace", project_metadata.workspace_id
|
172 |
+
)
|
173 |
+
target_dataset_id = project_metadata.active_learning_configuration.get(
|
174 |
+
"target_project", project_metadata.dataset_id
|
175 |
+
)
|
176 |
+
return ActiveLearningConfiguration.init(
|
177 |
+
roboflow_api_configuration=project_metadata.active_learning_configuration,
|
178 |
+
sampling_methods=sampling_methods,
|
179 |
+
workspace_id=target_workspace_id,
|
180 |
+
dataset_id=target_dataset_id,
|
181 |
+
model_id=f"{project_metadata.dataset_id}/{project_metadata.version_id}",
|
182 |
+
)
|
183 |
+
|
184 |
+
|
185 |
+
def initialize_sampling_methods(
|
186 |
+
sampling_strategies_configs: List[Dict[str, Any]]
|
187 |
+
) -> List[SamplingMethod]:
|
188 |
+
result = []
|
189 |
+
for sampling_strategy_config in sampling_strategies_configs:
|
190 |
+
sampling_type = sampling_strategy_config["type"]
|
191 |
+
if sampling_type not in TYPE2SAMPLING_INITIALIZERS:
|
192 |
+
logger.warn(
|
193 |
+
f"Could not identify sampling method `{sampling_type}` - skipping initialisation."
|
194 |
+
)
|
195 |
+
continue
|
196 |
+
initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type]
|
197 |
+
result.append(initializer(sampling_strategy_config))
|
198 |
+
names = set(m.name for m in result)
|
199 |
+
if len(names) != len(result):
|
200 |
+
raise ActiveLearningConfigurationError(
|
201 |
+
"Detected duplication of Active Learning strategies names."
|
202 |
+
)
|
203 |
+
return result
|
inference/core/active_learning/core.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import List, Optional, Tuple
|
3 |
+
from uuid import uuid4
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from inference.core import logger
|
8 |
+
from inference.core.active_learning.cache_operations import (
|
9 |
+
return_strategy_credit,
|
10 |
+
use_credit_of_matching_strategy,
|
11 |
+
)
|
12 |
+
from inference.core.active_learning.entities import (
|
13 |
+
ActiveLearningConfiguration,
|
14 |
+
ImageDimensions,
|
15 |
+
Prediction,
|
16 |
+
PredictionType,
|
17 |
+
SamplingMethod,
|
18 |
+
)
|
19 |
+
from inference.core.active_learning.post_processing import (
|
20 |
+
adjust_prediction_to_client_scaling_factor,
|
21 |
+
encode_prediction,
|
22 |
+
)
|
23 |
+
from inference.core.cache.base import BaseCache
|
24 |
+
from inference.core.env import ACTIVE_LEARNING_TAGS
|
25 |
+
from inference.core.roboflow_api import (
|
26 |
+
annotate_image_at_roboflow,
|
27 |
+
register_image_at_roboflow,
|
28 |
+
)
|
29 |
+
from inference.core.utils.image_utils import encode_image_to_jpeg_bytes
|
30 |
+
from inference.core.utils.preprocess import downscale_image_keeping_aspect_ratio
|
31 |
+
|
32 |
+
|
33 |
+
def execute_sampling(
|
34 |
+
image: np.ndarray,
|
35 |
+
prediction: Prediction,
|
36 |
+
prediction_type: PredictionType,
|
37 |
+
sampling_methods: List[SamplingMethod],
|
38 |
+
) -> List[str]:
|
39 |
+
matching_strategies = []
|
40 |
+
for method in sampling_methods:
|
41 |
+
sampling_result = method.sample(image, prediction, prediction_type)
|
42 |
+
if sampling_result:
|
43 |
+
matching_strategies.append(method.name)
|
44 |
+
return matching_strategies
|
45 |
+
|
46 |
+
|
47 |
+
def execute_datapoint_registration(
|
48 |
+
cache: BaseCache,
|
49 |
+
matching_strategies: List[str],
|
50 |
+
image: np.ndarray,
|
51 |
+
prediction: Prediction,
|
52 |
+
prediction_type: PredictionType,
|
53 |
+
configuration: ActiveLearningConfiguration,
|
54 |
+
api_key: str,
|
55 |
+
batch_name: str,
|
56 |
+
) -> None:
|
57 |
+
local_image_id = str(uuid4())
|
58 |
+
encoded_image, scaling_factor = prepare_image_to_registration(
|
59 |
+
image=image,
|
60 |
+
desired_size=configuration.max_image_size,
|
61 |
+
jpeg_compression_level=configuration.jpeg_compression_level,
|
62 |
+
)
|
63 |
+
prediction = adjust_prediction_to_client_scaling_factor(
|
64 |
+
prediction=prediction,
|
65 |
+
scaling_factor=scaling_factor,
|
66 |
+
prediction_type=prediction_type,
|
67 |
+
)
|
68 |
+
matching_strategies_limits = OrderedDict(
|
69 |
+
(strategy_name, configuration.strategies_limits[strategy_name])
|
70 |
+
for strategy_name in matching_strategies
|
71 |
+
)
|
72 |
+
strategy_with_spare_credit = use_credit_of_matching_strategy(
|
73 |
+
cache=cache,
|
74 |
+
workspace=configuration.workspace_id,
|
75 |
+
project=configuration.dataset_id,
|
76 |
+
matching_strategies_limits=matching_strategies_limits,
|
77 |
+
)
|
78 |
+
if strategy_with_spare_credit is None:
|
79 |
+
logger.debug(f"Limit on Active Learning strategy reached.")
|
80 |
+
return None
|
81 |
+
register_datapoint_at_roboflow(
|
82 |
+
cache=cache,
|
83 |
+
strategy_with_spare_credit=strategy_with_spare_credit,
|
84 |
+
encoded_image=encoded_image,
|
85 |
+
local_image_id=local_image_id,
|
86 |
+
prediction=prediction,
|
87 |
+
prediction_type=prediction_type,
|
88 |
+
configuration=configuration,
|
89 |
+
api_key=api_key,
|
90 |
+
batch_name=batch_name,
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def prepare_image_to_registration(
|
95 |
+
image: np.ndarray,
|
96 |
+
desired_size: Optional[ImageDimensions],
|
97 |
+
jpeg_compression_level: int,
|
98 |
+
) -> Tuple[bytes, float]:
|
99 |
+
scaling_factor = 1.0
|
100 |
+
if desired_size is not None:
|
101 |
+
height_before_scale = image.shape[0]
|
102 |
+
image = downscale_image_keeping_aspect_ratio(
|
103 |
+
image=image,
|
104 |
+
desired_size=desired_size.to_wh(),
|
105 |
+
)
|
106 |
+
scaling_factor = image.shape[0] / height_before_scale
|
107 |
+
return (
|
108 |
+
encode_image_to_jpeg_bytes(image=image, jpeg_quality=jpeg_compression_level),
|
109 |
+
scaling_factor,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def register_datapoint_at_roboflow(
|
114 |
+
cache: BaseCache,
|
115 |
+
strategy_with_spare_credit: str,
|
116 |
+
encoded_image: bytes,
|
117 |
+
local_image_id: str,
|
118 |
+
prediction: Prediction,
|
119 |
+
prediction_type: PredictionType,
|
120 |
+
configuration: ActiveLearningConfiguration,
|
121 |
+
api_key: str,
|
122 |
+
batch_name: str,
|
123 |
+
) -> None:
|
124 |
+
tags = collect_tags(
|
125 |
+
configuration=configuration,
|
126 |
+
sampling_strategy=strategy_with_spare_credit,
|
127 |
+
)
|
128 |
+
roboflow_image_id = safe_register_image_at_roboflow(
|
129 |
+
cache=cache,
|
130 |
+
strategy_with_spare_credit=strategy_with_spare_credit,
|
131 |
+
encoded_image=encoded_image,
|
132 |
+
local_image_id=local_image_id,
|
133 |
+
configuration=configuration,
|
134 |
+
api_key=api_key,
|
135 |
+
batch_name=batch_name,
|
136 |
+
tags=tags,
|
137 |
+
)
|
138 |
+
if is_prediction_registration_forbidden(
|
139 |
+
prediction=prediction,
|
140 |
+
persist_predictions=configuration.persist_predictions,
|
141 |
+
roboflow_image_id=roboflow_image_id,
|
142 |
+
):
|
143 |
+
return None
|
144 |
+
encoded_prediction, prediction_file_type = encode_prediction(
|
145 |
+
prediction=prediction, prediction_type=prediction_type
|
146 |
+
)
|
147 |
+
_ = annotate_image_at_roboflow(
|
148 |
+
api_key=api_key,
|
149 |
+
dataset_id=configuration.dataset_id,
|
150 |
+
local_image_id=local_image_id,
|
151 |
+
roboflow_image_id=roboflow_image_id,
|
152 |
+
annotation_content=encoded_prediction,
|
153 |
+
annotation_file_type=prediction_file_type,
|
154 |
+
is_prediction=True,
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
def collect_tags(
|
159 |
+
configuration: ActiveLearningConfiguration, sampling_strategy: str
|
160 |
+
) -> List[str]:
|
161 |
+
tags = ACTIVE_LEARNING_TAGS if ACTIVE_LEARNING_TAGS is not None else []
|
162 |
+
tags.extend(configuration.tags)
|
163 |
+
tags.extend(configuration.strategies_tags[sampling_strategy])
|
164 |
+
if configuration.persist_predictions:
|
165 |
+
# this replacement is needed due to backend input validation
|
166 |
+
tags.append(configuration.model_id.replace("/", "-"))
|
167 |
+
return tags
|
168 |
+
|
169 |
+
|
170 |
+
def safe_register_image_at_roboflow(
|
171 |
+
cache: BaseCache,
|
172 |
+
strategy_with_spare_credit: str,
|
173 |
+
encoded_image: bytes,
|
174 |
+
local_image_id: str,
|
175 |
+
configuration: ActiveLearningConfiguration,
|
176 |
+
api_key: str,
|
177 |
+
batch_name: str,
|
178 |
+
tags: List[str],
|
179 |
+
) -> Optional[str]:
|
180 |
+
credit_to_be_returned = False
|
181 |
+
try:
|
182 |
+
registration_response = register_image_at_roboflow(
|
183 |
+
api_key=api_key,
|
184 |
+
dataset_id=configuration.dataset_id,
|
185 |
+
local_image_id=local_image_id,
|
186 |
+
image_bytes=encoded_image,
|
187 |
+
batch_name=batch_name,
|
188 |
+
tags=tags,
|
189 |
+
)
|
190 |
+
image_duplicated = registration_response.get("duplicate", False)
|
191 |
+
if image_duplicated:
|
192 |
+
credit_to_be_returned = True
|
193 |
+
logger.warning(f"Image duplication detected: {registration_response}.")
|
194 |
+
return None
|
195 |
+
return registration_response["id"]
|
196 |
+
except Exception as error:
|
197 |
+
credit_to_be_returned = True
|
198 |
+
raise error
|
199 |
+
finally:
|
200 |
+
if credit_to_be_returned:
|
201 |
+
return_strategy_credit(
|
202 |
+
cache=cache,
|
203 |
+
workspace=configuration.workspace_id,
|
204 |
+
project=configuration.dataset_id,
|
205 |
+
strategy_name=strategy_with_spare_credit,
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
+
def is_prediction_registration_forbidden(
|
210 |
+
prediction: Prediction,
|
211 |
+
persist_predictions: bool,
|
212 |
+
roboflow_image_id: Optional[str],
|
213 |
+
) -> bool:
|
214 |
+
return (
|
215 |
+
roboflow_image_id is None
|
216 |
+
or persist_predictions is False
|
217 |
+
or prediction.get("is_stub", False) is True
|
218 |
+
or (len(prediction.get("predictions", [])) == 0 and "top" not in prediction)
|
219 |
+
)
|
inference/core/active_learning/entities.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from enum import Enum
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from inference.core.entities.types import DatasetID, WorkspaceID
|
8 |
+
from inference.core.exceptions import ActiveLearningConfigurationDecodingError
|
9 |
+
|
10 |
+
LocalImageIdentifier = str
|
11 |
+
PredictionType = str
|
12 |
+
Prediction = dict
|
13 |
+
SerialisedPrediction = str
|
14 |
+
PredictionFileType = str
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass(frozen=True)
|
18 |
+
class ImageDimensions:
|
19 |
+
height: int
|
20 |
+
width: int
|
21 |
+
|
22 |
+
def to_hw(self) -> Tuple[int, int]:
|
23 |
+
return self.height, self.width
|
24 |
+
|
25 |
+
def to_wh(self) -> Tuple[int, int]:
|
26 |
+
return self.width, self.height
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass(frozen=True)
|
30 |
+
class SamplingMethod:
|
31 |
+
name: str
|
32 |
+
sample: Callable[[np.ndarray, Prediction, PredictionType], bool]
|
33 |
+
|
34 |
+
|
35 |
+
class BatchReCreationInterval(Enum):
|
36 |
+
NEVER = "never"
|
37 |
+
DAILY = "daily"
|
38 |
+
WEEKLY = "weekly"
|
39 |
+
MONTHLY = "monthly"
|
40 |
+
|
41 |
+
|
42 |
+
class StrategyLimitType(Enum):
|
43 |
+
MINUTELY = "minutely"
|
44 |
+
HOURLY = "hourly"
|
45 |
+
DAILY = "daily"
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass(frozen=True)
|
49 |
+
class StrategyLimit:
|
50 |
+
limit_type: StrategyLimitType
|
51 |
+
value: int
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def from_dict(cls, specification: dict) -> "StrategyLimit":
|
55 |
+
return cls(
|
56 |
+
limit_type=StrategyLimitType(specification["type"]),
|
57 |
+
value=specification["value"],
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
@dataclass(frozen=True)
|
62 |
+
class ActiveLearningConfiguration:
|
63 |
+
max_image_size: Optional[ImageDimensions]
|
64 |
+
jpeg_compression_level: int
|
65 |
+
persist_predictions: bool
|
66 |
+
sampling_methods: List[SamplingMethod]
|
67 |
+
batches_name_prefix: str
|
68 |
+
batch_recreation_interval: BatchReCreationInterval
|
69 |
+
max_batch_images: Optional[int]
|
70 |
+
workspace_id: WorkspaceID
|
71 |
+
dataset_id: DatasetID
|
72 |
+
model_id: str
|
73 |
+
strategies_limits: Dict[str, List[StrategyLimit]]
|
74 |
+
tags: List[str]
|
75 |
+
strategies_tags: Dict[str, List[str]]
|
76 |
+
|
77 |
+
@classmethod
|
78 |
+
def init(
|
79 |
+
cls,
|
80 |
+
roboflow_api_configuration: Dict[str, Any],
|
81 |
+
sampling_methods: List[SamplingMethod],
|
82 |
+
workspace_id: WorkspaceID,
|
83 |
+
dataset_id: DatasetID,
|
84 |
+
model_id: str,
|
85 |
+
) -> "ActiveLearningConfiguration":
|
86 |
+
try:
|
87 |
+
max_image_size = roboflow_api_configuration.get("max_image_size")
|
88 |
+
if max_image_size is not None:
|
89 |
+
max_image_size = ImageDimensions(
|
90 |
+
height=roboflow_api_configuration["max_image_size"][0],
|
91 |
+
width=roboflow_api_configuration["max_image_size"][1],
|
92 |
+
)
|
93 |
+
strategies_limits = {
|
94 |
+
strategy["name"]: [
|
95 |
+
StrategyLimit.from_dict(specification=specification)
|
96 |
+
for specification in strategy.get("limits", [])
|
97 |
+
]
|
98 |
+
for strategy in roboflow_api_configuration["sampling_strategies"]
|
99 |
+
}
|
100 |
+
strategies_tags = {
|
101 |
+
strategy["name"]: strategy.get("tags", [])
|
102 |
+
for strategy in roboflow_api_configuration["sampling_strategies"]
|
103 |
+
}
|
104 |
+
return cls(
|
105 |
+
max_image_size=max_image_size,
|
106 |
+
jpeg_compression_level=roboflow_api_configuration.get(
|
107 |
+
"jpeg_compression_level", 95
|
108 |
+
),
|
109 |
+
persist_predictions=roboflow_api_configuration["persist_predictions"],
|
110 |
+
sampling_methods=sampling_methods,
|
111 |
+
batches_name_prefix=roboflow_api_configuration["batching_strategy"][
|
112 |
+
"batches_name_prefix"
|
113 |
+
],
|
114 |
+
batch_recreation_interval=BatchReCreationInterval(
|
115 |
+
roboflow_api_configuration["batching_strategy"][
|
116 |
+
"recreation_interval"
|
117 |
+
]
|
118 |
+
),
|
119 |
+
max_batch_images=roboflow_api_configuration["batching_strategy"].get(
|
120 |
+
"max_batch_images"
|
121 |
+
),
|
122 |
+
workspace_id=workspace_id,
|
123 |
+
dataset_id=dataset_id,
|
124 |
+
model_id=model_id,
|
125 |
+
strategies_limits=strategies_limits,
|
126 |
+
tags=roboflow_api_configuration.get("tags", []),
|
127 |
+
strategies_tags=strategies_tags,
|
128 |
+
)
|
129 |
+
except (KeyError, ValueError) as e:
|
130 |
+
raise ActiveLearningConfigurationDecodingError(
|
131 |
+
f"Failed to initialise Active Learning configuration. Cause: {str(e)}"
|
132 |
+
) from e
|
133 |
+
|
134 |
+
|
135 |
+
@dataclass(frozen=True)
|
136 |
+
class RoboflowProjectMetadata:
|
137 |
+
dataset_id: DatasetID
|
138 |
+
version_id: str
|
139 |
+
workspace_id: WorkspaceID
|
140 |
+
dataset_type: str
|
141 |
+
active_learning_configuration: dict
|
inference/core/active_learning/middlewares.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue
|
2 |
+
from queue import Queue
|
3 |
+
from threading import Thread
|
4 |
+
from typing import Any, List, Optional
|
5 |
+
|
6 |
+
from inference.core import logger
|
7 |
+
from inference.core.active_learning.accounting import image_can_be_submitted_to_batch
|
8 |
+
from inference.core.active_learning.batching import generate_batch_name
|
9 |
+
from inference.core.active_learning.configuration import (
|
10 |
+
prepare_active_learning_configuration,
|
11 |
+
prepare_active_learning_configuration_inplace,
|
12 |
+
)
|
13 |
+
from inference.core.active_learning.core import (
|
14 |
+
execute_datapoint_registration,
|
15 |
+
execute_sampling,
|
16 |
+
)
|
17 |
+
from inference.core.active_learning.entities import (
|
18 |
+
ActiveLearningConfiguration,
|
19 |
+
Prediction,
|
20 |
+
PredictionType,
|
21 |
+
)
|
22 |
+
from inference.core.cache.base import BaseCache
|
23 |
+
from inference.core.utils.image_utils import load_image
|
24 |
+
|
25 |
+
MAX_REGISTRATION_QUEUE_SIZE = 512
|
26 |
+
|
27 |
+
|
28 |
+
class NullActiveLearningMiddleware:
|
29 |
+
def register_batch(
|
30 |
+
self,
|
31 |
+
inference_inputs: List[Any],
|
32 |
+
predictions: List[Prediction],
|
33 |
+
prediction_type: PredictionType,
|
34 |
+
disable_preproc_auto_orient: bool = False,
|
35 |
+
) -> None:
|
36 |
+
pass
|
37 |
+
|
38 |
+
def register(
|
39 |
+
self,
|
40 |
+
inference_input: Any,
|
41 |
+
prediction: dict,
|
42 |
+
prediction_type: PredictionType,
|
43 |
+
disable_preproc_auto_orient: bool = False,
|
44 |
+
) -> None:
|
45 |
+
pass
|
46 |
+
|
47 |
+
def start_registration_thread(self) -> None:
|
48 |
+
pass
|
49 |
+
|
50 |
+
def stop_registration_thread(self) -> None:
|
51 |
+
pass
|
52 |
+
|
53 |
+
def __enter__(self) -> "NullActiveLearningMiddleware":
|
54 |
+
return self
|
55 |
+
|
56 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
57 |
+
pass
|
58 |
+
|
59 |
+
|
60 |
+
class ActiveLearningMiddleware:
|
61 |
+
@classmethod
|
62 |
+
def init(
|
63 |
+
cls, api_key: str, model_id: str, cache: BaseCache
|
64 |
+
) -> "ActiveLearningMiddleware":
|
65 |
+
configuration = prepare_active_learning_configuration(
|
66 |
+
api_key=api_key,
|
67 |
+
model_id=model_id,
|
68 |
+
cache=cache,
|
69 |
+
)
|
70 |
+
return cls(
|
71 |
+
api_key=api_key,
|
72 |
+
configuration=configuration,
|
73 |
+
cache=cache,
|
74 |
+
)
|
75 |
+
|
76 |
+
@classmethod
|
77 |
+
def init_from_config(
|
78 |
+
cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict]
|
79 |
+
) -> "ActiveLearningMiddleware":
|
80 |
+
configuration = prepare_active_learning_configuration_inplace(
|
81 |
+
api_key=api_key,
|
82 |
+
model_id=model_id,
|
83 |
+
active_learning_configuration=config,
|
84 |
+
)
|
85 |
+
return cls(
|
86 |
+
api_key=api_key,
|
87 |
+
configuration=configuration,
|
88 |
+
cache=cache,
|
89 |
+
)
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
api_key: str,
|
94 |
+
configuration: Optional[ActiveLearningConfiguration],
|
95 |
+
cache: BaseCache,
|
96 |
+
):
|
97 |
+
self._api_key = api_key
|
98 |
+
self._configuration = configuration
|
99 |
+
self._cache = cache
|
100 |
+
|
101 |
+
def register_batch(
|
102 |
+
self,
|
103 |
+
inference_inputs: List[Any],
|
104 |
+
predictions: List[Prediction],
|
105 |
+
prediction_type: PredictionType,
|
106 |
+
disable_preproc_auto_orient: bool = False,
|
107 |
+
) -> None:
|
108 |
+
for inference_input, prediction in zip(inference_inputs, predictions):
|
109 |
+
self.register(
|
110 |
+
inference_input=inference_input,
|
111 |
+
prediction=prediction,
|
112 |
+
prediction_type=prediction_type,
|
113 |
+
disable_preproc_auto_orient=disable_preproc_auto_orient,
|
114 |
+
)
|
115 |
+
|
116 |
+
def register(
|
117 |
+
self,
|
118 |
+
inference_input: Any,
|
119 |
+
prediction: dict,
|
120 |
+
prediction_type: PredictionType,
|
121 |
+
disable_preproc_auto_orient: bool = False,
|
122 |
+
) -> None:
|
123 |
+
self._execute_registration(
|
124 |
+
inference_input=inference_input,
|
125 |
+
prediction=prediction,
|
126 |
+
prediction_type=prediction_type,
|
127 |
+
disable_preproc_auto_orient=disable_preproc_auto_orient,
|
128 |
+
)
|
129 |
+
|
130 |
+
def _execute_registration(
|
131 |
+
self,
|
132 |
+
inference_input: Any,
|
133 |
+
prediction: dict,
|
134 |
+
prediction_type: PredictionType,
|
135 |
+
disable_preproc_auto_orient: bool = False,
|
136 |
+
) -> None:
|
137 |
+
if self._configuration is None:
|
138 |
+
return None
|
139 |
+
image, is_bgr = load_image(
|
140 |
+
value=inference_input,
|
141 |
+
disable_preproc_auto_orient=disable_preproc_auto_orient,
|
142 |
+
)
|
143 |
+
if not is_bgr:
|
144 |
+
image = image[:, :, ::-1]
|
145 |
+
matching_strategies = execute_sampling(
|
146 |
+
image=image,
|
147 |
+
prediction=prediction,
|
148 |
+
prediction_type=prediction_type,
|
149 |
+
sampling_methods=self._configuration.sampling_methods,
|
150 |
+
)
|
151 |
+
if len(matching_strategies) == 0:
|
152 |
+
return None
|
153 |
+
batch_name = generate_batch_name(configuration=self._configuration)
|
154 |
+
if not image_can_be_submitted_to_batch(
|
155 |
+
batch_name=batch_name,
|
156 |
+
workspace_id=self._configuration.workspace_id,
|
157 |
+
dataset_id=self._configuration.dataset_id,
|
158 |
+
max_batch_images=self._configuration.max_batch_images,
|
159 |
+
api_key=self._api_key,
|
160 |
+
):
|
161 |
+
logger.debug(f"Limit on Active Learning batch size reached.")
|
162 |
+
return None
|
163 |
+
execute_datapoint_registration(
|
164 |
+
cache=self._cache,
|
165 |
+
matching_strategies=matching_strategies,
|
166 |
+
image=image,
|
167 |
+
prediction=prediction,
|
168 |
+
prediction_type=prediction_type,
|
169 |
+
configuration=self._configuration,
|
170 |
+
api_key=self._api_key,
|
171 |
+
batch_name=batch_name,
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
class ThreadingActiveLearningMiddleware(ActiveLearningMiddleware):
|
176 |
+
@classmethod
|
177 |
+
def init(
|
178 |
+
cls,
|
179 |
+
api_key: str,
|
180 |
+
model_id: str,
|
181 |
+
cache: BaseCache,
|
182 |
+
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
|
183 |
+
) -> "ThreadingActiveLearningMiddleware":
|
184 |
+
configuration = prepare_active_learning_configuration(
|
185 |
+
api_key=api_key,
|
186 |
+
model_id=model_id,
|
187 |
+
cache=cache,
|
188 |
+
)
|
189 |
+
task_queue = Queue(max_queue_size)
|
190 |
+
return cls(
|
191 |
+
api_key=api_key,
|
192 |
+
configuration=configuration,
|
193 |
+
cache=cache,
|
194 |
+
task_queue=task_queue,
|
195 |
+
)
|
196 |
+
|
197 |
+
@classmethod
|
198 |
+
def init_from_config(
|
199 |
+
cls,
|
200 |
+
api_key: str,
|
201 |
+
model_id: str,
|
202 |
+
cache: BaseCache,
|
203 |
+
config: Optional[dict],
|
204 |
+
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
|
205 |
+
) -> "ThreadingActiveLearningMiddleware":
|
206 |
+
configuration = prepare_active_learning_configuration_inplace(
|
207 |
+
api_key=api_key,
|
208 |
+
model_id=model_id,
|
209 |
+
active_learning_configuration=config,
|
210 |
+
)
|
211 |
+
task_queue = Queue(max_queue_size)
|
212 |
+
return cls(
|
213 |
+
api_key=api_key,
|
214 |
+
configuration=configuration,
|
215 |
+
cache=cache,
|
216 |
+
task_queue=task_queue,
|
217 |
+
)
|
218 |
+
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
api_key: str,
|
222 |
+
configuration: ActiveLearningConfiguration,
|
223 |
+
cache: BaseCache,
|
224 |
+
task_queue: Queue,
|
225 |
+
):
|
226 |
+
super().__init__(api_key=api_key, configuration=configuration, cache=cache)
|
227 |
+
self._task_queue = task_queue
|
228 |
+
self._registration_thread: Optional[Thread] = None
|
229 |
+
|
230 |
+
def register(
|
231 |
+
self,
|
232 |
+
inference_input: Any,
|
233 |
+
prediction: dict,
|
234 |
+
prediction_type: PredictionType,
|
235 |
+
disable_preproc_auto_orient: bool = False,
|
236 |
+
) -> None:
|
237 |
+
logger.debug(f"Putting registration task into queue")
|
238 |
+
try:
|
239 |
+
self._task_queue.put_nowait(
|
240 |
+
(
|
241 |
+
inference_input,
|
242 |
+
prediction,
|
243 |
+
prediction_type,
|
244 |
+
disable_preproc_auto_orient,
|
245 |
+
)
|
246 |
+
)
|
247 |
+
except queue.Full:
|
248 |
+
logger.warning(
|
249 |
+
f"Dropping datapoint registered in Active Learning due to insufficient processing "
|
250 |
+
f"capabilities."
|
251 |
+
)
|
252 |
+
|
253 |
+
def start_registration_thread(self) -> None:
|
254 |
+
if self._registration_thread is not None:
|
255 |
+
logger.warning(f"Registration thread already started.")
|
256 |
+
return None
|
257 |
+
logger.debug("Staring registration thread")
|
258 |
+
self._registration_thread = Thread(target=self._consume_queue)
|
259 |
+
self._registration_thread.start()
|
260 |
+
|
261 |
+
def stop_registration_thread(self) -> None:
|
262 |
+
if self._registration_thread is None:
|
263 |
+
logger.warning("Registration thread is already stopped.")
|
264 |
+
return None
|
265 |
+
logger.debug("Stopping registration thread")
|
266 |
+
self._task_queue.put(None)
|
267 |
+
self._registration_thread.join()
|
268 |
+
if self._registration_thread.is_alive():
|
269 |
+
logger.warning(f"Registration thread stopping was unsuccessful.")
|
270 |
+
self._registration_thread = None
|
271 |
+
|
272 |
+
def _consume_queue(self) -> None:
|
273 |
+
queue_closed = False
|
274 |
+
while not queue_closed:
|
275 |
+
queue_closed = self._consume_queue_task()
|
276 |
+
|
277 |
+
def _consume_queue_task(self) -> bool:
|
278 |
+
logger.debug("Consuming registration task")
|
279 |
+
task = self._task_queue.get()
|
280 |
+
logger.debug("Received registration task")
|
281 |
+
if task is None:
|
282 |
+
logger.debug("Terminating registration thread")
|
283 |
+
self._task_queue.task_done()
|
284 |
+
return True
|
285 |
+
inference_input, prediction, prediction_type, disable_preproc_auto_orient = task
|
286 |
+
try:
|
287 |
+
self._execute_registration(
|
288 |
+
inference_input=inference_input,
|
289 |
+
prediction=prediction,
|
290 |
+
prediction_type=prediction_type,
|
291 |
+
disable_preproc_auto_orient=disable_preproc_auto_orient,
|
292 |
+
)
|
293 |
+
except Exception as error:
|
294 |
+
# Error handling to be decided
|
295 |
+
logger.warning(
|
296 |
+
f"Error in datapoint registration for Active Learning. Details: {error}. "
|
297 |
+
f"Error is suppressed in favour of normal operations of registration thread."
|
298 |
+
)
|
299 |
+
self._task_queue.task_done()
|
300 |
+
return False
|
301 |
+
|
302 |
+
def __enter__(self) -> "ThreadingActiveLearningMiddleware":
|
303 |
+
self.start_registration_thread()
|
304 |
+
return self
|
305 |
+
|
306 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
307 |
+
self.stop_registration_thread()
|
inference/core/active_learning/post_processing.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
from inference.core.active_learning.entities import (
|
5 |
+
Prediction,
|
6 |
+
PredictionFileType,
|
7 |
+
PredictionType,
|
8 |
+
SerialisedPrediction,
|
9 |
+
)
|
10 |
+
from inference.core.constants import (
|
11 |
+
CLASSIFICATION_TASK,
|
12 |
+
INSTANCE_SEGMENTATION_TASK,
|
13 |
+
OBJECT_DETECTION_TASK,
|
14 |
+
)
|
15 |
+
from inference.core.exceptions import PredictionFormatNotSupported
|
16 |
+
|
17 |
+
|
18 |
+
def adjust_prediction_to_client_scaling_factor(
|
19 |
+
prediction: dict, scaling_factor: float, prediction_type: PredictionType
|
20 |
+
) -> dict:
|
21 |
+
if abs(scaling_factor - 1.0) < 1e-5:
|
22 |
+
return prediction
|
23 |
+
if "image" in prediction:
|
24 |
+
prediction["image"] = {
|
25 |
+
"width": round(prediction["image"]["width"] / scaling_factor),
|
26 |
+
"height": round(prediction["image"]["height"] / scaling_factor),
|
27 |
+
}
|
28 |
+
if predictions_should_not_be_post_processed(
|
29 |
+
prediction=prediction, prediction_type=prediction_type
|
30 |
+
):
|
31 |
+
return prediction
|
32 |
+
if prediction_type == INSTANCE_SEGMENTATION_TASK:
|
33 |
+
prediction["predictions"] = (
|
34 |
+
adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
|
35 |
+
predictions=prediction["predictions"],
|
36 |
+
scaling_factor=scaling_factor,
|
37 |
+
points_key="points",
|
38 |
+
)
|
39 |
+
)
|
40 |
+
if prediction_type == OBJECT_DETECTION_TASK:
|
41 |
+
prediction["predictions"] = (
|
42 |
+
adjust_object_detection_predictions_to_client_scaling_factor(
|
43 |
+
predictions=prediction["predictions"],
|
44 |
+
scaling_factor=scaling_factor,
|
45 |
+
)
|
46 |
+
)
|
47 |
+
return prediction
|
48 |
+
|
49 |
+
|
50 |
+
def predictions_should_not_be_post_processed(
|
51 |
+
prediction: dict, prediction_type: PredictionType
|
52 |
+
) -> bool:
|
53 |
+
# excluding from post-processing classification output, stub-output and empty predictions
|
54 |
+
return (
|
55 |
+
"is_stub" in prediction
|
56 |
+
or "predictions" not in prediction
|
57 |
+
or CLASSIFICATION_TASK in prediction_type
|
58 |
+
or len(prediction["predictions"]) == 0
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def adjust_object_detection_predictions_to_client_scaling_factor(
|
63 |
+
predictions: List[dict],
|
64 |
+
scaling_factor: float,
|
65 |
+
) -> List[dict]:
|
66 |
+
result = []
|
67 |
+
for prediction in predictions:
|
68 |
+
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
|
69 |
+
bbox=prediction,
|
70 |
+
scaling_factor=scaling_factor,
|
71 |
+
)
|
72 |
+
result.append(prediction)
|
73 |
+
return result
|
74 |
+
|
75 |
+
|
76 |
+
def adjust_prediction_with_bbox_and_points_to_client_scaling_factor(
|
77 |
+
predictions: List[dict],
|
78 |
+
scaling_factor: float,
|
79 |
+
points_key: str,
|
80 |
+
) -> List[dict]:
|
81 |
+
result = []
|
82 |
+
for prediction in predictions:
|
83 |
+
prediction = adjust_bbox_coordinates_to_client_scaling_factor(
|
84 |
+
bbox=prediction,
|
85 |
+
scaling_factor=scaling_factor,
|
86 |
+
)
|
87 |
+
prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor(
|
88 |
+
points=prediction[points_key],
|
89 |
+
scaling_factor=scaling_factor,
|
90 |
+
)
|
91 |
+
result.append(prediction)
|
92 |
+
return result
|
93 |
+
|
94 |
+
|
95 |
+
def adjust_bbox_coordinates_to_client_scaling_factor(
|
96 |
+
bbox: dict,
|
97 |
+
scaling_factor: float,
|
98 |
+
) -> dict:
|
99 |
+
bbox["x"] = bbox["x"] / scaling_factor
|
100 |
+
bbox["y"] = bbox["y"] / scaling_factor
|
101 |
+
bbox["width"] = bbox["width"] / scaling_factor
|
102 |
+
bbox["height"] = bbox["height"] / scaling_factor
|
103 |
+
return bbox
|
104 |
+
|
105 |
+
|
106 |
+
def adjust_points_coordinates_to_client_scaling_factor(
|
107 |
+
points: List[dict],
|
108 |
+
scaling_factor: float,
|
109 |
+
) -> List[dict]:
|
110 |
+
result = []
|
111 |
+
for point in points:
|
112 |
+
point["x"] = point["x"] / scaling_factor
|
113 |
+
point["y"] = point["y"] / scaling_factor
|
114 |
+
result.append(point)
|
115 |
+
return result
|
116 |
+
|
117 |
+
|
118 |
+
def encode_prediction(
|
119 |
+
prediction: Prediction,
|
120 |
+
prediction_type: PredictionType,
|
121 |
+
) -> Tuple[SerialisedPrediction, PredictionFileType]:
|
122 |
+
if CLASSIFICATION_TASK not in prediction_type:
|
123 |
+
return json.dumps(prediction), "json"
|
124 |
+
if "top" in prediction:
|
125 |
+
return prediction["top"], "txt"
|
126 |
+
raise PredictionFormatNotSupported(
|
127 |
+
f"Prediction type or prediction format not supported."
|
128 |
+
)
|
inference/core/active_learning/samplers/__init__.py
ADDED
File without changes
|
inference/core/active_learning/samplers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (201 Bytes). View file
|
|
inference/core/active_learning/samplers/__pycache__/close_to_threshold.cpython-310.pyc
ADDED
Binary file (4.68 kB). View file
|
|
inference/core/active_learning/samplers/__pycache__/contains_classes.cpython-310.pyc
ADDED
Binary file (1.71 kB). View file
|
|
inference/core/active_learning/samplers/__pycache__/number_of_detections.cpython-310.pyc
ADDED
Binary file (2.74 kB). View file
|
|
inference/core/active_learning/samplers/__pycache__/random.cpython-310.pyc
ADDED
Binary file (1.22 kB). View file
|
|
inference/core/active_learning/samplers/close_to_threshold.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from functools import partial
|
3 |
+
from typing import Any, Dict, Optional, Set
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from inference.core.active_learning.entities import (
|
8 |
+
Prediction,
|
9 |
+
PredictionType,
|
10 |
+
SamplingMethod,
|
11 |
+
)
|
12 |
+
from inference.core.constants import (
|
13 |
+
CLASSIFICATION_TASK,
|
14 |
+
INSTANCE_SEGMENTATION_TASK,
|
15 |
+
KEYPOINTS_DETECTION_TASK,
|
16 |
+
OBJECT_DETECTION_TASK,
|
17 |
+
)
|
18 |
+
from inference.core.exceptions import ActiveLearningConfigurationError
|
19 |
+
|
20 |
+
ELIGIBLE_PREDICTION_TYPES = {
|
21 |
+
CLASSIFICATION_TASK,
|
22 |
+
INSTANCE_SEGMENTATION_TASK,
|
23 |
+
KEYPOINTS_DETECTION_TASK,
|
24 |
+
OBJECT_DETECTION_TASK,
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def initialize_close_to_threshold_sampling(
|
29 |
+
strategy_config: Dict[str, Any]
|
30 |
+
) -> SamplingMethod:
|
31 |
+
try:
|
32 |
+
selected_class_names = strategy_config.get("selected_class_names")
|
33 |
+
if selected_class_names is not None:
|
34 |
+
selected_class_names = set(selected_class_names)
|
35 |
+
sample_function = partial(
|
36 |
+
sample_close_to_threshold,
|
37 |
+
selected_class_names=selected_class_names,
|
38 |
+
threshold=strategy_config["threshold"],
|
39 |
+
epsilon=strategy_config["epsilon"],
|
40 |
+
only_top_classes=strategy_config.get("only_top_classes", True),
|
41 |
+
minimum_objects_close_to_threshold=strategy_config.get(
|
42 |
+
"minimum_objects_close_to_threshold",
|
43 |
+
1,
|
44 |
+
),
|
45 |
+
probability=strategy_config["probability"],
|
46 |
+
)
|
47 |
+
return SamplingMethod(
|
48 |
+
name=strategy_config["name"],
|
49 |
+
sample=sample_function,
|
50 |
+
)
|
51 |
+
except KeyError as error:
|
52 |
+
raise ActiveLearningConfigurationError(
|
53 |
+
f"In configuration of `close_to_threshold_sampling` missing key detected: {error}."
|
54 |
+
) from error
|
55 |
+
|
56 |
+
|
57 |
+
def sample_close_to_threshold(
|
58 |
+
image: np.ndarray,
|
59 |
+
prediction: Prediction,
|
60 |
+
prediction_type: PredictionType,
|
61 |
+
selected_class_names: Optional[Set[str]],
|
62 |
+
threshold: float,
|
63 |
+
epsilon: float,
|
64 |
+
only_top_classes: bool,
|
65 |
+
minimum_objects_close_to_threshold: int,
|
66 |
+
probability: float,
|
67 |
+
) -> bool:
|
68 |
+
if is_prediction_a_stub(prediction=prediction):
|
69 |
+
return False
|
70 |
+
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
|
71 |
+
return False
|
72 |
+
close_to_threshold = prediction_is_close_to_threshold(
|
73 |
+
prediction=prediction,
|
74 |
+
prediction_type=prediction_type,
|
75 |
+
selected_class_names=selected_class_names,
|
76 |
+
threshold=threshold,
|
77 |
+
epsilon=epsilon,
|
78 |
+
only_top_classes=only_top_classes,
|
79 |
+
minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
|
80 |
+
)
|
81 |
+
if not close_to_threshold:
|
82 |
+
return False
|
83 |
+
return random.random() < probability
|
84 |
+
|
85 |
+
|
86 |
+
def is_prediction_a_stub(prediction: Prediction) -> bool:
|
87 |
+
return prediction.get("is_stub", False)
|
88 |
+
|
89 |
+
|
90 |
+
def prediction_is_close_to_threshold(
|
91 |
+
prediction: Prediction,
|
92 |
+
prediction_type: PredictionType,
|
93 |
+
selected_class_names: Optional[Set[str]],
|
94 |
+
threshold: float,
|
95 |
+
epsilon: float,
|
96 |
+
only_top_classes: bool,
|
97 |
+
minimum_objects_close_to_threshold: int,
|
98 |
+
) -> bool:
|
99 |
+
if CLASSIFICATION_TASK not in prediction_type:
|
100 |
+
return detections_are_close_to_threshold(
|
101 |
+
prediction=prediction,
|
102 |
+
selected_class_names=selected_class_names,
|
103 |
+
threshold=threshold,
|
104 |
+
epsilon=epsilon,
|
105 |
+
minimum_objects_close_to_threshold=minimum_objects_close_to_threshold,
|
106 |
+
)
|
107 |
+
checker = multi_label_classification_prediction_is_close_to_threshold
|
108 |
+
if "top" in prediction:
|
109 |
+
checker = multi_class_classification_prediction_is_close_to_threshold
|
110 |
+
return checker(
|
111 |
+
prediction=prediction,
|
112 |
+
selected_class_names=selected_class_names,
|
113 |
+
threshold=threshold,
|
114 |
+
epsilon=epsilon,
|
115 |
+
only_top_classes=only_top_classes,
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
def multi_class_classification_prediction_is_close_to_threshold(
|
120 |
+
prediction: Prediction,
|
121 |
+
selected_class_names: Optional[Set[str]],
|
122 |
+
threshold: float,
|
123 |
+
epsilon: float,
|
124 |
+
only_top_classes: bool,
|
125 |
+
) -> bool:
|
126 |
+
if only_top_classes:
|
127 |
+
return (
|
128 |
+
multi_class_classification_prediction_is_close_to_threshold_for_top_class(
|
129 |
+
prediction=prediction,
|
130 |
+
selected_class_names=selected_class_names,
|
131 |
+
threshold=threshold,
|
132 |
+
epsilon=epsilon,
|
133 |
+
)
|
134 |
+
)
|
135 |
+
for prediction_details in prediction["predictions"]:
|
136 |
+
if class_to_be_excluded(
|
137 |
+
class_name=prediction_details["class"],
|
138 |
+
selected_class_names=selected_class_names,
|
139 |
+
):
|
140 |
+
continue
|
141 |
+
if is_close_to_threshold(
|
142 |
+
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
|
143 |
+
):
|
144 |
+
return True
|
145 |
+
return False
|
146 |
+
|
147 |
+
|
148 |
+
def multi_class_classification_prediction_is_close_to_threshold_for_top_class(
|
149 |
+
prediction: Prediction,
|
150 |
+
selected_class_names: Optional[Set[str]],
|
151 |
+
threshold: float,
|
152 |
+
epsilon: float,
|
153 |
+
) -> bool:
|
154 |
+
if (
|
155 |
+
selected_class_names is not None
|
156 |
+
and prediction["top"] not in selected_class_names
|
157 |
+
):
|
158 |
+
return False
|
159 |
+
return abs(prediction["confidence"] - threshold) < epsilon
|
160 |
+
|
161 |
+
|
162 |
+
def multi_label_classification_prediction_is_close_to_threshold(
|
163 |
+
prediction: Prediction,
|
164 |
+
selected_class_names: Optional[Set[str]],
|
165 |
+
threshold: float,
|
166 |
+
epsilon: float,
|
167 |
+
only_top_classes: bool,
|
168 |
+
) -> bool:
|
169 |
+
predicted_classes = set(prediction["predicted_classes"])
|
170 |
+
for class_name, prediction_details in prediction["predictions"].items():
|
171 |
+
if only_top_classes and class_name not in predicted_classes:
|
172 |
+
continue
|
173 |
+
if class_to_be_excluded(
|
174 |
+
class_name=class_name, selected_class_names=selected_class_names
|
175 |
+
):
|
176 |
+
continue
|
177 |
+
if is_close_to_threshold(
|
178 |
+
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
|
179 |
+
):
|
180 |
+
return True
|
181 |
+
return False
|
182 |
+
|
183 |
+
|
184 |
+
def detections_are_close_to_threshold(
|
185 |
+
prediction: Prediction,
|
186 |
+
selected_class_names: Optional[Set[str]],
|
187 |
+
threshold: float,
|
188 |
+
epsilon: float,
|
189 |
+
minimum_objects_close_to_threshold: int,
|
190 |
+
) -> bool:
|
191 |
+
detections_close_to_threshold = count_detections_close_to_threshold(
|
192 |
+
prediction=prediction,
|
193 |
+
selected_class_names=selected_class_names,
|
194 |
+
threshold=threshold,
|
195 |
+
epsilon=epsilon,
|
196 |
+
)
|
197 |
+
return detections_close_to_threshold >= minimum_objects_close_to_threshold
|
198 |
+
|
199 |
+
|
200 |
+
def count_detections_close_to_threshold(
|
201 |
+
prediction: Prediction,
|
202 |
+
selected_class_names: Optional[Set[str]],
|
203 |
+
threshold: float,
|
204 |
+
epsilon: float,
|
205 |
+
) -> int:
|
206 |
+
counter = 0
|
207 |
+
for prediction_details in prediction["predictions"]:
|
208 |
+
if class_to_be_excluded(
|
209 |
+
class_name=prediction_details["class"],
|
210 |
+
selected_class_names=selected_class_names,
|
211 |
+
):
|
212 |
+
continue
|
213 |
+
if is_close_to_threshold(
|
214 |
+
value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon
|
215 |
+
):
|
216 |
+
counter += 1
|
217 |
+
return counter
|
218 |
+
|
219 |
+
|
220 |
+
def class_to_be_excluded(
|
221 |
+
class_name: str, selected_class_names: Optional[Set[str]]
|
222 |
+
) -> bool:
|
223 |
+
return selected_class_names is not None and class_name not in selected_class_names
|
224 |
+
|
225 |
+
|
226 |
+
def is_close_to_threshold(value: float, threshold: float, epsilon: float) -> bool:
|
227 |
+
return abs(value - threshold) < epsilon
|
inference/core/active_learning/samplers/contains_classes.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Any, Dict, Set
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from inference.core.active_learning.entities import (
|
7 |
+
Prediction,
|
8 |
+
PredictionType,
|
9 |
+
SamplingMethod,
|
10 |
+
)
|
11 |
+
from inference.core.active_learning.samplers.close_to_threshold import (
|
12 |
+
sample_close_to_threshold,
|
13 |
+
)
|
14 |
+
from inference.core.constants import CLASSIFICATION_TASK
|
15 |
+
from inference.core.exceptions import ActiveLearningConfigurationError
|
16 |
+
|
17 |
+
ELIGIBLE_PREDICTION_TYPES = {CLASSIFICATION_TASK}
|
18 |
+
|
19 |
+
|
20 |
+
def initialize_classes_based_sampling(
|
21 |
+
strategy_config: Dict[str, Any]
|
22 |
+
) -> SamplingMethod:
|
23 |
+
try:
|
24 |
+
sample_function = partial(
|
25 |
+
sample_based_on_classes,
|
26 |
+
selected_class_names=set(strategy_config["selected_class_names"]),
|
27 |
+
probability=strategy_config["probability"],
|
28 |
+
)
|
29 |
+
return SamplingMethod(
|
30 |
+
name=strategy_config["name"],
|
31 |
+
sample=sample_function,
|
32 |
+
)
|
33 |
+
except KeyError as error:
|
34 |
+
raise ActiveLearningConfigurationError(
|
35 |
+
f"In configuration of `classes_based_sampling` missing key detected: {error}."
|
36 |
+
) from error
|
37 |
+
|
38 |
+
|
39 |
+
def sample_based_on_classes(
|
40 |
+
image: np.ndarray,
|
41 |
+
prediction: Prediction,
|
42 |
+
prediction_type: PredictionType,
|
43 |
+
selected_class_names: Set[str],
|
44 |
+
probability: float,
|
45 |
+
) -> bool:
|
46 |
+
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
|
47 |
+
return False
|
48 |
+
return sample_close_to_threshold(
|
49 |
+
image=image,
|
50 |
+
prediction=prediction,
|
51 |
+
prediction_type=prediction_type,
|
52 |
+
selected_class_names=selected_class_names,
|
53 |
+
threshold=0.5,
|
54 |
+
epsilon=1.0,
|
55 |
+
only_top_classes=True,
|
56 |
+
minimum_objects_close_to_threshold=1,
|
57 |
+
probability=probability,
|
58 |
+
)
|
inference/core/active_learning/samplers/number_of_detections.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from functools import partial
|
3 |
+
from typing import Any, Dict, Optional, Set
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from inference.core.active_learning.entities import (
|
8 |
+
Prediction,
|
9 |
+
PredictionType,
|
10 |
+
SamplingMethod,
|
11 |
+
)
|
12 |
+
from inference.core.active_learning.samplers.close_to_threshold import (
|
13 |
+
count_detections_close_to_threshold,
|
14 |
+
is_prediction_a_stub,
|
15 |
+
)
|
16 |
+
from inference.core.constants import (
|
17 |
+
INSTANCE_SEGMENTATION_TASK,
|
18 |
+
KEYPOINTS_DETECTION_TASK,
|
19 |
+
OBJECT_DETECTION_TASK,
|
20 |
+
)
|
21 |
+
from inference.core.exceptions import ActiveLearningConfigurationError
|
22 |
+
|
23 |
+
ELIGIBLE_PREDICTION_TYPES = {
|
24 |
+
INSTANCE_SEGMENTATION_TASK,
|
25 |
+
KEYPOINTS_DETECTION_TASK,
|
26 |
+
OBJECT_DETECTION_TASK,
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def initialize_detections_number_based_sampling(
|
31 |
+
strategy_config: Dict[str, Any]
|
32 |
+
) -> SamplingMethod:
|
33 |
+
try:
|
34 |
+
more_than = strategy_config.get("more_than")
|
35 |
+
less_than = strategy_config.get("less_than")
|
36 |
+
ensure_range_configuration_is_valid(more_than=more_than, less_than=less_than)
|
37 |
+
selected_class_names = strategy_config.get("selected_class_names")
|
38 |
+
if selected_class_names is not None:
|
39 |
+
selected_class_names = set(selected_class_names)
|
40 |
+
sample_function = partial(
|
41 |
+
sample_based_on_detections_number,
|
42 |
+
less_than=less_than,
|
43 |
+
more_than=more_than,
|
44 |
+
selected_class_names=selected_class_names,
|
45 |
+
probability=strategy_config["probability"],
|
46 |
+
)
|
47 |
+
return SamplingMethod(
|
48 |
+
name=strategy_config["name"],
|
49 |
+
sample=sample_function,
|
50 |
+
)
|
51 |
+
except KeyError as error:
|
52 |
+
raise ActiveLearningConfigurationError(
|
53 |
+
f"In configuration of `detections_number_based_sampling` missing key detected: {error}."
|
54 |
+
) from error
|
55 |
+
|
56 |
+
|
57 |
+
def sample_based_on_detections_number(
|
58 |
+
image: np.ndarray,
|
59 |
+
prediction: Prediction,
|
60 |
+
prediction_type: PredictionType,
|
61 |
+
more_than: Optional[int],
|
62 |
+
less_than: Optional[int],
|
63 |
+
selected_class_names: Optional[Set[str]],
|
64 |
+
probability: float,
|
65 |
+
) -> bool:
|
66 |
+
if is_prediction_a_stub(prediction=prediction):
|
67 |
+
return False
|
68 |
+
if prediction_type not in ELIGIBLE_PREDICTION_TYPES:
|
69 |
+
return False
|
70 |
+
detections_close_to_threshold = count_detections_close_to_threshold(
|
71 |
+
prediction=prediction,
|
72 |
+
selected_class_names=selected_class_names,
|
73 |
+
threshold=0.5,
|
74 |
+
epsilon=1.0,
|
75 |
+
)
|
76 |
+
if is_in_range(
|
77 |
+
value=detections_close_to_threshold, less_than=less_than, more_than=more_than
|
78 |
+
):
|
79 |
+
return random.random() < probability
|
80 |
+
return False
|
81 |
+
|
82 |
+
|
83 |
+
def is_in_range(
|
84 |
+
value: int,
|
85 |
+
more_than: Optional[int],
|
86 |
+
less_than: Optional[int],
|
87 |
+
) -> bool:
|
88 |
+
# calculates value > more_than and value < less_than, with optional borders of range
|
89 |
+
less_than_satisfied, more_than_satisfied = less_than is None, more_than is None
|
90 |
+
if less_than is not None and value < less_than:
|
91 |
+
less_than_satisfied = True
|
92 |
+
if more_than is not None and value > more_than:
|
93 |
+
more_than_satisfied = True
|
94 |
+
return less_than_satisfied and more_than_satisfied
|
95 |
+
|
96 |
+
|
97 |
+
def ensure_range_configuration_is_valid(
|
98 |
+
more_than: Optional[int],
|
99 |
+
less_than: Optional[int],
|
100 |
+
) -> None:
|
101 |
+
if more_than is None or less_than is None:
|
102 |
+
return None
|
103 |
+
if more_than >= less_than:
|
104 |
+
raise ActiveLearningConfigurationError(
|
105 |
+
f"Misconfiguration of detections number sampling: "
|
106 |
+
f"`more_than` parameter ({more_than}) >= `less_than` ({less_than})."
|
107 |
+
)
|
inference/core/active_learning/samplers/random.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from functools import partial
|
3 |
+
from typing import Any, Dict
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from inference.core.active_learning.entities import (
|
8 |
+
Prediction,
|
9 |
+
PredictionType,
|
10 |
+
SamplingMethod,
|
11 |
+
)
|
12 |
+
from inference.core.exceptions import ActiveLearningConfigurationError
|
13 |
+
|
14 |
+
|
15 |
+
def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod:
|
16 |
+
try:
|
17 |
+
sample_function = partial(
|
18 |
+
sample_randomly,
|
19 |
+
traffic_percentage=strategy_config["traffic_percentage"],
|
20 |
+
)
|
21 |
+
return SamplingMethod(
|
22 |
+
name=strategy_config["name"],
|
23 |
+
sample=sample_function,
|
24 |
+
)
|
25 |
+
except KeyError as error:
|
26 |
+
raise ActiveLearningConfigurationError(
|
27 |
+
f"In configuration of `random_sampling` missing key detected: {error}."
|
28 |
+
) from error
|
29 |
+
|
30 |
+
|
31 |
+
def sample_randomly(
|
32 |
+
image: np.ndarray,
|
33 |
+
prediction: Prediction,
|
34 |
+
prediction_type: PredictionType,
|
35 |
+
traffic_percentage: float,
|
36 |
+
) -> bool:
|
37 |
+
return random.random() < traffic_percentage
|
inference/core/active_learning/utils.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime, timedelta
|
2 |
+
|
3 |
+
TIMESTAMP_FORMAT = "%Y_%m_%d"
|
4 |
+
|
5 |
+
|
6 |
+
def generate_today_timestamp() -> str:
|
7 |
+
return datetime.today().strftime(TIMESTAMP_FORMAT)
|
8 |
+
|
9 |
+
|
10 |
+
def generate_start_timestamp_for_this_week() -> str:
|
11 |
+
today = datetime.today()
|
12 |
+
return (today - timedelta(days=today.weekday())).strftime(TIMESTAMP_FORMAT)
|
13 |
+
|
14 |
+
|
15 |
+
def generate_start_timestamp_for_this_month() -> str:
|
16 |
+
return datetime.today().replace(day=1).strftime(TIMESTAMP_FORMAT)
|
inference/core/cache/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from redis.exceptions import ConnectionError, TimeoutError
|
2 |
+
|
3 |
+
from inference.core import logger
|
4 |
+
from inference.core.cache.memory import MemoryCache
|
5 |
+
from inference.core.cache.redis import RedisCache
|
6 |
+
from inference.core.env import REDIS_HOST, REDIS_PORT, REDIS_SSL, REDIS_TIMEOUT
|
7 |
+
|
8 |
+
if REDIS_HOST is not None:
|
9 |
+
try:
|
10 |
+
cache = RedisCache(
|
11 |
+
host=REDIS_HOST, port=REDIS_PORT, ssl=REDIS_SSL, timeout=REDIS_TIMEOUT
|
12 |
+
)
|
13 |
+
logger.info(f"Redis Cache initialised")
|
14 |
+
except (TimeoutError, ConnectionError):
|
15 |
+
logger.error(
|
16 |
+
f"Could not connect to Redis under {REDIS_HOST}:{REDIS_PORT}. MemoryCache to be used."
|
17 |
+
)
|
18 |
+
cache = MemoryCache()
|
19 |
+
logger.info(f"Memory Cache initialised")
|
20 |
+
else:
|
21 |
+
cache = MemoryCache()
|
22 |
+
logger.info(f"Memory Cache initialised")
|
inference/core/cache/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (864 Bytes). View file
|
|
inference/core/cache/__pycache__/base.cpython-310.pyc
ADDED
Binary file (4.93 kB). View file
|
|
inference/core/cache/__pycache__/memory.cpython-310.pyc
ADDED
Binary file (6.56 kB). View file
|
|
inference/core/cache/__pycache__/model_artifacts.cpython-310.pyc
ADDED
Binary file (3.17 kB). View file
|
|
inference/core/cache/__pycache__/redis.cpython-310.pyc
ADDED
Binary file (7.3 kB). View file
|
|
inference/core/cache/__pycache__/serializers.cpython-310.pyc
ADDED
Binary file (1.91 kB). View file
|
|
inference/core/cache/base.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from typing import Any, Optional
|
3 |
+
|
4 |
+
from inference.core import logger
|
5 |
+
|
6 |
+
|
7 |
+
class BaseCache:
|
8 |
+
"""
|
9 |
+
BaseCache is an abstract base class that defines the interface for a cache.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def get(self, key: str):
|
13 |
+
"""
|
14 |
+
Gets the value associated with the given key.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
key (str): The key to retrieve the value.
|
18 |
+
|
19 |
+
Raises:
|
20 |
+
NotImplementedError: This method must be implemented by subclasses.
|
21 |
+
"""
|
22 |
+
raise NotImplementedError()
|
23 |
+
|
24 |
+
def set(self, key: str, value: str, expire: float = None):
|
25 |
+
"""
|
26 |
+
Sets a value for a given key with an optional expire time.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
key (str): The key to store the value.
|
30 |
+
value (str): The value to store.
|
31 |
+
expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
NotImplementedError: This method must be implemented by subclasses.
|
35 |
+
"""
|
36 |
+
raise NotImplementedError()
|
37 |
+
|
38 |
+
def zadd(self, key: str, value: str, score: float, expire: float = None):
|
39 |
+
"""
|
40 |
+
Adds a member with the specified score to the sorted set stored at key.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
key (str): The key of the sorted set.
|
44 |
+
value (str): The value to add to the sorted set.
|
45 |
+
score (float): The score associated with the value.
|
46 |
+
expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
|
47 |
+
|
48 |
+
Raises:
|
49 |
+
NotImplementedError: This method must be implemented by subclasses.
|
50 |
+
"""
|
51 |
+
raise NotImplementedError()
|
52 |
+
|
53 |
+
def zrangebyscore(
|
54 |
+
self,
|
55 |
+
key: str,
|
56 |
+
min: Optional[float] = -1,
|
57 |
+
max: Optional[float] = float("inf"),
|
58 |
+
withscores: bool = False,
|
59 |
+
):
|
60 |
+
"""
|
61 |
+
Retrieves a range of members from a sorted set.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
key (str): The key of the sorted set.
|
65 |
+
start (int, optional): The starting index of the range. Defaults to -1.
|
66 |
+
stop (int, optional): The ending index of the range. Defaults to float("inf").
|
67 |
+
withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.
|
68 |
+
|
69 |
+
Raises:
|
70 |
+
NotImplementedError: This method must be implemented by subclasses.
|
71 |
+
"""
|
72 |
+
raise NotImplementedError()
|
73 |
+
|
74 |
+
def zremrangebyscore(
|
75 |
+
self,
|
76 |
+
key: str,
|
77 |
+
start: Optional[int] = -1,
|
78 |
+
stop: Optional[int] = float("inf"),
|
79 |
+
):
|
80 |
+
"""
|
81 |
+
Removes all members in a sorted set within the given scores.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
key (str): The key of the sorted set.
|
85 |
+
start (int, optional): The minimum score of the range. Defaults to -1.
|
86 |
+
stop (int, optional): The maximum score of the range. Defaults to float("inf").
|
87 |
+
|
88 |
+
Raises:
|
89 |
+
NotImplementedError: This method must be implemented by subclasses.
|
90 |
+
"""
|
91 |
+
raise NotImplementedError()
|
92 |
+
|
93 |
+
def acquire_lock(self, key: str, expire: float = None) -> Any:
|
94 |
+
raise NotImplementedError()
|
95 |
+
|
96 |
+
@contextmanager
|
97 |
+
def lock(self, key: str, expire: float = None) -> Any:
|
98 |
+
logger.debug(f"Acquiring lock at cache key: {key}")
|
99 |
+
l = self.acquire_lock(key, expire=expire)
|
100 |
+
try:
|
101 |
+
yield l
|
102 |
+
finally:
|
103 |
+
logger.debug(f"Releasing lock at cache key: {key}")
|
104 |
+
l.release()
|
105 |
+
|
106 |
+
def set_numpy(self, key: str, value: Any, expire: float = None):
|
107 |
+
"""
|
108 |
+
Caches a numpy array.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
key (str): The key to store the value.
|
112 |
+
value (Any): The value to store.
|
113 |
+
expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
|
114 |
+
|
115 |
+
Raises:
|
116 |
+
NotImplementedError: This method must be implemented by subclasses.
|
117 |
+
"""
|
118 |
+
raise NotImplementedError()
|
119 |
+
|
120 |
+
def get_numpy(self, key: str) -> Any:
|
121 |
+
"""
|
122 |
+
Retrieves a numpy array from the cache.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
key (str): The key of the value to retrieve.
|
126 |
+
|
127 |
+
Raises:
|
128 |
+
NotImplementedError: This method must be implemented by subclasses.
|
129 |
+
"""
|
130 |
+
raise NotImplementedError()
|