Spaces:
Paused
Paused
PySpaces
Browse files- .gitignore +1 -0
- spaces/__init__.py +23 -0
- spaces/config.py +29 -0
- spaces/gradio.py +55 -0
- spaces/utils.py +73 -0
- spaces/zero/__init__.py +12 -0
- spaces/zero/api.py +154 -0
- spaces/zero/bitsandbytes.py +135 -0
- spaces/zero/client.py +175 -0
- spaces/zero/decorator.py +117 -0
- spaces/zero/gradio.py +108 -0
- spaces/zero/torch.py +279 -0
- spaces/zero/tqdm.py +14 -0
- spaces/zero/types.py +44 -0
- spaces/zero/utils.py +44 -0
- spaces/zero/wrappers.py +340 -0
.gitignore
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
.idea
|
2 |
.DS_Store
|
3 |
__pycache__
|
|
|
|
1 |
.idea
|
2 |
.DS_Store
|
3 |
__pycache__
|
4 |
+
*.pyc
|
spaces/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
import sys
|
5 |
+
|
6 |
+
if sys.version_info.minor < 8: # pragma: no cover
|
7 |
+
raise RuntimeError("Importing PySpaces requires Python 3.8+")
|
8 |
+
|
9 |
+
|
10 |
+
from .zero.decorator import GPU
|
11 |
+
from .zero.torch import disable_cuda_intercept
|
12 |
+
from .gradio import gradio_auto_wrap
|
13 |
+
from .gradio import disable_gradio_auto_wrap
|
14 |
+
from .gradio import enable_gradio_auto_wrap
|
15 |
+
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
'GPU',
|
19 |
+
'disable_cuda_intercept',
|
20 |
+
'gradio_auto_wrap',
|
21 |
+
'disable_gradio_auto_wrap',
|
22 |
+
'enable_gradio_auto_wrap',
|
23 |
+
]
|
spaces/config.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from .utils import boolean
|
8 |
+
|
9 |
+
|
10 |
+
class Settings:
|
11 |
+
def __init__(self):
|
12 |
+
self.zero_gpu = boolean(
|
13 |
+
os.getenv('SPACES_ZERO_GPU'))
|
14 |
+
self.zero_device_api_url = (
|
15 |
+
os.getenv('SPACES_ZERO_DEVICE_API_URL'))
|
16 |
+
self.gradio_auto_wrap = boolean(
|
17 |
+
os.getenv('SPACES_GRADIO_AUTO_WRAP'))
|
18 |
+
self.zero_patch_torch_device = boolean(
|
19 |
+
os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
|
20 |
+
|
21 |
+
|
22 |
+
Config = Settings()
|
23 |
+
|
24 |
+
|
25 |
+
if Config.zero_gpu:
|
26 |
+
assert Config.zero_device_api_url is not None, (
|
27 |
+
'SPACES_ZERO_DEVICE_API_URL env must be set '
|
28 |
+
'on ZeroGPU Spaces (identified by SPACES_ZERO_GPU=true)'
|
29 |
+
)
|
spaces/gradio.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from typing import Callable
|
6 |
+
from typing import Generator
|
7 |
+
from typing import TypeVar
|
8 |
+
from typing import overload
|
9 |
+
from typing_extensions import ParamSpec
|
10 |
+
|
11 |
+
from .config import Config
|
12 |
+
from .zero.decorator import GPU
|
13 |
+
|
14 |
+
|
15 |
+
Param = ParamSpec('Param')
|
16 |
+
Res = TypeVar('Res')
|
17 |
+
|
18 |
+
|
19 |
+
gradio_auto_wrap_enabled = Config.gradio_auto_wrap
|
20 |
+
|
21 |
+
|
22 |
+
def disable_gradio_auto_wrap():
|
23 |
+
global gradio_auto_wrap_enabled
|
24 |
+
gradio_auto_wrap_enabled = False
|
25 |
+
|
26 |
+
def enable_gradio_auto_wrap():
|
27 |
+
global gradio_auto_wrap_enabled
|
28 |
+
gradio_auto_wrap_enabled = True
|
29 |
+
|
30 |
+
|
31 |
+
@overload
|
32 |
+
def gradio_auto_wrap(
|
33 |
+
task:
|
34 |
+
Callable[Param, Res],
|
35 |
+
) -> Callable[Param, Res]:
|
36 |
+
...
|
37 |
+
@overload
|
38 |
+
def gradio_auto_wrap(
|
39 |
+
task:
|
40 |
+
None,
|
41 |
+
) -> None:
|
42 |
+
...
|
43 |
+
def gradio_auto_wrap(
|
44 |
+
task:
|
45 |
+
Callable[Param, Res]
|
46 |
+
| None,
|
47 |
+
) -> (Callable[Param, Res]
|
48 |
+
| None):
|
49 |
+
"""
|
50 |
+
"""
|
51 |
+
if not gradio_auto_wrap_enabled:
|
52 |
+
return task
|
53 |
+
if not callable(task):
|
54 |
+
return task
|
55 |
+
return GPU(task) # type: ignore
|
spaces/utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import sys
|
6 |
+
from functools import lru_cache as cache
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import multiprocessing
|
10 |
+
from multiprocessing.queues import SimpleQueue as _SimpleQueue
|
11 |
+
from pathlib import Path
|
12 |
+
from pickle import PicklingError
|
13 |
+
from typing import Callable
|
14 |
+
from typing import TypeVar
|
15 |
+
|
16 |
+
|
17 |
+
GRADIO_VERSION_ERROR_MESSAGE = "Make sure Gradio version is at least 3.46"
|
18 |
+
|
19 |
+
|
20 |
+
T = TypeVar('T')
|
21 |
+
|
22 |
+
|
23 |
+
@cache
|
24 |
+
def self_cgroup_device_path() -> str:
|
25 |
+
cgroup_content = Path('/proc/self/cgroup').read_text()
|
26 |
+
for line in cgroup_content.strip().split('\n'):
|
27 |
+
contents = line.split(':devices:')
|
28 |
+
if len(contents) != 2:
|
29 |
+
continue # pragma: no cover
|
30 |
+
return contents[1]
|
31 |
+
raise Exception # pragma: no cover
|
32 |
+
|
33 |
+
|
34 |
+
if sys.version_info.minor < 9: # pragma: no cover
|
35 |
+
_SimpleQueue.__class_getitem__ = classmethod(lambda cls, _: cls) # type: ignore
|
36 |
+
|
37 |
+
class SimpleQueue(_SimpleQueue[T]):
|
38 |
+
def __init__(self, *args):
|
39 |
+
super().__init__(*args, ctx=multiprocessing.get_context('fork'))
|
40 |
+
def put(self, obj: T):
|
41 |
+
try:
|
42 |
+
super().put(obj)
|
43 |
+
except PicklingError:
|
44 |
+
raise # pragma: no cover
|
45 |
+
# https://bugs.python.org/issue29187
|
46 |
+
except Exception as e:
|
47 |
+
message = str(e)
|
48 |
+
if not "pickle" in message:
|
49 |
+
raise # pragma: no cover
|
50 |
+
raise PicklingError(message)
|
51 |
+
def close(self): # Python 3.8 static typing trick
|
52 |
+
super().close() # type: ignore
|
53 |
+
|
54 |
+
|
55 |
+
def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
|
56 |
+
def drop(*args):
|
57 |
+
return fn()
|
58 |
+
return drop
|
59 |
+
|
60 |
+
|
61 |
+
def boolean(value: str | None) -> bool:
|
62 |
+
return value is not None and value.lower() in ("1", "t", "true")
|
63 |
+
|
64 |
+
|
65 |
+
def gradio_request_var():
|
66 |
+
try:
|
67 |
+
from gradio.context import LocalContext
|
68 |
+
except ImportError: # pragma: no cover
|
69 |
+
raise RuntimeError(GRADIO_VERSION_ERROR_MESSAGE)
|
70 |
+
return LocalContext.request
|
71 |
+
|
72 |
+
|
73 |
+
debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
|
spaces/zero/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
from ..config import Config
|
5 |
+
from . import torch
|
6 |
+
|
7 |
+
if Config.zero_gpu:
|
8 |
+
if torch.is_in_bad_fork():
|
9 |
+
raise RuntimeError(
|
10 |
+
"CUDA has been initialized before importing the `spaces` package"
|
11 |
+
)
|
12 |
+
torch.patch() # pragma: no cover
|
spaces/zero/api.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Synced with huggingface/pyspaces:spaces/zero/api.py
|
3 |
+
"""
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
from datetime import timedelta
|
7 |
+
from typing import Any
|
8 |
+
from typing import Generator
|
9 |
+
from typing import Literal
|
10 |
+
from typing import NamedTuple
|
11 |
+
from typing import Optional
|
12 |
+
from typing import overload
|
13 |
+
|
14 |
+
import httpx
|
15 |
+
from pydantic import BaseModel
|
16 |
+
from typing_extensions import assert_never
|
17 |
+
|
18 |
+
|
19 |
+
AllowToken = str
|
20 |
+
NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
|
21 |
+
NvidiaUUID = str
|
22 |
+
CGroupPath = str
|
23 |
+
VisitorId = str
|
24 |
+
Score = float
|
25 |
+
|
26 |
+
|
27 |
+
class ScheduleResponse(BaseModel):
|
28 |
+
idle: bool
|
29 |
+
nvidiaIndex: int
|
30 |
+
nvidiaUUID: str
|
31 |
+
allowToken: str | None
|
32 |
+
|
33 |
+
|
34 |
+
class QuotaInfos(BaseModel):
|
35 |
+
left: int
|
36 |
+
wait: timedelta
|
37 |
+
|
38 |
+
|
39 |
+
class ReportUsageMonitoringParams(NamedTuple):
|
40 |
+
nvidia_index: int
|
41 |
+
visitor_id: str
|
42 |
+
duration: timedelta
|
43 |
+
|
44 |
+
|
45 |
+
class QueueEvent(BaseModel):
|
46 |
+
event: Literal['ping', 'failed', 'succeeded']
|
47 |
+
data: Optional[ScheduleResponse] = None
|
48 |
+
|
49 |
+
|
50 |
+
def sse_parse(text: str):
|
51 |
+
event, *data = text.strip().splitlines()
|
52 |
+
assert event.startswith('event:')
|
53 |
+
event = event[6:].strip()
|
54 |
+
if event in ('ping', 'failed'):
|
55 |
+
return QueueEvent(event=event)
|
56 |
+
assert event == 'succeeded'
|
57 |
+
(data,) = data
|
58 |
+
assert data.startswith('data:')
|
59 |
+
data = data[5:].strip()
|
60 |
+
return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
|
61 |
+
|
62 |
+
|
63 |
+
def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
|
64 |
+
for text in res.iter_text():
|
65 |
+
if len(text) == 0:
|
66 |
+
break # pragma: no cover
|
67 |
+
try:
|
68 |
+
yield sse_parse(text)
|
69 |
+
except GeneratorExit:
|
70 |
+
res.close()
|
71 |
+
break
|
72 |
+
|
73 |
+
|
74 |
+
class APIClient:
|
75 |
+
|
76 |
+
def __init__(self, client: httpx.Client):
|
77 |
+
self.client = client
|
78 |
+
|
79 |
+
def startup_report(self) -> httpx.codes:
|
80 |
+
res = self.client.post('/startup-report')
|
81 |
+
return httpx.codes(res.status_code)
|
82 |
+
|
83 |
+
def schedule(
|
84 |
+
self,
|
85 |
+
cgroup_path: str,
|
86 |
+
task_id: int = 0,
|
87 |
+
token: str | None = None,
|
88 |
+
duration_seconds: int | None = None,
|
89 |
+
enable_queue: bool = True,
|
90 |
+
):
|
91 |
+
params: dict[str, str | int | bool] = {
|
92 |
+
'cgroupPath': cgroup_path,
|
93 |
+
'taskId': task_id,
|
94 |
+
'enableQueue': enable_queue,
|
95 |
+
}
|
96 |
+
if duration_seconds is not None:
|
97 |
+
params['durationSeconds'] = duration_seconds
|
98 |
+
if token is not None:
|
99 |
+
params['token'] = token
|
100 |
+
res = self.client.send(
|
101 |
+
request=self.client.build_request(
|
102 |
+
method='POST',
|
103 |
+
url='/schedule',
|
104 |
+
params=params,
|
105 |
+
),
|
106 |
+
stream=True,
|
107 |
+
)
|
108 |
+
status = httpx.codes(res.status_code)
|
109 |
+
if (status is not httpx.codes.OK and
|
110 |
+
status is not httpx.codes.TOO_MANY_REQUESTS
|
111 |
+
):
|
112 |
+
res.close()
|
113 |
+
return status
|
114 |
+
if "text/event-stream" in res.headers['content-type']:
|
115 |
+
return sse_stream(res)
|
116 |
+
res.read()
|
117 |
+
if status is httpx.codes.TOO_MANY_REQUESTS:
|
118 |
+
return QuotaInfos(**res.json()) # pragma: no cover
|
119 |
+
if status is httpx.codes.OK:
|
120 |
+
return ScheduleResponse(**res.json())
|
121 |
+
assert_never(status)
|
122 |
+
|
123 |
+
def allow(
|
124 |
+
self,
|
125 |
+
allow_token: str,
|
126 |
+
pid: int,
|
127 |
+
):
|
128 |
+
res = self.client.post('/allow', params={
|
129 |
+
'allowToken': allow_token,
|
130 |
+
'pid': pid,
|
131 |
+
})
|
132 |
+
return httpx.codes(res.status_code)
|
133 |
+
|
134 |
+
def release(
|
135 |
+
self,
|
136 |
+
nvidia_index: int,
|
137 |
+
cgroup_path: str,
|
138 |
+
task_id: int = 0,
|
139 |
+
fail: bool = False,
|
140 |
+
) -> httpx.codes:
|
141 |
+
res = self.client.post('/release', params={
|
142 |
+
'nvidiaIndex': nvidia_index,
|
143 |
+
'cgroupPath': cgroup_path,
|
144 |
+
'taskId': task_id,
|
145 |
+
'fail': fail,
|
146 |
+
})
|
147 |
+
return httpx.codes(res.status_code)
|
148 |
+
|
149 |
+
def get_queue_size(self) -> int:
|
150 |
+
res = self.client.get('/queue-size')
|
151 |
+
assert res.status_code == 200, res.status_code
|
152 |
+
size = res.json()
|
153 |
+
assert isinstance(size, int)
|
154 |
+
return size
|
spaces/zero/bitsandbytes.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
# pyright: reportPrivateImportUsage=false
|
4 |
+
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
import importlib
|
8 |
+
from typing import TYPE_CHECKING
|
9 |
+
from typing import Tuple
|
10 |
+
|
11 |
+
from .utils import cuda_unavailable
|
12 |
+
from .utils import maybe_import_torch
|
13 |
+
from .utils import maybe_import_bitsandbytes
|
14 |
+
|
15 |
+
if TYPE_CHECKING:
|
16 |
+
import torch as Torch
|
17 |
+
|
18 |
+
|
19 |
+
if (torch := maybe_import_torch()) and (bnb := maybe_import_bitsandbytes()):
|
20 |
+
|
21 |
+
from torch.utils.weak import WeakTensorKeyDictionary
|
22 |
+
|
23 |
+
with cuda_unavailable(torch):
|
24 |
+
from bitsandbytes import cextension
|
25 |
+
from bitsandbytes import functional
|
26 |
+
try: # bitsandbytes < 0.44
|
27 |
+
from bitsandbytes.cuda_setup.main import CUDASetup
|
28 |
+
except ModuleNotFoundError: # pragma: no cover
|
29 |
+
CUDASetup = None
|
30 |
+
from bitsandbytes.nn import Int8Params
|
31 |
+
from bitsandbytes.nn import Params4bit
|
32 |
+
|
33 |
+
_param_to_8bit = Int8Params.to # type: ignore
|
34 |
+
_param_cuda_8bit = Int8Params.cuda
|
35 |
+
_param_to_4bit = Params4bit.to # type: ignore
|
36 |
+
_param_cuda_4bit = Params4bit.cuda
|
37 |
+
|
38 |
+
TensorToArgs = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
|
39 |
+
|
40 |
+
to_ops_8bit: dict[Int8Params, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
|
41 |
+
to_ops_4bit: dict[Params4bit, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
|
42 |
+
|
43 |
+
def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
|
44 |
+
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
45 |
+
device, *_ = parsed
|
46 |
+
if not isinstance(device, torch.device): # pragma: no cover
|
47 |
+
return _param_to_8bit(self, *args, **kwargs)
|
48 |
+
if device.type != 'cuda':
|
49 |
+
return _param_to_8bit(self, *args, **kwargs)
|
50 |
+
to_ops_8bit[self] = parsed
|
51 |
+
return self
|
52 |
+
|
53 |
+
def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
|
54 |
+
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
55 |
+
device, *_ = parsed
|
56 |
+
if not isinstance(device, torch.device): # pragma: no cover
|
57 |
+
return _param_to_4bit(self, *args, **kwargs)
|
58 |
+
if device.type != 'cuda':
|
59 |
+
return _param_to_4bit(self, *args, **kwargs)
|
60 |
+
to_ops_4bit[self] = parsed
|
61 |
+
return self
|
62 |
+
|
63 |
+
def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
|
64 |
+
if device is None: # pragma: no cover
|
65 |
+
return True
|
66 |
+
if isinstance(device, int):
|
67 |
+
return True
|
68 |
+
if isinstance(device, str): # pragma: no cover
|
69 |
+
device = torch.device(device)
|
70 |
+
return device.type == 'cuda' # pragma: no cover
|
71 |
+
|
72 |
+
def _cuda_op_register_8bit(self: Int8Params, device: Torch.device | int | str | None = None, **kwargs):
|
73 |
+
if not _cuda_op_arg_check(device): # pragma: no cover
|
74 |
+
# Let PyTorch handle the fail
|
75 |
+
return _param_cuda_8bit(self, device, **kwargs)
|
76 |
+
to_ops_8bit[self] = None
|
77 |
+
return self
|
78 |
+
|
79 |
+
def _cuda_op_register_4bit(self: Params4bit, device: Torch.device | int | str | None = None, **kwargs):
|
80 |
+
if not _cuda_op_arg_check(device): # pragma: no cover
|
81 |
+
# Let PyTorch handle the fail
|
82 |
+
return _param_cuda_4bit(self, device, **kwargs)
|
83 |
+
to_ops_4bit[self] = None
|
84 |
+
return self
|
85 |
+
|
86 |
+
def _patch():
|
87 |
+
Int8Params.to = _to_op_register_8bit # type: ignore
|
88 |
+
Int8Params.cuda = _cuda_op_register_8bit # type: ignore
|
89 |
+
Params4bit.to = _to_op_register_4bit # type: ignore
|
90 |
+
Params4bit.cuda = _cuda_op_register_4bit # type: ignore
|
91 |
+
|
92 |
+
def _unpatch():
|
93 |
+
Int8Params.to = _param_to_8bit # type: ignore
|
94 |
+
Int8Params.cuda = _param_cuda_8bit
|
95 |
+
Params4bit.to = _param_to_4bit # type: ignore
|
96 |
+
Params4bit.cuda = _param_cuda_4bit
|
97 |
+
|
98 |
+
def _move():
|
99 |
+
if CUDASetup is not None:
|
100 |
+
CUDASetup._instance = None
|
101 |
+
importlib.reload(cextension)
|
102 |
+
functional.lib = cextension.lib
|
103 |
+
for op in to_ops_8bit.items():
|
104 |
+
tensor, parsed_args = op
|
105 |
+
if parsed_args:
|
106 |
+
_, dtype, _, memory_format = parsed_args
|
107 |
+
else:
|
108 |
+
dtype, memory_format = None, None
|
109 |
+
tensor.data = _param_to_8bit(tensor,
|
110 |
+
device='cuda',
|
111 |
+
dtype=dtype,
|
112 |
+
memory_format=memory_format,
|
113 |
+
) # type: ignore
|
114 |
+
for op in to_ops_4bit.items():
|
115 |
+
tensor, parsed_args = op
|
116 |
+
if parsed_args:
|
117 |
+
_, dtype, _, memory_format = parsed_args
|
118 |
+
else:
|
119 |
+
dtype, memory_format = None, None
|
120 |
+
tensor.data = _param_to_4bit(tensor,
|
121 |
+
device='cuda',
|
122 |
+
dtype=dtype,
|
123 |
+
memory_format=memory_format,
|
124 |
+
) # type: ignore
|
125 |
+
|
126 |
+
else:
|
127 |
+
|
128 |
+
_patch = lambda: None
|
129 |
+
_unpatch = lambda: None
|
130 |
+
_move = lambda: None
|
131 |
+
|
132 |
+
|
133 |
+
patch = _patch
|
134 |
+
unpatch = _unpatch
|
135 |
+
move = _move
|
spaces/zero/client.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import warnings
|
8 |
+
from datetime import timedelta
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import httpx
|
12 |
+
|
13 |
+
from .. import utils
|
14 |
+
from ..config import Config
|
15 |
+
from .api import APIClient
|
16 |
+
from .api import QuotaInfos
|
17 |
+
from .api import ScheduleResponse
|
18 |
+
from .gradio import get_event
|
19 |
+
|
20 |
+
|
21 |
+
TOKEN_HEADER = 'X-IP-Token'
|
22 |
+
DEFAULT_SCHEDULE_DURATION = 60
|
23 |
+
|
24 |
+
QUOTA_MESSAGE = "You have exceeded your GPU quota"
|
25 |
+
UNUSED_MESSAGE = "GPU device not used"
|
26 |
+
NO_GPU_MESSAGE_REGULAR = "No GPU is currently available"
|
27 |
+
NO_GPU_MESSAGE_INQUEUE = "No GPU is currently available for you after 60s"
|
28 |
+
|
29 |
+
|
30 |
+
def api_client():
|
31 |
+
assert Config.zero_device_api_url is not None
|
32 |
+
httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
|
33 |
+
return APIClient(httpx_client)
|
34 |
+
|
35 |
+
|
36 |
+
def startup_report():
|
37 |
+
retries, max_retries = 0, 2
|
38 |
+
client = api_client()
|
39 |
+
while (status := client.startup_report()) is httpx.codes.NOT_FOUND: # pragma: no cover
|
40 |
+
time.sleep(1)
|
41 |
+
if (retries := retries + 1) > max_retries:
|
42 |
+
raise RuntimeError("Error while initializing ZeroGPU: NotFound")
|
43 |
+
if status is not httpx.codes.OK: # pragma: no cover
|
44 |
+
raise RuntimeError("Error while initializing ZeroGPU: Unknown")
|
45 |
+
|
46 |
+
|
47 |
+
def schedule(
|
48 |
+
task_id: int,
|
49 |
+
request: gr.Request | None = None,
|
50 |
+
duration: timedelta | None = None,
|
51 |
+
_first_attempt: bool = True,
|
52 |
+
) -> ScheduleResponse:
|
53 |
+
|
54 |
+
if not gr.__version__.startswith('4.'): # pragma: no cover
|
55 |
+
raise RuntimeError("ZeroGPU is only compatible with Gradio 4+")
|
56 |
+
|
57 |
+
res = api_client().schedule(
|
58 |
+
cgroup_path=utils.self_cgroup_device_path(),
|
59 |
+
task_id=task_id,
|
60 |
+
token=_get_token(request),
|
61 |
+
duration_seconds=duration.seconds if duration is not None else None,
|
62 |
+
)
|
63 |
+
|
64 |
+
if isinstance(res, ScheduleResponse):
|
65 |
+
return res
|
66 |
+
|
67 |
+
if isinstance(res, QuotaInfos): # pragma: no cover
|
68 |
+
requested = duration.seconds if duration is not None else DEFAULT_SCHEDULE_DURATION
|
69 |
+
if res.wait < timedelta(0):
|
70 |
+
message = (
|
71 |
+
f"The requested GPU duration ({requested}s) "
|
72 |
+
f"is larger than the maximum allowed"
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
message = (
|
76 |
+
f"You have exceeded your GPU quota "
|
77 |
+
f"({res.left}s left vs. {requested}s requested). "
|
78 |
+
f"Please retry in {res.wait}"
|
79 |
+
)
|
80 |
+
raise gr.Error(message)
|
81 |
+
|
82 |
+
if not isinstance(res, httpx.codes): # pragma: no cover
|
83 |
+
gr.Info("Waiting for a GPU to become available")
|
84 |
+
connection_event = get_event()
|
85 |
+
if connection_event is None and request is not None:
|
86 |
+
warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
|
87 |
+
while True:
|
88 |
+
try:
|
89 |
+
event = next(res)
|
90 |
+
except StopIteration:
|
91 |
+
raise RuntimeError("Unexpected end of stream")
|
92 |
+
except httpx.RemoteProtocolError:
|
93 |
+
if not _first_attempt:
|
94 |
+
raise RuntimeError("Error while re-trying after queue disconnect")
|
95 |
+
return schedule(task_id, request, duration, _first_attempt=False)
|
96 |
+
if event.event == 'ping':
|
97 |
+
if connection_event is not None and not connection_event.alive:
|
98 |
+
res.close()
|
99 |
+
raise RuntimeError("Connection closed by visitor while queueing")
|
100 |
+
continue
|
101 |
+
if event.event == 'failed':
|
102 |
+
raise gr.Error(NO_GPU_MESSAGE_INQUEUE)
|
103 |
+
if event.event == 'succeeded':
|
104 |
+
assert event.data is not None
|
105 |
+
if connection_event is not None and not connection_event.alive:
|
106 |
+
release(task_id, event.data.nvidiaIndex)
|
107 |
+
raise RuntimeError("Connection closed by visitor on queue success")
|
108 |
+
gr.Info("Successfully acquired a GPU")
|
109 |
+
return event.data
|
110 |
+
|
111 |
+
if res is httpx.codes.SERVICE_UNAVAILABLE:
|
112 |
+
raise gr.Error(NO_GPU_MESSAGE_REGULAR)
|
113 |
+
|
114 |
+
# TODO: Find a way to log 'detail' response field
|
115 |
+
raise RuntimeError(f"ZeroGPU API /schedule error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
|
116 |
+
|
117 |
+
|
118 |
+
def allow(allow_token: str) -> None:
|
119 |
+
pid = os.getpid()
|
120 |
+
assert pid != 1, "Allowing PID 1 on ZeroGPU will end up killing your Space"
|
121 |
+
assert api_client().allow(allow_token=allow_token, pid=pid) is httpx.codes.OK
|
122 |
+
|
123 |
+
|
124 |
+
def release(
|
125 |
+
task_id: int,
|
126 |
+
nvidia_index: int,
|
127 |
+
fail: bool = False,
|
128 |
+
allow_404: bool = False,
|
129 |
+
) -> None:
|
130 |
+
|
131 |
+
res = api_client().release(
|
132 |
+
cgroup_path=utils.self_cgroup_device_path(),
|
133 |
+
task_id=task_id,
|
134 |
+
nvidia_index=nvidia_index,
|
135 |
+
fail=fail,
|
136 |
+
)
|
137 |
+
|
138 |
+
if res is httpx.codes.NO_CONTENT: # pragma: no cover
|
139 |
+
try:
|
140 |
+
gr.Warning(UNUSED_MESSAGE)
|
141 |
+
except AttributeError:
|
142 |
+
pass
|
143 |
+
warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
|
144 |
+
return None
|
145 |
+
|
146 |
+
if res is httpx.codes.NOT_FOUND:
|
147 |
+
if not allow_404:
|
148 |
+
warnings.warn("ZeroGPU API /release warning: 404 Not Found")
|
149 |
+
return None
|
150 |
+
|
151 |
+
if httpx.codes.is_success(res):
|
152 |
+
return None
|
153 |
+
|
154 |
+
# TODO: Find a way to log 'detail' response field
|
155 |
+
# TODO: Only raise in dev environment. Simply warn in production ?
|
156 |
+
raise RuntimeError(f"ZeroGPU API /release error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
|
157 |
+
|
158 |
+
|
159 |
+
def _get_token(request: gr.Request | None) -> str | None:
|
160 |
+
|
161 |
+
if request is None:
|
162 |
+
return None
|
163 |
+
|
164 |
+
headers = getattr(request, 'headers', None)
|
165 |
+
if headers is None or not hasattr(headers, '__dict__'):
|
166 |
+
raise gr.Error("Internal Gradio error")
|
167 |
+
|
168 |
+
# Compatibility trick
|
169 |
+
if not hasattr(headers, 'get'):
|
170 |
+
headers = headers.__dict__ # pragma: no cover
|
171 |
+
|
172 |
+
if not (token := headers.get(TOKEN_HEADER.lower())):
|
173 |
+
raise gr.Error("Internal infra error")
|
174 |
+
|
175 |
+
return token
|
spaces/zero/decorator.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import inspect
|
6 |
+
import sys
|
7 |
+
import warnings
|
8 |
+
from datetime import timedelta
|
9 |
+
from functools import partial
|
10 |
+
from typing import Callable
|
11 |
+
from typing import TypeVar
|
12 |
+
from typing import overload
|
13 |
+
from typing_extensions import ParamSpec
|
14 |
+
from typing_extensions import Unpack
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
from ..config import Config
|
19 |
+
from . import client
|
20 |
+
from .types import EmptyKwargs
|
21 |
+
from .wrappers import regular_function_wrapper
|
22 |
+
from .wrappers import generator_function_wrapper
|
23 |
+
|
24 |
+
|
25 |
+
P = ParamSpec('P')
|
26 |
+
R = TypeVar('R')
|
27 |
+
|
28 |
+
|
29 |
+
decorated_cache: dict[Callable, Callable] = {}
|
30 |
+
|
31 |
+
|
32 |
+
@overload
|
33 |
+
def GPU(
|
34 |
+
task: None = None, *,
|
35 |
+
duration: int | timedelta | None = None,
|
36 |
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
37 |
+
...
|
38 |
+
@overload
|
39 |
+
def GPU(
|
40 |
+
task: Callable[P, R], *,
|
41 |
+
duration: int | timedelta | None = None,
|
42 |
+
) -> Callable[P, R]:
|
43 |
+
...
|
44 |
+
def GPU(
|
45 |
+
task: Callable[P, R] | None = None, *,
|
46 |
+
duration: int | timedelta | None = None,
|
47 |
+
**kwargs: Unpack[EmptyKwargs],
|
48 |
+
) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
|
49 |
+
"""
|
50 |
+
ZeroGPU decorator
|
51 |
+
|
52 |
+
Basic usage:
|
53 |
+
```
|
54 |
+
@spaces.GPU
|
55 |
+
def fn(...):
|
56 |
+
# CUDA is available here
|
57 |
+
pass
|
58 |
+
```
|
59 |
+
|
60 |
+
With custom duration:
|
61 |
+
```
|
62 |
+
@spaces.GPU(duration=45) # Expressed in seconds
|
63 |
+
def fn(...):
|
64 |
+
# CUDA is available here
|
65 |
+
pass
|
66 |
+
```
|
67 |
+
|
68 |
+
Args:
|
69 |
+
task (`Callable | None`): Python function that requires CUDA
|
70 |
+
duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
`Callable`: GPU-ready function
|
74 |
+
"""
|
75 |
+
if "enable_queue" in kwargs:
|
76 |
+
warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
|
77 |
+
if task is None:
|
78 |
+
return partial(_GPU, duration=duration)
|
79 |
+
return _GPU(task, duration)
|
80 |
+
|
81 |
+
|
82 |
+
def _GPU(
|
83 |
+
task: Callable[P, R],
|
84 |
+
duration: int | timedelta | None,
|
85 |
+
) -> Callable[P, R]:
|
86 |
+
|
87 |
+
if not Config.zero_gpu:
|
88 |
+
# TODO: still prepend gr.Request for type consistency ?
|
89 |
+
return task # type: ignore
|
90 |
+
|
91 |
+
if sys.version_info.minor < 9: # pragma: no cover
|
92 |
+
raise RuntimeError("Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+")
|
93 |
+
|
94 |
+
if task in decorated_cache:
|
95 |
+
# TODO: Assert same duration ?
|
96 |
+
return decorated_cache[task] # type: ignore
|
97 |
+
|
98 |
+
if inspect.iscoroutinefunction(task):
|
99 |
+
raise NotImplementedError
|
100 |
+
|
101 |
+
if duration is None or isinstance(duration, timedelta):
|
102 |
+
timedelta_duration = duration
|
103 |
+
else:
|
104 |
+
timedelta_duration = timedelta(seconds=duration)
|
105 |
+
|
106 |
+
if inspect.isgeneratorfunction(task):
|
107 |
+
decorated = generator_function_wrapper(task, timedelta_duration)
|
108 |
+
else:
|
109 |
+
decorated = regular_function_wrapper(task, timedelta_duration)
|
110 |
+
|
111 |
+
client.startup_report()
|
112 |
+
decorated_cache.update({
|
113 |
+
task: decorated,
|
114 |
+
decorated: decorated,
|
115 |
+
})
|
116 |
+
|
117 |
+
return decorated # type: ignore
|
spaces/zero/gradio.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from typing import NamedTuple
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
from gradio.context import LocalContext
|
9 |
+
from gradio.helpers import Progress
|
10 |
+
from gradio.helpers import TrackedIterable
|
11 |
+
from gradio.queueing import Queue
|
12 |
+
from typing_extensions import assert_type
|
13 |
+
|
14 |
+
from ..utils import SimpleQueue
|
15 |
+
from .types import GeneratorResQueueResult
|
16 |
+
from .types import GradioQueueEvent
|
17 |
+
from .types import RegularResQueueResult
|
18 |
+
|
19 |
+
|
20 |
+
QUEUE_RPC_METHODS = [
|
21 |
+
"set_progress",
|
22 |
+
"log_message",
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
class GradioPartialContext(NamedTuple):
|
27 |
+
event_id: str | None
|
28 |
+
in_event_listener: bool
|
29 |
+
progress: Progress | None
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def get():
|
33 |
+
TrackedIterable.__reduce__ = tracked_iterable__reduce__
|
34 |
+
return GradioPartialContext(
|
35 |
+
event_id=LocalContext.event_id.get(),
|
36 |
+
in_event_listener=LocalContext.in_event_listener.get(),
|
37 |
+
progress=LocalContext.progress.get(),
|
38 |
+
)
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def apply(context: 'GradioPartialContext'):
|
42 |
+
LocalContext.event_id.set(context.event_id)
|
43 |
+
LocalContext.in_event_listener.set(context.in_event_listener)
|
44 |
+
LocalContext.progress.set(context.progress)
|
45 |
+
|
46 |
+
|
47 |
+
def get_queue_instance():
|
48 |
+
blocks = LocalContext.blocks.get()
|
49 |
+
if blocks is None: # pragma: no cover
|
50 |
+
return None
|
51 |
+
return blocks._queue
|
52 |
+
|
53 |
+
|
54 |
+
def get_event():
|
55 |
+
queue = get_queue_instance()
|
56 |
+
event_id = LocalContext.event_id.get()
|
57 |
+
if queue is None:
|
58 |
+
return None
|
59 |
+
if event_id is None: # pragma: no cover
|
60 |
+
return None
|
61 |
+
for job in queue.active_jobs:
|
62 |
+
if job is None: # pragma: no cover
|
63 |
+
continue
|
64 |
+
for event in job:
|
65 |
+
if event._id == event_id:
|
66 |
+
return event
|
67 |
+
|
68 |
+
|
69 |
+
def try_process_queue_event(method_name: str, *args, **kwargs):
|
70 |
+
queue = get_queue_instance()
|
71 |
+
if queue is None: # pragma: no cover
|
72 |
+
warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
|
73 |
+
return
|
74 |
+
method = getattr(queue, method_name, None)
|
75 |
+
assert callable(method)
|
76 |
+
method(*args, **kwargs)
|
77 |
+
|
78 |
+
|
79 |
+
def patch_gradio_queue(
|
80 |
+
res_queue: SimpleQueue[RegularResQueueResult | None] | SimpleQueue[GeneratorResQueueResult | None],
|
81 |
+
):
|
82 |
+
|
83 |
+
def rpc_method(method_name: str):
|
84 |
+
def method(*args, **kwargs):
|
85 |
+
if args and isinstance(args[0], Queue):
|
86 |
+
args = args[1:] # drop `self`
|
87 |
+
res_queue.put(GradioQueueEvent(method_name, args, kwargs))
|
88 |
+
return method
|
89 |
+
|
90 |
+
for method_name in QUEUE_RPC_METHODS:
|
91 |
+
if (method := getattr(Queue, method_name, None)) is None: # pragma: no cover
|
92 |
+
warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
|
93 |
+
continue
|
94 |
+
if not callable(method): # pragma: no cover
|
95 |
+
warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
|
96 |
+
continue
|
97 |
+
setattr(Queue, method_name, rpc_method(method_name))
|
98 |
+
|
99 |
+
TrackedIterable.__reduce__ = tracked_iterable__reduce__
|
100 |
+
|
101 |
+
|
102 |
+
def tracked_iterable__reduce__(self):
|
103 |
+
res: tuple = super(TrackedIterable, self).__reduce__() # type: ignore
|
104 |
+
cls, base, state, *_ = res
|
105 |
+
return cls, base,{**state, **{
|
106 |
+
'iterable': None,
|
107 |
+
'_tqdm': None,
|
108 |
+
}}
|
spaces/zero/torch.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
# pyright: reportPrivateImportUsage=false
|
4 |
+
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
import multiprocessing
|
8 |
+
import os
|
9 |
+
from concurrent.futures import ProcessPoolExecutor
|
10 |
+
from contextlib import suppress
|
11 |
+
from functools import partial
|
12 |
+
from types import SimpleNamespace
|
13 |
+
from typing import TYPE_CHECKING
|
14 |
+
from typing import Any
|
15 |
+
from typing import Optional
|
16 |
+
from typing import Tuple
|
17 |
+
|
18 |
+
from ..config import Config
|
19 |
+
from . import bitsandbytes
|
20 |
+
from .utils import maybe_import_torch
|
21 |
+
|
22 |
+
if TYPE_CHECKING:
|
23 |
+
import torch as Torch
|
24 |
+
|
25 |
+
|
26 |
+
# Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
|
27 |
+
CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
|
28 |
+
CUDA_TOTAL_MEMORY = 42144366592
|
29 |
+
CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
|
30 |
+
CUDA_DEVICE_CAPABILITY = (8, 0)
|
31 |
+
CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
|
32 |
+
|
33 |
+
GENERIC_METHOD_NAMES = [
|
34 |
+
'arange',
|
35 |
+
'as_tensor',
|
36 |
+
'asarray',
|
37 |
+
'bartlett_window',
|
38 |
+
'blackman_window',
|
39 |
+
'empty',
|
40 |
+
'empty_like',
|
41 |
+
'empty_strided',
|
42 |
+
'eye',
|
43 |
+
'full',
|
44 |
+
'full_like',
|
45 |
+
'hamming_window',
|
46 |
+
'hann_window',
|
47 |
+
'kaiser_window',
|
48 |
+
'linspace',
|
49 |
+
'logspace',
|
50 |
+
'obj',
|
51 |
+
'ones',
|
52 |
+
'ones_like',
|
53 |
+
'rand',
|
54 |
+
'rand_like',
|
55 |
+
'randint',
|
56 |
+
'randint_like',
|
57 |
+
'randn',
|
58 |
+
'randn_like',
|
59 |
+
'randperm',
|
60 |
+
'range',
|
61 |
+
'sparse_bsc_tensor',
|
62 |
+
'sparse_bsr_tensor',
|
63 |
+
'sparse_compressed_tensor',
|
64 |
+
'sparse_coo_tensor',
|
65 |
+
'sparse_csc_tensor',
|
66 |
+
'sparse_csr_tensor',
|
67 |
+
'tensor',
|
68 |
+
'tril_indices',
|
69 |
+
'triu_indices',
|
70 |
+
'zeros',
|
71 |
+
'zeros_like',
|
72 |
+
]
|
73 |
+
|
74 |
+
|
75 |
+
if (torch := maybe_import_torch()):
|
76 |
+
|
77 |
+
from torch.utils.weak import WeakTensorKeyDictionary
|
78 |
+
|
79 |
+
TO_CUDA = (torch.device('cuda'), None, False, None)
|
80 |
+
|
81 |
+
_tensor__deepcopy__ = torch.Tensor.__deepcopy__
|
82 |
+
_tensor_to = torch.Tensor.to
|
83 |
+
_tensor_cuda = torch.Tensor.cuda
|
84 |
+
_tensor_cpu = torch.Tensor.cpu
|
85 |
+
_torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
|
86 |
+
_cuda_init = torch._C._cuda_init
|
87 |
+
_cuda_available = torch.cuda.is_available
|
88 |
+
_cuda_device_count = torch.cuda.device_count
|
89 |
+
_cuda_current_device = torch.cuda.current_device
|
90 |
+
_cuda_mem_get_info = torch.cuda.mem_get_info
|
91 |
+
_cuda_get_device_capability = torch.cuda.get_device_capability
|
92 |
+
_cuda_get_device_properties = torch.cuda.get_device_properties
|
93 |
+
_cuda_get_device_name = torch.cuda.get_device_name
|
94 |
+
|
95 |
+
TensorToArgs = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
|
96 |
+
|
97 |
+
to_ops: dict[Torch.Tensor, TensorToArgs] = WeakTensorKeyDictionary() # type: ignore
|
98 |
+
|
99 |
+
def _tensor_new_register(*args, **kwargs):
|
100 |
+
new_tensor: Torch.Tensor = torch._C._TensorBase.__new__(*args, **kwargs)
|
101 |
+
if (base_tensor := new_tensor._base) is not None:
|
102 |
+
if base_tensor in to_ops:
|
103 |
+
to_ops[new_tensor] = to_ops[base_tensor]
|
104 |
+
return new_tensor
|
105 |
+
|
106 |
+
def _tensor_deepcopy_register(self: Torch.Tensor, memo):
|
107 |
+
new_tensor = _tensor__deepcopy__(self, memo)
|
108 |
+
if isinstance(new_tensor, torch.Tensor):
|
109 |
+
if self in to_ops:
|
110 |
+
to_ops[new_tensor] = to_ops[self]
|
111 |
+
return new_tensor
|
112 |
+
|
113 |
+
@property
|
114 |
+
def _tensor_device_property(self: Torch.Tensor):
|
115 |
+
if self in to_ops:
|
116 |
+
return torch.device(type='cuda', index=0)
|
117 |
+
del torch.Tensor.device
|
118 |
+
try:
|
119 |
+
return self.device
|
120 |
+
finally:
|
121 |
+
torch.Tensor.device = _tensor_device_property # type: ignore
|
122 |
+
|
123 |
+
@property
|
124 |
+
def _tensor_dtype_property(self: Torch.Tensor):
|
125 |
+
if self in to_ops:
|
126 |
+
if (to_dtype := to_ops[self][1]) is not None:
|
127 |
+
return to_dtype
|
128 |
+
del torch.Tensor.dtype
|
129 |
+
try:
|
130 |
+
return self.dtype
|
131 |
+
finally:
|
132 |
+
torch.Tensor.dtype = _tensor_dtype_property # type: ignore
|
133 |
+
|
134 |
+
def _to_op_register(self: Torch.Tensor, *args, **kwargs):
|
135 |
+
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
136 |
+
device, dtype, *_ = parsed
|
137 |
+
try:
|
138 |
+
to_args = to_ops.pop(self)
|
139 |
+
except KeyError:
|
140 |
+
to_args = None
|
141 |
+
if device is None:
|
142 |
+
if to_args is not None:
|
143 |
+
to_ops[self] = (to_args[0], dtype, *to_args[2:])
|
144 |
+
return self
|
145 |
+
return _tensor_to(self, *args, **kwargs)
|
146 |
+
if device.type != 'cuda':
|
147 |
+
if to_args is not None:
|
148 |
+
if (to_dtype := to_args[1]) is not None:
|
149 |
+
kwargs = {'dtype': to_dtype, **kwargs}
|
150 |
+
return _tensor_to(self, *args, **kwargs)
|
151 |
+
to_ops[self] = parsed
|
152 |
+
return self
|
153 |
+
|
154 |
+
def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
|
155 |
+
if device is None:
|
156 |
+
return True
|
157 |
+
if isinstance(device, int):
|
158 |
+
return True
|
159 |
+
if isinstance(device, str):
|
160 |
+
device = torch.device(device)
|
161 |
+
return device.type == 'cuda'
|
162 |
+
|
163 |
+
def _cuda_op_register(self: Torch.Tensor, device: Torch.device | int | str | None = None, **kwargs):
|
164 |
+
if not _cuda_op_arg_check(device):
|
165 |
+
# Let PyTorch handle the fail
|
166 |
+
return _tensor_cuda(self, device, **kwargs)
|
167 |
+
to_ops[self] = TO_CUDA
|
168 |
+
return self
|
169 |
+
|
170 |
+
def _cpu_op_remove(self: Torch.Tensor, **kwargs):
|
171 |
+
try:
|
172 |
+
to_args = to_ops.pop(self)
|
173 |
+
except KeyError:
|
174 |
+
to_args = None
|
175 |
+
if to_args is not None:
|
176 |
+
if (to_dtype := to_args[1]) is not None:
|
177 |
+
return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
|
178 |
+
return _tensor_cpu(self, **kwargs)
|
179 |
+
|
180 |
+
def _cuda_init_raise():
|
181 |
+
raise RuntimeError(
|
182 |
+
"CUDA must not be initialized in the main process "
|
183 |
+
"on Spaces with Stateless GPU environment.\n"
|
184 |
+
"You can look at this Stacktrace to find out "
|
185 |
+
"which part of your code triggered a CUDA init"
|
186 |
+
)
|
187 |
+
|
188 |
+
def _generic_method_register(name: str, *args: Any, **kwargs: Any):
|
189 |
+
try:
|
190 |
+
device = torch.device(kwargs.get('device', "cpu"))
|
191 |
+
except Exception:
|
192 |
+
return _torch_generics[name](*args, **kwargs)
|
193 |
+
if device.type != 'cuda':
|
194 |
+
return _torch_generics[name](*args, **kwargs)
|
195 |
+
tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
|
196 |
+
to_ops[tensor] = TO_CUDA
|
197 |
+
return tensor
|
198 |
+
|
199 |
+
def _patch():
|
200 |
+
torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
|
201 |
+
torch.Tensor.__new__ = _tensor_new_register # pyright: ignore [reportAttributeAccessIssue]
|
202 |
+
torch.Tensor.to = _to_op_register # type: ignore
|
203 |
+
torch.Tensor.cuda = _cuda_op_register # type: ignore
|
204 |
+
torch.Tensor.cpu = _cpu_op_remove # type: ignore
|
205 |
+
if Config.zero_patch_torch_device:
|
206 |
+
torch.Tensor.device = _tensor_device_property # type: ignore
|
207 |
+
torch.Tensor.dtype = _tensor_dtype_property # pyright: ignore [reportAttributeAccessIssue]
|
208 |
+
for name in GENERIC_METHOD_NAMES:
|
209 |
+
setattr(torch, name, partial(_generic_method_register, name))
|
210 |
+
torch._C._cuda_init = _cuda_init_raise
|
211 |
+
torch.cuda.is_available = lambda: True
|
212 |
+
torch.cuda.device_count = lambda: 1
|
213 |
+
torch.cuda.current_device = lambda: 0
|
214 |
+
torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
|
215 |
+
torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
|
216 |
+
torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
|
217 |
+
torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
|
218 |
+
bitsandbytes.patch()
|
219 |
+
|
220 |
+
def _unpatch():
|
221 |
+
torch.Tensor.__deepcopy__ = _tensor__deepcopy__
|
222 |
+
with suppress(AttributeError):
|
223 |
+
del torch.Tensor.__new__
|
224 |
+
torch.Tensor.to = _tensor_to
|
225 |
+
torch.Tensor.cuda = _tensor_cuda
|
226 |
+
torch.Tensor.cpu = _tensor_cpu
|
227 |
+
with suppress(AttributeError):
|
228 |
+
del torch.Tensor.device
|
229 |
+
with suppress(AttributeError):
|
230 |
+
del torch.Tensor.dtype
|
231 |
+
for name in GENERIC_METHOD_NAMES:
|
232 |
+
setattr(torch, name, _torch_generics[name])
|
233 |
+
torch._C._cuda_init = _cuda_init
|
234 |
+
torch.cuda.is_available = _cuda_available
|
235 |
+
torch.cuda.device_count = _cuda_device_count
|
236 |
+
torch.cuda.current_device = _cuda_current_device
|
237 |
+
torch.cuda.mem_get_info = _cuda_mem_get_info
|
238 |
+
torch.cuda.get_device_capability = _cuda_get_device_capability
|
239 |
+
torch.cuda.get_device_properties = _cuda_get_device_properties
|
240 |
+
torch.cuda.get_device_name = _cuda_get_device_name
|
241 |
+
bitsandbytes.unpatch()
|
242 |
+
|
243 |
+
def _move(nvidia_uuid: str):
|
244 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
|
245 |
+
torch.Tensor([0]).cuda() # CUDA init
|
246 |
+
for op in to_ops.items():
|
247 |
+
tensor, parsed_args = op
|
248 |
+
_, dtype, _, memory_format = parsed_args
|
249 |
+
tensor.data = _tensor_to(tensor,
|
250 |
+
device='cuda',
|
251 |
+
dtype=dtype,
|
252 |
+
memory_format=memory_format,
|
253 |
+
) # type: ignore
|
254 |
+
bitsandbytes.move()
|
255 |
+
torch.cuda.synchronize()
|
256 |
+
|
257 |
+
def _is_in_bad_fork():
|
258 |
+
with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
|
259 |
+
f = e.submit(torch.cuda._is_in_bad_fork)
|
260 |
+
return f.result()
|
261 |
+
|
262 |
+
def _disable_cuda_intercept():
|
263 |
+
torch.Tensor.to = _tensor_to
|
264 |
+
torch.Tensor.cuda = _tensor_cuda
|
265 |
+
|
266 |
+
else:
|
267 |
+
|
268 |
+
_patch = lambda: None
|
269 |
+
_unpatch = lambda: None
|
270 |
+
_move = lambda nvidia_uuid: None
|
271 |
+
_is_in_bad_fork = lambda: False
|
272 |
+
_disable_cuda_intercept = lambda: None
|
273 |
+
|
274 |
+
|
275 |
+
patch = _patch
|
276 |
+
unpatch = _unpatch
|
277 |
+
move = _move
|
278 |
+
is_in_bad_fork = _is_in_bad_fork
|
279 |
+
disable_cuda_intercept = _disable_cuda_intercept
|
spaces/zero/tqdm.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
from multiprocessing.synchronize import RLock as MultiprocessingRLock
|
5 |
+
|
6 |
+
|
7 |
+
def remove_tqdm_multiprocessing_lock():
|
8 |
+
from tqdm import tqdm
|
9 |
+
tqdm_lock = tqdm.get_lock()
|
10 |
+
assert tqdm_lock.__class__.__name__ == 'TqdmDefaultWriteLock'
|
11 |
+
tqdm_lock.locks = [
|
12 |
+
lock for lock in tqdm_lock.locks
|
13 |
+
if not isinstance(lock, MultiprocessingRLock)
|
14 |
+
]
|
spaces/zero/types.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Any
|
8 |
+
from typing import Dict
|
9 |
+
from typing import Tuple
|
10 |
+
from typing import TypedDict
|
11 |
+
from typing_extensions import Generic
|
12 |
+
from typing_extensions import ParamSpec
|
13 |
+
from typing_extensions import TypeAlias
|
14 |
+
from typing_extensions import TypeVar
|
15 |
+
|
16 |
+
|
17 |
+
Params = Tuple[Tuple[object, ...], Dict[str, Any]]
|
18 |
+
Res = TypeVar('Res')
|
19 |
+
Param = ParamSpec('Param')
|
20 |
+
|
21 |
+
class EmptyKwargs(TypedDict):
|
22 |
+
pass
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class OkResult(Generic[Res]):
|
26 |
+
value: Res
|
27 |
+
@dataclass
|
28 |
+
class ExceptionResult:
|
29 |
+
value: Exception
|
30 |
+
@dataclass
|
31 |
+
class AbortedResult:
|
32 |
+
pass
|
33 |
+
@dataclass
|
34 |
+
class EndResult:
|
35 |
+
pass
|
36 |
+
@dataclass
|
37 |
+
class GradioQueueEvent:
|
38 |
+
method_name: str
|
39 |
+
args: tuple[Any, ...]
|
40 |
+
kwargs: dict[str, Any]
|
41 |
+
|
42 |
+
RegularResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | GradioQueueEvent"
|
43 |
+
GeneratorResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | GradioQueueEvent"
|
44 |
+
YieldQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | AbortedResult"
|
spaces/zero/utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from importlib import metadata
|
7 |
+
from types import ModuleType
|
8 |
+
|
9 |
+
from packaging import version
|
10 |
+
|
11 |
+
from ..config import Config
|
12 |
+
|
13 |
+
|
14 |
+
def maybe_import_torch():
|
15 |
+
if not Config.zero_gpu:
|
16 |
+
return None
|
17 |
+
try:
|
18 |
+
import torch
|
19 |
+
except ImportError:
|
20 |
+
return None
|
21 |
+
return torch
|
22 |
+
|
23 |
+
|
24 |
+
@contextmanager
|
25 |
+
def cuda_unavailable(torch: ModuleType):
|
26 |
+
_is_available = torch.cuda.is_available
|
27 |
+
torch.cuda.is_available = lambda: False
|
28 |
+
yield
|
29 |
+
torch.cuda.is_available = _is_available
|
30 |
+
|
31 |
+
|
32 |
+
def maybe_import_bitsandbytes():
|
33 |
+
if (torch := maybe_import_torch()) is None:
|
34 |
+
return None # pragma: no cover
|
35 |
+
with cuda_unavailable(torch):
|
36 |
+
try:
|
37 |
+
import bitsandbytes
|
38 |
+
except ImportError:
|
39 |
+
bitsandbytes = None
|
40 |
+
else:
|
41 |
+
if (bnb_version := version.parse(metadata.version('bitsandbytes'))) < version.parse('0.40.0'):
|
42 |
+
raise RuntimeError(f"ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})") # pragma: no cover
|
43 |
+
print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑")
|
44 |
+
return bitsandbytes
|
spaces/zero/wrappers.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import multiprocessing
|
6 |
+
import os
|
7 |
+
import signal
|
8 |
+
import traceback
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
from contextvars import copy_context
|
11 |
+
from datetime import timedelta
|
12 |
+
from functools import partial
|
13 |
+
from functools import wraps
|
14 |
+
from multiprocessing.context import ForkProcess
|
15 |
+
from pickle import PicklingError
|
16 |
+
from queue import Empty
|
17 |
+
from queue import Queue as ThreadQueue
|
18 |
+
from threading import Thread
|
19 |
+
from typing import TYPE_CHECKING
|
20 |
+
from typing import Callable
|
21 |
+
from typing import Generator
|
22 |
+
from typing import Generic
|
23 |
+
from typing_extensions import assert_never
|
24 |
+
|
25 |
+
import gradio as gr
|
26 |
+
import psutil
|
27 |
+
|
28 |
+
from ..utils import debug
|
29 |
+
from ..utils import drop_params
|
30 |
+
from ..utils import gradio_request_var
|
31 |
+
from ..utils import SimpleQueue as Queue
|
32 |
+
from . import client
|
33 |
+
from . import torch
|
34 |
+
from .api import AllowToken
|
35 |
+
from .api import NvidiaIndex
|
36 |
+
from .api import NvidiaUUID
|
37 |
+
from .gradio import GradioPartialContext
|
38 |
+
from .gradio import patch_gradio_queue
|
39 |
+
from .gradio import try_process_queue_event
|
40 |
+
from .tqdm import remove_tqdm_multiprocessing_lock
|
41 |
+
from .types import * # TODO: Please don't do that
|
42 |
+
|
43 |
+
|
44 |
+
GENERATOR_GLOBAL_TIMEOUT = 20 * 60
|
45 |
+
|
46 |
+
|
47 |
+
Process = multiprocessing.get_context('fork').Process
|
48 |
+
forked = False
|
49 |
+
|
50 |
+
|
51 |
+
class Worker(Generic[Res]):
|
52 |
+
process: ForkProcess
|
53 |
+
arg_queue: Queue[tuple[Params, GradioPartialContext]]
|
54 |
+
res_queue: Queue[Res | None]
|
55 |
+
_sentinel: Thread
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
target: Callable[[
|
60 |
+
Queue[tuple[Params, GradioPartialContext]],
|
61 |
+
Queue[Res | None],
|
62 |
+
AllowToken | None,
|
63 |
+
NvidiaUUID,
|
64 |
+
list[int],
|
65 |
+
], None],
|
66 |
+
allow_token: str | None,
|
67 |
+
nvidia_uuid: str,
|
68 |
+
):
|
69 |
+
self._sentinel = Thread(target=self._close_on_exit)
|
70 |
+
self.arg_queue = Queue()
|
71 |
+
self.res_queue = Queue()
|
72 |
+
fds = [c.fd for c in psutil.Process().connections()]
|
73 |
+
args = self.arg_queue, self.res_queue, allow_token, nvidia_uuid, fds
|
74 |
+
if TYPE_CHECKING:
|
75 |
+
target(*args)
|
76 |
+
self.process = Process(
|
77 |
+
target=target,
|
78 |
+
args=args,
|
79 |
+
daemon=True,
|
80 |
+
)
|
81 |
+
self.process.start()
|
82 |
+
self._sentinel.start()
|
83 |
+
|
84 |
+
def _close_on_exit(self):
|
85 |
+
self.process.join()
|
86 |
+
self.res_queue.put(None)
|
87 |
+
|
88 |
+
|
89 |
+
def worker_init(
|
90 |
+
res_queue: Queue[RegularResQueueResult | None] | Queue[GeneratorResQueueResult | None],
|
91 |
+
allow_token: str | None,
|
92 |
+
nvidia_uuid: str,
|
93 |
+
fds: list[int],
|
94 |
+
) -> None | ExceptionResult:
|
95 |
+
try: # Unrecoverable init part
|
96 |
+
if allow_token is not None:
|
97 |
+
client.allow(allow_token)
|
98 |
+
torch.unpatch()
|
99 |
+
torch.move(nvidia_uuid)
|
100 |
+
patch_gradio_queue(res_queue)
|
101 |
+
except Exception as e: # pragma: no cover
|
102 |
+
traceback.print_exc()
|
103 |
+
return ExceptionResult(e)
|
104 |
+
try:
|
105 |
+
remove_tqdm_multiprocessing_lock()
|
106 |
+
except Exception: # pragma: no cover
|
107 |
+
print("Error while trying to remove tqdm mp_lock:")
|
108 |
+
traceback.print_exc()
|
109 |
+
for fd in fds:
|
110 |
+
try:
|
111 |
+
os.close(fd)
|
112 |
+
except Exception as e: # pragma: no cover
|
113 |
+
if isinstance(e, OSError) and e.errno == 9:
|
114 |
+
continue
|
115 |
+
traceback.print_exc()
|
116 |
+
return ExceptionResult(e)
|
117 |
+
|
118 |
+
|
119 |
+
def regular_function_wrapper(
|
120 |
+
task: Callable[Param, Res],
|
121 |
+
duration: timedelta | None,
|
122 |
+
) -> Callable[Param, Res]:
|
123 |
+
|
124 |
+
request_var = gradio_request_var()
|
125 |
+
workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res]]] = {}
|
126 |
+
task_id = id(task)
|
127 |
+
|
128 |
+
@wraps(task)
|
129 |
+
def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Res:
|
130 |
+
|
131 |
+
if forked:
|
132 |
+
return task(*args, **kwargs)
|
133 |
+
|
134 |
+
request = request_var.get()
|
135 |
+
schedule_response = client.schedule(task_id=task_id, request=request, duration=duration)
|
136 |
+
allow_token = schedule_response.allowToken
|
137 |
+
nvidia_index = schedule_response.nvidiaIndex
|
138 |
+
nvidia_uuid = schedule_response.nvidiaUUID
|
139 |
+
release = partial(client.release, task_id=task_id, nvidia_index=nvidia_index)
|
140 |
+
|
141 |
+
worker = workers.get(nvidia_index)
|
142 |
+
if worker is None or not worker.process.is_alive():
|
143 |
+
worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
|
144 |
+
workers[nvidia_index] = worker
|
145 |
+
|
146 |
+
try:
|
147 |
+
worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
|
148 |
+
except PicklingError:
|
149 |
+
release(fail=True)
|
150 |
+
# TODO: Better error message (check what arg / kwarg is problematic ?)
|
151 |
+
raise
|
152 |
+
|
153 |
+
while True:
|
154 |
+
res = worker.res_queue.get()
|
155 |
+
if res is None:
|
156 |
+
release(fail=True, allow_404=True)
|
157 |
+
raise gr.Error("GPU task aborted")
|
158 |
+
if isinstance(res, ExceptionResult):
|
159 |
+
release(fail=True)
|
160 |
+
raise res.value
|
161 |
+
if isinstance(res, OkResult):
|
162 |
+
release()
|
163 |
+
return res.value
|
164 |
+
if isinstance(res, GradioQueueEvent):
|
165 |
+
try_process_queue_event(res.method_name, *res.args, **res.kwargs)
|
166 |
+
continue
|
167 |
+
assert_never(res)
|
168 |
+
|
169 |
+
|
170 |
+
def thread_wrapper(
|
171 |
+
arg_queue: Queue[tuple[Params, GradioPartialContext]],
|
172 |
+
res_queue: Queue[RegularResQueueResult[Res] | None],
|
173 |
+
allow_token: str | None,
|
174 |
+
nvidia_uuid: str,
|
175 |
+
fds: list[int],
|
176 |
+
):
|
177 |
+
global forked
|
178 |
+
forked = True
|
179 |
+
if (res := worker_init(
|
180 |
+
res_queue=res_queue,
|
181 |
+
allow_token=allow_token,
|
182 |
+
nvidia_uuid=nvidia_uuid,
|
183 |
+
fds=fds,
|
184 |
+
)) is not None: # pragma: no cover
|
185 |
+
res_queue.put(res)
|
186 |
+
return
|
187 |
+
signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
|
188 |
+
while True:
|
189 |
+
try:
|
190 |
+
(args, kwargs), gradio_context = arg_queue.get()
|
191 |
+
except OSError:
|
192 |
+
break
|
193 |
+
GradioPartialContext.apply(gradio_context)
|
194 |
+
context = copy_context()
|
195 |
+
with ThreadPoolExecutor() as executor:
|
196 |
+
future = executor.submit(context.run, task, *args, **kwargs) # type: ignore
|
197 |
+
try:
|
198 |
+
res = future.result()
|
199 |
+
except Exception as e:
|
200 |
+
traceback.print_exc()
|
201 |
+
res = ExceptionResult(e)
|
202 |
+
else:
|
203 |
+
res = OkResult(res)
|
204 |
+
try:
|
205 |
+
res_queue.put(res)
|
206 |
+
except PicklingError as e:
|
207 |
+
res_queue.put(ExceptionResult(e))
|
208 |
+
|
209 |
+
|
210 |
+
return gradio_handler
|
211 |
+
|
212 |
+
|
213 |
+
def generator_function_wrapper(
|
214 |
+
task: Callable[Param, Generator[Res, None, None]],
|
215 |
+
duration: timedelta | None,
|
216 |
+
) -> Callable[Param, Generator[Res, None, None]]:
|
217 |
+
|
218 |
+
request_var = gradio_request_var()
|
219 |
+
workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res]]] = {}
|
220 |
+
task_id = id(task)
|
221 |
+
|
222 |
+
@wraps(task)
|
223 |
+
def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
|
224 |
+
|
225 |
+
if forked:
|
226 |
+
yield from task(*args, **kwargs)
|
227 |
+
return
|
228 |
+
|
229 |
+
request = request_var.get()
|
230 |
+
schedule_response = client.schedule(task_id=task_id, request=request, duration=duration)
|
231 |
+
allow_token = schedule_response.allowToken
|
232 |
+
nvidia_index = schedule_response.nvidiaIndex
|
233 |
+
nvidia_uuid = schedule_response.nvidiaUUID
|
234 |
+
release = partial(client.release, task_id=task_id, nvidia_index=nvidia_index)
|
235 |
+
|
236 |
+
worker = workers.get(nvidia_index)
|
237 |
+
if worker is None or not worker.process.is_alive():
|
238 |
+
worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
|
239 |
+
workers[nvidia_index] = worker
|
240 |
+
|
241 |
+
try:
|
242 |
+
worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
|
243 |
+
except PicklingError:
|
244 |
+
release(fail=True)
|
245 |
+
raise
|
246 |
+
|
247 |
+
yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
|
248 |
+
def fill_yield_queue(worker: Worker[GeneratorResQueueResult[Res]]):
|
249 |
+
while True:
|
250 |
+
res = worker.res_queue.get()
|
251 |
+
if res is None:
|
252 |
+
release(fail=True, allow_404=True)
|
253 |
+
yield_queue.put(AbortedResult())
|
254 |
+
return
|
255 |
+
if isinstance(res, ExceptionResult):
|
256 |
+
release(fail=True)
|
257 |
+
yield_queue.put(ExceptionResult(res.value))
|
258 |
+
return
|
259 |
+
if isinstance(res, EndResult):
|
260 |
+
release()
|
261 |
+
yield_queue.put(EndResult())
|
262 |
+
return
|
263 |
+
if isinstance(res, OkResult):
|
264 |
+
yield_queue.put(OkResult(res.value))
|
265 |
+
continue
|
266 |
+
if isinstance(res, GradioQueueEvent): # pragma: no cover (not working properly on Gradio side)
|
267 |
+
try_process_queue_event(res.method_name, *res.args, **res.kwargs)
|
268 |
+
continue
|
269 |
+
debug(f"fill_yield_queue: assert_never({res=})")
|
270 |
+
assert_never(res)
|
271 |
+
from typing_extensions import assert_never
|
272 |
+
with ThreadPoolExecutor() as e:
|
273 |
+
f = e.submit(fill_yield_queue, worker)
|
274 |
+
f.add_done_callback(lambda _: debug("fill_yield_queue DONE"))
|
275 |
+
while True:
|
276 |
+
try:
|
277 |
+
res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
|
278 |
+
except Empty: # pragma: no cover
|
279 |
+
debug(f"yield_queue TIMEOUT ({GENERATOR_GLOBAL_TIMEOUT=})")
|
280 |
+
raise
|
281 |
+
if isinstance(res, AbortedResult):
|
282 |
+
raise gr.Error("GPU task aborted")
|
283 |
+
if isinstance(res, ExceptionResult):
|
284 |
+
raise res.value
|
285 |
+
if isinstance(res, EndResult):
|
286 |
+
break
|
287 |
+
if isinstance(res, OkResult):
|
288 |
+
yield res.value
|
289 |
+
continue
|
290 |
+
debug(f"gradio_handler: assert_never({res=})")
|
291 |
+
assert_never(res)
|
292 |
+
|
293 |
+
|
294 |
+
def thread_wrapper(
|
295 |
+
arg_queue: Queue[tuple[Params, GradioPartialContext]],
|
296 |
+
res_queue: Queue[GeneratorResQueueResult[Res] | None],
|
297 |
+
allow_token: str | None,
|
298 |
+
nvidia_uuid: str,
|
299 |
+
fds: list[int],
|
300 |
+
):
|
301 |
+
global forked
|
302 |
+
forked = True
|
303 |
+
if (res := worker_init(
|
304 |
+
res_queue=res_queue,
|
305 |
+
allow_token=allow_token,
|
306 |
+
nvidia_uuid=nvidia_uuid,
|
307 |
+
fds=fds,
|
308 |
+
)) is not None: # pragma: no cover
|
309 |
+
res_queue.put(res)
|
310 |
+
return
|
311 |
+
signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
|
312 |
+
while True:
|
313 |
+
try:
|
314 |
+
(args, kwargs), gradio_context = arg_queue.get()
|
315 |
+
except OSError:
|
316 |
+
break
|
317 |
+
def iterate():
|
318 |
+
gen = task(*args, **kwargs) # type: ignore
|
319 |
+
while True:
|
320 |
+
try:
|
321 |
+
res = next(gen)
|
322 |
+
except StopIteration:
|
323 |
+
break
|
324 |
+
except Exception as e:
|
325 |
+
res_queue.put(ExceptionResult(e))
|
326 |
+
break
|
327 |
+
try:
|
328 |
+
res_queue.put(OkResult(res))
|
329 |
+
except PicklingError as e:
|
330 |
+
res_queue.put(ExceptionResult(e))
|
331 |
+
break
|
332 |
+
else:
|
333 |
+
continue
|
334 |
+
GradioPartialContext.apply(gradio_context)
|
335 |
+
context = copy_context()
|
336 |
+
with ThreadPoolExecutor() as executor:
|
337 |
+
executor.submit(context.run, iterate)
|
338 |
+
res_queue.put(EndResult())
|
339 |
+
|
340 |
+
return gradio_handler
|