afff4f781c38808b9aafcdd5ec92a88a4aca77a59bb34d95614d31dab397a490
Browse files- lib/python3.11/site-packages/functorch/dim/op_properties.py +311 -0
- lib/python3.11/site-packages/functorch/dim/reference.py +645 -0
- lib/python3.11/site-packages/functorch/dim/tree_map.py +14 -0
- lib/python3.11/site-packages/functorch/dim/wrap_type.py +71 -0
- lib/python3.11/site-packages/functorch/einops/__init__.py +3 -0
- lib/python3.11/site-packages/functorch/einops/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/functorch/einops/__pycache__/_parsing.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/functorch/einops/__pycache__/rearrange.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/functorch/einops/_parsing.py +302 -0
- lib/python3.11/site-packages/functorch/einops/rearrange.py +207 -0
- lib/python3.11/site-packages/functorch/experimental/__init__.py +6 -0
- lib/python3.11/site-packages/functorch/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/functorch/experimental/__pycache__/_map.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/functorch/experimental/__pycache__/control_flow.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/functorch/experimental/__pycache__/ops.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/functorch/experimental/_map.py +393 -0
- lib/python3.11/site-packages/functorch/experimental/control_flow.py +6 -0
- lib/python3.11/site-packages/functorch/experimental/ops.py +1 -0
- lib/python3.11/site-packages/huggingface_hub/__init__.py +650 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_commit_api.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_commit_scheduler.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_inference_endpoints.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_login.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_multi_commits.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_snapshot_download.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_space_api.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_tensorboard_logger.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_webhooks_payload.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/_webhooks_server.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/community.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/constants.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/fastai_utils.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/file_download.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_api.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_file_system.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/hub_mixin.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/inference_api.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/keras_mixin.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/lfs.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/repocard.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/repocard_data.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/__pycache__/repository.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/huggingface_hub/_commit_api.py +670 -0
- lib/python3.11/site-packages/huggingface_hub/_commit_scheduler.py +327 -0
- lib/python3.11/site-packages/huggingface_hub/_inference_endpoints.py +373 -0
- lib/python3.11/site-packages/huggingface_hub/_login.py +395 -0
- lib/python3.11/site-packages/huggingface_hub/_multi_commits.py +305 -0
- lib/python3.11/site-packages/huggingface_hub/_snapshot_download.py +319 -0
- lib/python3.11/site-packages/huggingface_hub/_space_api.py +154 -0
lib/python3.11/site-packages/functorch/dim/op_properties.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import torch
|
7 |
+
|
8 |
+
# pointwise operators can go through a faster pathway
|
9 |
+
|
10 |
+
tensor_magic_methods = ["add", ""]
|
11 |
+
pointwise_magic_methods_with_reverse = (
|
12 |
+
"add",
|
13 |
+
"sub",
|
14 |
+
"mul",
|
15 |
+
"floordiv",
|
16 |
+
"div",
|
17 |
+
"truediv",
|
18 |
+
"mod",
|
19 |
+
"pow",
|
20 |
+
"lshift",
|
21 |
+
"rshift",
|
22 |
+
"and",
|
23 |
+
"or",
|
24 |
+
"xor",
|
25 |
+
)
|
26 |
+
pointwise_magic_methods = (
|
27 |
+
*(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
|
28 |
+
"eq",
|
29 |
+
"gt",
|
30 |
+
"le",
|
31 |
+
"lt",
|
32 |
+
"ge",
|
33 |
+
"gt",
|
34 |
+
"ne",
|
35 |
+
"neg",
|
36 |
+
"pos",
|
37 |
+
"abs",
|
38 |
+
"invert",
|
39 |
+
"iadd",
|
40 |
+
"isub",
|
41 |
+
"imul",
|
42 |
+
"ifloordiv",
|
43 |
+
"idiv",
|
44 |
+
"itruediv",
|
45 |
+
"imod",
|
46 |
+
"ipow",
|
47 |
+
"ilshift",
|
48 |
+
"irshift",
|
49 |
+
"iand",
|
50 |
+
"ior",
|
51 |
+
"ixor",
|
52 |
+
"int",
|
53 |
+
"long",
|
54 |
+
"float",
|
55 |
+
"complex",
|
56 |
+
)
|
57 |
+
|
58 |
+
pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
|
59 |
+
|
60 |
+
pointwise = (
|
61 |
+
*(getattr(torch.Tensor, m) for m in pointwise_methods),
|
62 |
+
torch.nn.functional.dropout,
|
63 |
+
torch.where,
|
64 |
+
torch.Tensor.abs,
|
65 |
+
torch.abs,
|
66 |
+
torch.Tensor.acos,
|
67 |
+
torch.acos,
|
68 |
+
torch.Tensor.acosh,
|
69 |
+
torch.acosh,
|
70 |
+
torch.Tensor.add,
|
71 |
+
torch.add,
|
72 |
+
torch.Tensor.addcdiv,
|
73 |
+
torch.addcdiv,
|
74 |
+
torch.Tensor.addcmul,
|
75 |
+
torch.addcmul,
|
76 |
+
torch.Tensor.addr,
|
77 |
+
torch.addr,
|
78 |
+
torch.Tensor.angle,
|
79 |
+
torch.angle,
|
80 |
+
torch.Tensor.asin,
|
81 |
+
torch.asin,
|
82 |
+
torch.Tensor.asinh,
|
83 |
+
torch.asinh,
|
84 |
+
torch.Tensor.atan,
|
85 |
+
torch.atan,
|
86 |
+
torch.Tensor.atan2,
|
87 |
+
torch.atan2,
|
88 |
+
torch.Tensor.atanh,
|
89 |
+
torch.atanh,
|
90 |
+
torch.Tensor.bitwise_and,
|
91 |
+
torch.bitwise_and,
|
92 |
+
torch.Tensor.bitwise_left_shift,
|
93 |
+
torch.bitwise_left_shift,
|
94 |
+
torch.Tensor.bitwise_not,
|
95 |
+
torch.bitwise_not,
|
96 |
+
torch.Tensor.bitwise_or,
|
97 |
+
torch.bitwise_or,
|
98 |
+
torch.Tensor.bitwise_right_shift,
|
99 |
+
torch.bitwise_right_shift,
|
100 |
+
torch.Tensor.bitwise_xor,
|
101 |
+
torch.bitwise_xor,
|
102 |
+
torch.Tensor.ceil,
|
103 |
+
torch.ceil,
|
104 |
+
torch.celu,
|
105 |
+
torch.nn.functional.celu,
|
106 |
+
torch.Tensor.clamp,
|
107 |
+
torch.clamp,
|
108 |
+
torch.Tensor.clamp_max,
|
109 |
+
torch.clamp_max,
|
110 |
+
torch.Tensor.clamp_min,
|
111 |
+
torch.clamp_min,
|
112 |
+
torch.Tensor.copysign,
|
113 |
+
torch.copysign,
|
114 |
+
torch.Tensor.cos,
|
115 |
+
torch.cos,
|
116 |
+
torch.Tensor.cosh,
|
117 |
+
torch.cosh,
|
118 |
+
torch.Tensor.deg2rad,
|
119 |
+
torch.deg2rad,
|
120 |
+
torch.Tensor.digamma,
|
121 |
+
torch.digamma,
|
122 |
+
torch.Tensor.div,
|
123 |
+
torch.div,
|
124 |
+
torch.dropout,
|
125 |
+
torch.nn.functional.dropout,
|
126 |
+
torch.nn.functional.elu,
|
127 |
+
torch.Tensor.eq,
|
128 |
+
torch.eq,
|
129 |
+
torch.Tensor.erf,
|
130 |
+
torch.erf,
|
131 |
+
torch.Tensor.erfc,
|
132 |
+
torch.erfc,
|
133 |
+
torch.Tensor.erfinv,
|
134 |
+
torch.erfinv,
|
135 |
+
torch.Tensor.exp,
|
136 |
+
torch.exp,
|
137 |
+
torch.Tensor.exp2,
|
138 |
+
torch.exp2,
|
139 |
+
torch.Tensor.expm1,
|
140 |
+
torch.expm1,
|
141 |
+
torch.feature_dropout,
|
142 |
+
torch.Tensor.float_power,
|
143 |
+
torch.float_power,
|
144 |
+
torch.Tensor.floor,
|
145 |
+
torch.floor,
|
146 |
+
torch.Tensor.floor_divide,
|
147 |
+
torch.floor_divide,
|
148 |
+
torch.Tensor.fmod,
|
149 |
+
torch.fmod,
|
150 |
+
torch.Tensor.frac,
|
151 |
+
torch.frac,
|
152 |
+
torch.Tensor.frexp,
|
153 |
+
torch.frexp,
|
154 |
+
torch.Tensor.gcd,
|
155 |
+
torch.gcd,
|
156 |
+
torch.Tensor.ge,
|
157 |
+
torch.ge,
|
158 |
+
torch.nn.functional.gelu,
|
159 |
+
torch.nn.functional.glu,
|
160 |
+
torch.Tensor.gt,
|
161 |
+
torch.gt,
|
162 |
+
torch.Tensor.hardshrink,
|
163 |
+
torch.hardshrink,
|
164 |
+
torch.nn.functional.hardshrink,
|
165 |
+
torch.nn.functional.hardsigmoid,
|
166 |
+
torch.nn.functional.hardswish,
|
167 |
+
torch.nn.functional.hardtanh,
|
168 |
+
torch.Tensor.heaviside,
|
169 |
+
torch.heaviside,
|
170 |
+
torch.Tensor.hypot,
|
171 |
+
torch.hypot,
|
172 |
+
torch.Tensor.i0,
|
173 |
+
torch.i0,
|
174 |
+
torch.Tensor.igamma,
|
175 |
+
torch.igamma,
|
176 |
+
torch.Tensor.igammac,
|
177 |
+
torch.igammac,
|
178 |
+
torch.Tensor.isclose,
|
179 |
+
torch.isclose,
|
180 |
+
torch.Tensor.isfinite,
|
181 |
+
torch.isfinite,
|
182 |
+
torch.Tensor.isinf,
|
183 |
+
torch.isinf,
|
184 |
+
torch.Tensor.isnan,
|
185 |
+
torch.isnan,
|
186 |
+
torch.Tensor.isneginf,
|
187 |
+
torch.isneginf,
|
188 |
+
torch.Tensor.isposinf,
|
189 |
+
torch.isposinf,
|
190 |
+
torch.Tensor.isreal,
|
191 |
+
torch.isreal,
|
192 |
+
torch.Tensor.kron,
|
193 |
+
torch.kron,
|
194 |
+
torch.Tensor.lcm,
|
195 |
+
torch.lcm,
|
196 |
+
torch.Tensor.ldexp,
|
197 |
+
torch.ldexp,
|
198 |
+
torch.Tensor.le,
|
199 |
+
torch.le,
|
200 |
+
torch.nn.functional.leaky_relu,
|
201 |
+
torch.Tensor.lerp,
|
202 |
+
torch.lerp,
|
203 |
+
torch.Tensor.lgamma,
|
204 |
+
torch.lgamma,
|
205 |
+
torch.Tensor.log,
|
206 |
+
torch.log,
|
207 |
+
torch.Tensor.log10,
|
208 |
+
torch.log10,
|
209 |
+
torch.Tensor.log1p,
|
210 |
+
torch.log1p,
|
211 |
+
torch.Tensor.log2,
|
212 |
+
torch.log2,
|
213 |
+
torch.nn.functional.logsigmoid,
|
214 |
+
torch.Tensor.logical_and,
|
215 |
+
torch.logical_and,
|
216 |
+
torch.Tensor.logical_not,
|
217 |
+
torch.logical_not,
|
218 |
+
torch.Tensor.logical_or,
|
219 |
+
torch.logical_or,
|
220 |
+
torch.Tensor.logical_xor,
|
221 |
+
torch.logical_xor,
|
222 |
+
torch.Tensor.logit,
|
223 |
+
torch.logit,
|
224 |
+
torch.Tensor.lt,
|
225 |
+
torch.lt,
|
226 |
+
torch.Tensor.maximum,
|
227 |
+
torch.maximum,
|
228 |
+
torch.Tensor.minimum,
|
229 |
+
torch.minimum,
|
230 |
+
torch.nn.functional.mish,
|
231 |
+
torch.Tensor.mvlgamma,
|
232 |
+
torch.mvlgamma,
|
233 |
+
torch.Tensor.nan_to_num,
|
234 |
+
torch.nan_to_num,
|
235 |
+
torch.Tensor.ne,
|
236 |
+
torch.ne,
|
237 |
+
torch.Tensor.neg,
|
238 |
+
torch.neg,
|
239 |
+
torch.Tensor.nextafter,
|
240 |
+
torch.nextafter,
|
241 |
+
torch.Tensor.outer,
|
242 |
+
torch.outer,
|
243 |
+
torch.polar,
|
244 |
+
torch.Tensor.polygamma,
|
245 |
+
torch.polygamma,
|
246 |
+
torch.Tensor.positive,
|
247 |
+
torch.positive,
|
248 |
+
torch.Tensor.pow,
|
249 |
+
torch.pow,
|
250 |
+
torch.Tensor.prelu,
|
251 |
+
torch.prelu,
|
252 |
+
torch.nn.functional.prelu,
|
253 |
+
torch.Tensor.rad2deg,
|
254 |
+
torch.rad2deg,
|
255 |
+
torch.Tensor.reciprocal,
|
256 |
+
torch.reciprocal,
|
257 |
+
torch.Tensor.relu,
|
258 |
+
torch.relu,
|
259 |
+
torch.nn.functional.relu,
|
260 |
+
torch.nn.functional.relu6,
|
261 |
+
torch.Tensor.remainder,
|
262 |
+
torch.remainder,
|
263 |
+
torch.Tensor.round,
|
264 |
+
torch.round,
|
265 |
+
torch.rrelu,
|
266 |
+
torch.nn.functional.rrelu,
|
267 |
+
torch.Tensor.rsqrt,
|
268 |
+
torch.rsqrt,
|
269 |
+
torch.rsub,
|
270 |
+
torch.selu,
|
271 |
+
torch.nn.functional.selu,
|
272 |
+
torch.Tensor.sgn,
|
273 |
+
torch.sgn,
|
274 |
+
torch.Tensor.sigmoid,
|
275 |
+
torch.sigmoid,
|
276 |
+
torch.nn.functional.sigmoid,
|
277 |
+
torch.Tensor.sign,
|
278 |
+
torch.sign,
|
279 |
+
torch.Tensor.signbit,
|
280 |
+
torch.signbit,
|
281 |
+
torch.nn.functional.silu,
|
282 |
+
torch.Tensor.sin,
|
283 |
+
torch.sin,
|
284 |
+
torch.Tensor.sinc,
|
285 |
+
torch.sinc,
|
286 |
+
torch.Tensor.sinh,
|
287 |
+
torch.sinh,
|
288 |
+
torch.nn.functional.softplus,
|
289 |
+
torch.nn.functional.softshrink,
|
290 |
+
torch.Tensor.sqrt,
|
291 |
+
torch.sqrt,
|
292 |
+
torch.Tensor.square,
|
293 |
+
torch.square,
|
294 |
+
torch.Tensor.sub,
|
295 |
+
torch.sub,
|
296 |
+
torch.Tensor.tan,
|
297 |
+
torch.tan,
|
298 |
+
torch.Tensor.tanh,
|
299 |
+
torch.tanh,
|
300 |
+
torch.nn.functional.tanh,
|
301 |
+
torch.threshold,
|
302 |
+
torch.nn.functional.threshold,
|
303 |
+
torch.trapz,
|
304 |
+
torch.Tensor.true_divide,
|
305 |
+
torch.true_divide,
|
306 |
+
torch.Tensor.trunc,
|
307 |
+
torch.trunc,
|
308 |
+
torch.Tensor.xlogy,
|
309 |
+
torch.xlogy,
|
310 |
+
torch.rand_like,
|
311 |
+
)
|
lib/python3.11/site-packages/functorch/dim/reference.py
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# reference python implementations for C ops
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from functorch._C import dim as _C
|
11 |
+
from . import op_properties
|
12 |
+
from .batch_tensor import _enable_layers
|
13 |
+
from .tree_map import tree_flatten, tree_map
|
14 |
+
|
15 |
+
DimList = _C.DimList
|
16 |
+
import operator
|
17 |
+
from functools import reduce
|
18 |
+
|
19 |
+
|
20 |
+
# use dict to avoid writing C++ bindings for set
|
21 |
+
pointwise = set(op_properties.pointwise)
|
22 |
+
|
23 |
+
|
24 |
+
def prod(x):
|
25 |
+
return reduce(operator.mul, x, 1)
|
26 |
+
|
27 |
+
|
28 |
+
def _wrap_dim(d, N, keepdim):
|
29 |
+
from . import Dim
|
30 |
+
|
31 |
+
if isinstance(d, Dim):
|
32 |
+
assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
|
33 |
+
return d
|
34 |
+
elif d >= 0:
|
35 |
+
return d - N
|
36 |
+
else:
|
37 |
+
return d
|
38 |
+
|
39 |
+
|
40 |
+
def _dims(d, N, keepdim, single_dim):
|
41 |
+
from . import Dim
|
42 |
+
|
43 |
+
if isinstance(d, (Dim, int)):
|
44 |
+
return ltuple((_wrap_dim(d, N, keepdim),))
|
45 |
+
assert not single_dim, f"expected a single dimension or int but found: {d}"
|
46 |
+
return ltuple(_wrap_dim(x, N, keepdim) for x in d)
|
47 |
+
|
48 |
+
|
49 |
+
def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
|
50 |
+
from . import DimensionMismatchError
|
51 |
+
|
52 |
+
not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
|
53 |
+
if len(not_bound) == 1:
|
54 |
+
idx, d = not_bound[0]
|
55 |
+
rhs_so_far = prod(r.size for r in rhs if r.is_bound)
|
56 |
+
if lhs_size % rhs_so_far != 0:
|
57 |
+
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
|
58 |
+
raise DimensionMismatchError(
|
59 |
+
f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
|
60 |
+
)
|
61 |
+
new_size = lhs_size // rhs_so_far
|
62 |
+
d.size = new_size
|
63 |
+
elif len(not_bound) > 1:
|
64 |
+
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
|
65 |
+
raise DimensionMismatchError(
|
66 |
+
f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
rhs_size = prod(r.size for r in rhs)
|
70 |
+
if lhs_size != rhs_size:
|
71 |
+
raise DimensionMismatchError(
|
72 |
+
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def _tensor_levels(inp):
|
77 |
+
from . import _Tensor
|
78 |
+
|
79 |
+
if isinstance(inp, _Tensor):
|
80 |
+
return inp._tensor, llist(inp._levels), inp._has_device
|
81 |
+
else:
|
82 |
+
return inp, llist(range(-inp.ndim, 0)), True
|
83 |
+
|
84 |
+
|
85 |
+
def _match_levels(v, from_levels, to_levels):
|
86 |
+
view = []
|
87 |
+
permute = []
|
88 |
+
requires_view = False
|
89 |
+
size = v.size()
|
90 |
+
for t in to_levels:
|
91 |
+
try:
|
92 |
+
idx = from_levels.index(t)
|
93 |
+
permute.append(idx)
|
94 |
+
view.append(size[idx])
|
95 |
+
except ValueError:
|
96 |
+
view.append(1)
|
97 |
+
requires_view = True
|
98 |
+
if permute != list(range(len(permute))):
|
99 |
+
v = v.permute(*permute)
|
100 |
+
if requires_view:
|
101 |
+
v = v.view(*view)
|
102 |
+
return v
|
103 |
+
|
104 |
+
|
105 |
+
# make a single dimension positional but do not permute it,
|
106 |
+
# used to do multi-tensor operators where the dim being acted on
|
107 |
+
# should not physically move if possible
|
108 |
+
def _positional_no_permute(self, dim, expand_dim=False):
|
109 |
+
from . import Tensor
|
110 |
+
|
111 |
+
ptensor, levels = self._tensor, llist(self._levels)
|
112 |
+
try:
|
113 |
+
idx = levels.index(dim)
|
114 |
+
except ValueError:
|
115 |
+
if not expand_dim:
|
116 |
+
raise
|
117 |
+
idx = 0
|
118 |
+
ptensor = ptensor.expand(dim.size, *ptensor.size())
|
119 |
+
levels.insert(0, 0)
|
120 |
+
idx_batched = 0
|
121 |
+
for i in range(idx):
|
122 |
+
if isinstance(levels[i], int):
|
123 |
+
levels[i] -= 1
|
124 |
+
idx_batched += 1
|
125 |
+
levels[idx] = -idx_batched - 1
|
126 |
+
return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
|
127 |
+
|
128 |
+
|
129 |
+
def seq(a, b):
|
130 |
+
from . import Dim
|
131 |
+
|
132 |
+
if isinstance(a, Dim) != isinstance(b, Dim):
|
133 |
+
return False
|
134 |
+
if isinstance(a, Dim):
|
135 |
+
return a is b
|
136 |
+
else:
|
137 |
+
return a == b
|
138 |
+
|
139 |
+
|
140 |
+
class isin:
|
141 |
+
def __contains__(self, item):
|
142 |
+
for x in self:
|
143 |
+
if seq(item, x):
|
144 |
+
return True
|
145 |
+
return False
|
146 |
+
|
147 |
+
def index(self, item):
|
148 |
+
for i, x in enumerate(self):
|
149 |
+
if seq(item, x):
|
150 |
+
return i
|
151 |
+
raise ValueError
|
152 |
+
|
153 |
+
|
154 |
+
class llist(isin, list):
|
155 |
+
pass
|
156 |
+
|
157 |
+
|
158 |
+
class ltuple(isin, tuple):
|
159 |
+
pass
|
160 |
+
|
161 |
+
|
162 |
+
empty_dict = {}
|
163 |
+
|
164 |
+
|
165 |
+
@classmethod
|
166 |
+
def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
|
167 |
+
from . import _Tensor, Tensor, TensorLike
|
168 |
+
from .delayed_mul_tensor import DelayedMulTensor
|
169 |
+
|
170 |
+
if orig is torch.Tensor.__mul__:
|
171 |
+
lhs, rhs = args
|
172 |
+
if (
|
173 |
+
isinstance(lhs, _Tensor)
|
174 |
+
and isinstance(rhs, _Tensor)
|
175 |
+
and lhs.ndim == 0
|
176 |
+
and rhs.ndim == 0
|
177 |
+
):
|
178 |
+
return DelayedMulTensor(lhs, rhs)
|
179 |
+
all_dims = llist()
|
180 |
+
flat_args, unflatten = tree_flatten((args, kwargs))
|
181 |
+
device_holding_tensor = None
|
182 |
+
for f in flat_args:
|
183 |
+
if isinstance(f, _Tensor):
|
184 |
+
if f._has_device:
|
185 |
+
device_holding_tensor = f._batchtensor
|
186 |
+
for d in f.dims:
|
187 |
+
if d not in all_dims:
|
188 |
+
all_dims.append(d)
|
189 |
+
|
190 |
+
def unwrap(t):
|
191 |
+
if isinstance(t, _Tensor):
|
192 |
+
r = t._batchtensor
|
193 |
+
if device_holding_tensor is not None and not t._has_device:
|
194 |
+
r = r.to(device=device_holding_tensor.device)
|
195 |
+
return r
|
196 |
+
return t
|
197 |
+
|
198 |
+
if orig in pointwise:
|
199 |
+
result_levels = llist()
|
200 |
+
arg_levels = llist()
|
201 |
+
to_expand = []
|
202 |
+
for i, f in enumerate(flat_args):
|
203 |
+
if isinstance(f, TensorLike):
|
204 |
+
ptensor, levels, _ = _tensor_levels(f)
|
205 |
+
if (
|
206 |
+
isinstance(f, _Tensor)
|
207 |
+
and not f._has_device
|
208 |
+
and device_holding_tensor is not None
|
209 |
+
):
|
210 |
+
ptensor = ptensor.to(device=device_holding_tensor.device)
|
211 |
+
flat_args[i] = ptensor
|
212 |
+
for l in levels:
|
213 |
+
if l not in result_levels:
|
214 |
+
result_levels.append(l)
|
215 |
+
to_expand.append((i, levels))
|
216 |
+
|
217 |
+
for i, levels in to_expand:
|
218 |
+
flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
|
219 |
+
args, kwargs = unflatten(flat_args)
|
220 |
+
result = orig(*args, **kwargs)
|
221 |
+
|
222 |
+
def wrap(t):
|
223 |
+
if isinstance(t, TensorLike):
|
224 |
+
return Tensor.from_positional(
|
225 |
+
t, result_levels, device_holding_tensor is not None
|
226 |
+
)
|
227 |
+
return t
|
228 |
+
|
229 |
+
return tree_map(wrap, result)
|
230 |
+
else:
|
231 |
+
|
232 |
+
def wrap(t):
|
233 |
+
if isinstance(t, TensorLike):
|
234 |
+
return Tensor.from_batched(t, device_holding_tensor is not None)
|
235 |
+
return t
|
236 |
+
|
237 |
+
with _enable_layers(all_dims):
|
238 |
+
print(f"batch_tensor for {orig}")
|
239 |
+
args, kwargs = unflatten(unwrap(f) for f in flat_args)
|
240 |
+
result = orig(*args, **kwargs)
|
241 |
+
# print("END", orig)
|
242 |
+
return tree_map(wrap, result)
|
243 |
+
|
244 |
+
|
245 |
+
def positional(self, *dims):
|
246 |
+
from . import Dim, Tensor
|
247 |
+
|
248 |
+
ptensor, levels = self._tensor, llist(self._levels)
|
249 |
+
flat_dims = llist()
|
250 |
+
view = []
|
251 |
+
needs_view = False
|
252 |
+
ndim = self.ndim
|
253 |
+
for d in dims:
|
254 |
+
if isinstance(d, DimList):
|
255 |
+
flat_dims.extend(d)
|
256 |
+
view.extend(e.size for e in d)
|
257 |
+
elif isinstance(d, Dim):
|
258 |
+
flat_dims.append(d)
|
259 |
+
view.append(d.size)
|
260 |
+
elif isinstance(d, int):
|
261 |
+
d = _wrap_dim(d, ndim, False)
|
262 |
+
flat_dims.append(d)
|
263 |
+
view.append(ptensor.size(d))
|
264 |
+
else:
|
265 |
+
flat_dims.extend(d)
|
266 |
+
view.append(prod(e.size for e in d))
|
267 |
+
needs_view = True
|
268 |
+
|
269 |
+
permute = list(range(len(levels)))
|
270 |
+
nflat = len(flat_dims)
|
271 |
+
for i, d in enumerate(flat_dims):
|
272 |
+
try:
|
273 |
+
idx = levels.index(d)
|
274 |
+
except ValueError as e:
|
275 |
+
raise DimensionBindError(
|
276 |
+
f"tensor of dimensions {self.dims} does not contain dim {d}"
|
277 |
+
) from e
|
278 |
+
p = permute[idx]
|
279 |
+
del levels[idx]
|
280 |
+
del permute[idx]
|
281 |
+
levels.insert(i, 0)
|
282 |
+
permute.insert(i, p)
|
283 |
+
ptensor = ptensor.permute(*permute)
|
284 |
+
seen = 0
|
285 |
+
for i in range(len(levels) - 1, -1, -1):
|
286 |
+
if isinstance(levels[i], int):
|
287 |
+
seen += 1
|
288 |
+
levels[i] = -seen
|
289 |
+
result = Tensor.from_positional(ptensor, levels, self._has_device)
|
290 |
+
if needs_view:
|
291 |
+
result = result.reshape(*view, *result.size()[len(flat_dims) :])
|
292 |
+
return result
|
293 |
+
|
294 |
+
|
295 |
+
def _contains_dim(input):
|
296 |
+
from . import Dim
|
297 |
+
|
298 |
+
for i in input:
|
299 |
+
if isinstance(i, Dim):
|
300 |
+
return True
|
301 |
+
|
302 |
+
|
303 |
+
def expand(self, *sizes):
|
304 |
+
if not _contains_dim(sizes):
|
305 |
+
return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
|
306 |
+
dims = sizes
|
307 |
+
sizes = [d.size for d in dims] + [-1] * self.ndim
|
308 |
+
self = self.expand(*sizes)
|
309 |
+
return self[dims]
|
310 |
+
|
311 |
+
|
312 |
+
_not_present = object()
|
313 |
+
|
314 |
+
|
315 |
+
def _getarg(name, offset, args, kwargs, default):
|
316 |
+
if len(args) > offset:
|
317 |
+
return args[offset]
|
318 |
+
return kwargs.get(name, default)
|
319 |
+
|
320 |
+
|
321 |
+
def _patcharg(name, offset, args, kwargs, value):
|
322 |
+
if len(args) > offset:
|
323 |
+
args[offset] = value
|
324 |
+
else:
|
325 |
+
kwargs[name] = value
|
326 |
+
|
327 |
+
|
328 |
+
def _wrap(
|
329 |
+
orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
|
330 |
+
):
|
331 |
+
from . import Dim, Tensor, TensorLike
|
332 |
+
|
333 |
+
def fn(self, *args, **kwargs):
|
334 |
+
dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
|
335 |
+
if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
|
336 |
+
with _enable_layers(self.dims):
|
337 |
+
print(f"dim fallback batch_tensor for {orig}")
|
338 |
+
return Tensor.from_batched(
|
339 |
+
orig(self._batchtensor, *args, **kwargs), self._has_device
|
340 |
+
)
|
341 |
+
keepdim = (
|
342 |
+
_getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
|
343 |
+
)
|
344 |
+
t, levels = self._tensor, llist(self._levels)
|
345 |
+
dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
|
346 |
+
dim_indices = tuple(levels.index(d) for d in dims)
|
347 |
+
if reduce and not keepdim:
|
348 |
+
new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
|
349 |
+
else:
|
350 |
+
new_levels = levels
|
351 |
+
|
352 |
+
if len(dim_indices) == 1:
|
353 |
+
dim_indices = dim_indices[
|
354 |
+
0
|
355 |
+
] # so that dims that really only take a single argument work...
|
356 |
+
args = list(args)
|
357 |
+
_patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
|
358 |
+
|
359 |
+
def wrap(t):
|
360 |
+
if isinstance(t, TensorLike):
|
361 |
+
return Tensor.from_positional(t, new_levels, self._has_device)
|
362 |
+
return t
|
363 |
+
|
364 |
+
with _enable_layers(new_levels):
|
365 |
+
print(f"dim used batch_tensor for {orig}")
|
366 |
+
r = orig(t, *args, **kwargs)
|
367 |
+
return tree_map(wrap, r)
|
368 |
+
|
369 |
+
return fn
|
370 |
+
|
371 |
+
|
372 |
+
def _def(name, *args, **kwargs):
|
373 |
+
from . import _Tensor
|
374 |
+
|
375 |
+
orig = getattr(torch.Tensor, name)
|
376 |
+
setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
|
377 |
+
|
378 |
+
|
379 |
+
no_slice = slice(None)
|
380 |
+
|
381 |
+
_orig_getitem = torch.Tensor.__getitem__
|
382 |
+
|
383 |
+
|
384 |
+
class dim_tracker:
|
385 |
+
def __init__(self):
|
386 |
+
self.dims = llist()
|
387 |
+
self.count = []
|
388 |
+
|
389 |
+
def record(self, d):
|
390 |
+
if d not in self.dims:
|
391 |
+
self.dims.append(d)
|
392 |
+
self.count.append(1)
|
393 |
+
|
394 |
+
def __getitem__(self, d):
|
395 |
+
return self.count[self.dims.index(d)]
|
396 |
+
|
397 |
+
|
398 |
+
def t__getitem__(self, input):
|
399 |
+
from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
|
400 |
+
|
401 |
+
# * bail to original example if we have a single non-Dim tensor, or a non-tensor
|
402 |
+
# * locate ... or an unbound tensor list, and determine its size, bind dim list
|
403 |
+
# (remember that None does not count to the total dim count)
|
404 |
+
# * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
|
405 |
+
# produce the re-view if needed
|
406 |
+
# * for each single-use dim index, replace with no_slice and mark that it will be added
|
407 |
+
# (keep track of whether we have to call super)
|
408 |
+
# * call super if needed
|
409 |
+
# * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
|
410 |
+
|
411 |
+
# this handles bool indexing handling, as well as some other simple cases.
|
412 |
+
|
413 |
+
is_simple = (
|
414 |
+
not isinstance(input, Dim)
|
415 |
+
and not isinstance(input, (tuple, list))
|
416 |
+
and
|
417 |
+
# WAR for functorch bug where zero time tensors in getitem are not handled correctly.
|
418 |
+
not (isinstance(input, TensorLike) and input.ndim == 0)
|
419 |
+
)
|
420 |
+
|
421 |
+
if is_simple:
|
422 |
+
if isinstance(self, _Tensor):
|
423 |
+
return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
|
424 |
+
else:
|
425 |
+
return _orig_getitem(self, input)
|
426 |
+
|
427 |
+
# can further optimize this case
|
428 |
+
if not isinstance(input, tuple):
|
429 |
+
input = [input]
|
430 |
+
else:
|
431 |
+
input = list(input)
|
432 |
+
|
433 |
+
dims_indexed = 0
|
434 |
+
expanding_object = None
|
435 |
+
dimlists = []
|
436 |
+
for i, s in enumerate(input):
|
437 |
+
if s is ... or isinstance(s, DimList) and not s.is_bound:
|
438 |
+
if expanding_object is not None:
|
439 |
+
msg = (
|
440 |
+
"at most one ... or unbound dimension list can exist in indexing list but"
|
441 |
+
f" found 2 at offsets {i} and {expanding_object}"
|
442 |
+
)
|
443 |
+
raise DimensionBindError(msg)
|
444 |
+
expanding_object = i
|
445 |
+
|
446 |
+
if isinstance(s, DimList):
|
447 |
+
dims_indexed += len(s) if s.is_bound else 0
|
448 |
+
dimlists.append(i)
|
449 |
+
elif s is not None and s is not ...:
|
450 |
+
dims_indexed += 1
|
451 |
+
|
452 |
+
ndim = self.ndim
|
453 |
+
if dims_indexed > ndim:
|
454 |
+
raise IndexError(
|
455 |
+
f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
|
456 |
+
)
|
457 |
+
if expanding_object is not None:
|
458 |
+
expanding_ndims = ndim - dims_indexed
|
459 |
+
obj = input[expanding_object]
|
460 |
+
if obj is ...:
|
461 |
+
input[expanding_object : expanding_object + 1] = [
|
462 |
+
no_slice
|
463 |
+
] * expanding_ndims
|
464 |
+
else:
|
465 |
+
obj.bind_len(expanding_ndims)
|
466 |
+
# flatten the dimslists into the indexing
|
467 |
+
for i in reversed(dimlists):
|
468 |
+
input[i : i + 1] = input[i]
|
469 |
+
dims_indexed = 0
|
470 |
+
requires_view = False
|
471 |
+
size = self.size()
|
472 |
+
view_sizes = []
|
473 |
+
dims_seen = dim_tracker()
|
474 |
+
|
475 |
+
def add_dims(t):
|
476 |
+
if not isinstance(t, _Tensor):
|
477 |
+
return
|
478 |
+
for d in t.dims:
|
479 |
+
dims_seen.record(d)
|
480 |
+
|
481 |
+
add_dims(self)
|
482 |
+
dim_packs = []
|
483 |
+
for i, idx in enumerate(input):
|
484 |
+
if idx is None:
|
485 |
+
input[i] = no_slice
|
486 |
+
view_sizes.append(1)
|
487 |
+
requires_view = True
|
488 |
+
else:
|
489 |
+
sz = size[dims_indexed]
|
490 |
+
if isinstance(idx, Dim):
|
491 |
+
idx.size = sz
|
492 |
+
dims_seen.record(idx)
|
493 |
+
view_sizes.append(sz)
|
494 |
+
elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
|
495 |
+
for d in idx:
|
496 |
+
dims_seen.record(idx)
|
497 |
+
_bind_dims_to_size(sz, idx, f"offset {i}")
|
498 |
+
view_sizes.extend(d.size for d in idx)
|
499 |
+
requires_view = True
|
500 |
+
dim_packs.append(i)
|
501 |
+
else:
|
502 |
+
add_dims(idx)
|
503 |
+
view_sizes.append(sz)
|
504 |
+
dims_indexed += 1
|
505 |
+
if requires_view:
|
506 |
+
self = self.view(*view_sizes)
|
507 |
+
for i in reversed(dim_packs):
|
508 |
+
input[i : i + 1] = input[i]
|
509 |
+
|
510 |
+
# currenty:
|
511 |
+
# input is flat, containing either Dim, or Tensor, or something valid for standard indexing
|
512 |
+
# self may have first-class dims as well.
|
513 |
+
|
514 |
+
# to index:
|
515 |
+
# drop the first class dims from self, they just become direct indices of their positions
|
516 |
+
|
517 |
+
# figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
|
518 |
+
# these dimensions will appear and need to be bound at the first place tensor occures
|
519 |
+
|
520 |
+
if isinstance(self, _Tensor):
|
521 |
+
ptensor_self, levels = self._tensor, list(self._levels)
|
522 |
+
# indices to ptensor rather than self which has first-class dimensions
|
523 |
+
input_it = iter(input)
|
524 |
+
flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
|
525 |
+
has_device = self._has_device
|
526 |
+
to_pad = 0
|
527 |
+
else:
|
528 |
+
ptensor_self, flat_inputs = self, input
|
529 |
+
to_pad = ptensor_self.ndim - len(flat_inputs)
|
530 |
+
has_device = True
|
531 |
+
|
532 |
+
result_levels = []
|
533 |
+
index_levels = []
|
534 |
+
tensor_insert_point = None
|
535 |
+
to_expand = {}
|
536 |
+
requires_getindex = False
|
537 |
+
for i, inp in enumerate(flat_inputs):
|
538 |
+
if isinstance(inp, Dim) and dims_seen[inp] == 1:
|
539 |
+
flat_inputs[i] = no_slice
|
540 |
+
result_levels.append(inp)
|
541 |
+
elif isinstance(inp, TensorLike):
|
542 |
+
requires_getindex = True
|
543 |
+
if tensor_insert_point is None:
|
544 |
+
tensor_insert_point = len(result_levels)
|
545 |
+
ptensor, levels, _ = _tensor_levels(inp)
|
546 |
+
to_expand[i] = levels
|
547 |
+
flat_inputs[i] = ptensor
|
548 |
+
for l in levels:
|
549 |
+
if l not in index_levels:
|
550 |
+
index_levels.append(l)
|
551 |
+
else:
|
552 |
+
requires_getindex = True
|
553 |
+
result_levels.append(0)
|
554 |
+
|
555 |
+
if tensor_insert_point is not None:
|
556 |
+
result_levels[tensor_insert_point:tensor_insert_point] = index_levels
|
557 |
+
|
558 |
+
for i, levels in to_expand.items():
|
559 |
+
flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
|
560 |
+
|
561 |
+
if requires_getindex:
|
562 |
+
result = _orig_getitem(ptensor_self, flat_inputs)
|
563 |
+
else:
|
564 |
+
result = ptensor_self
|
565 |
+
|
566 |
+
next_positional = -1
|
567 |
+
if to_pad > 0:
|
568 |
+
result_levels.extend([0] * to_pad)
|
569 |
+
for i, r in enumerate(reversed(result_levels)):
|
570 |
+
if isinstance(r, int):
|
571 |
+
result_levels[-1 - i] = next_positional
|
572 |
+
next_positional -= 1
|
573 |
+
|
574 |
+
return Tensor.from_positional(result, result_levels, has_device)
|
575 |
+
|
576 |
+
|
577 |
+
# XXX - dim is optional and can be the outer-most dimension...
|
578 |
+
def stack(tensors, new_dim, dim=0, out=None):
|
579 |
+
if isinstance(dim, int):
|
580 |
+
return torch.stack(tensors, dim, out).index(dim, new_dim)
|
581 |
+
index = None
|
582 |
+
if out is not None:
|
583 |
+
out, index = _positional_no_permute(out, dim, expand_dim=True)
|
584 |
+
ptensors = []
|
585 |
+
for t in tensors:
|
586 |
+
pt, pi = _positional_no_permute(t, dim, expand_dim=True)
|
587 |
+
if index is not None and pi != index:
|
588 |
+
pt = pt.move_dim(pi, index)
|
589 |
+
else:
|
590 |
+
index = pi
|
591 |
+
ptensors.append(pt)
|
592 |
+
pr = torch.stack(ptensors, index, out=out)
|
593 |
+
return pr.index((index, index + 1), (new_dim, dim))
|
594 |
+
|
595 |
+
|
596 |
+
_orig_split = torch.Tensor.split
|
597 |
+
|
598 |
+
|
599 |
+
def split(self, split_size_or_sections, dim=0):
|
600 |
+
from . import _Tensor, Dim
|
601 |
+
|
602 |
+
if isinstance(split_size_or_sections, int) or any(
|
603 |
+
isinstance(t, int) for t in split_size_or_sections
|
604 |
+
):
|
605 |
+
if isinstance(dim, Dim):
|
606 |
+
raise ValueError(
|
607 |
+
"when dim is specified as a Dim object, split sizes must also be dimensions."
|
608 |
+
)
|
609 |
+
return _orig_split(self, split_size_or_sections, dim=dim)
|
610 |
+
|
611 |
+
if isinstance(dim, Dim):
|
612 |
+
assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
|
613 |
+
self, dim = _positional_no_permute(self, dim)
|
614 |
+
|
615 |
+
size = self.size(dim)
|
616 |
+
total_bound_size = 0
|
617 |
+
unbound = []
|
618 |
+
sizes = []
|
619 |
+
for i, d in enumerate(split_size_or_sections):
|
620 |
+
if d.is_bound:
|
621 |
+
sizes.append(d.size)
|
622 |
+
total_bound_size += d.size
|
623 |
+
else:
|
624 |
+
sizes.append(0)
|
625 |
+
unbound.append(i)
|
626 |
+
|
627 |
+
if unbound:
|
628 |
+
assert (
|
629 |
+
total_bound_size <= size
|
630 |
+
), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
631 |
+
remaining_size = size - total_bound_size
|
632 |
+
chunk_size = -(-remaining_size // len(unbound))
|
633 |
+
for u in unbound:
|
634 |
+
sz = min(chunk_size, remaining_size)
|
635 |
+
split_size_or_sections[u].size = sz
|
636 |
+
sizes[u] = sz
|
637 |
+
remaining_size -= sz
|
638 |
+
else:
|
639 |
+
assert (
|
640 |
+
total_bound_size == size
|
641 |
+
), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
642 |
+
return tuple(
|
643 |
+
t.index(dim, d)
|
644 |
+
for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
|
645 |
+
)
|
lib/python3.11/site-packages/functorch/dim/tree_map.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from functorch._C import dim
|
8 |
+
|
9 |
+
tree_flatten = dim.tree_flatten
|
10 |
+
|
11 |
+
|
12 |
+
def tree_map(fn, tree):
|
13 |
+
vs, unflatten = tree_flatten(tree)
|
14 |
+
return unflatten(fn(v) for v in vs)
|
lib/python3.11/site-packages/functorch/dim/wrap_type.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from types import (
|
8 |
+
BuiltinMethodType,
|
9 |
+
FunctionType,
|
10 |
+
GetSetDescriptorType,
|
11 |
+
MethodDescriptorType,
|
12 |
+
WrapperDescriptorType,
|
13 |
+
)
|
14 |
+
|
15 |
+
from functorch._C import dim as _C
|
16 |
+
|
17 |
+
_wrap_method = _C._wrap_method
|
18 |
+
|
19 |
+
FUNC_TYPES = (
|
20 |
+
FunctionType,
|
21 |
+
MethodDescriptorType,
|
22 |
+
BuiltinMethodType,
|
23 |
+
WrapperDescriptorType,
|
24 |
+
)
|
25 |
+
PROPERTY_TYPES = (GetSetDescriptorType, property)
|
26 |
+
|
27 |
+
|
28 |
+
def _py_wrap_method(orig, __torch_function__):
|
29 |
+
def impl(*args, **kwargs):
|
30 |
+
return __torch_function__(orig, None, args, kwargs)
|
31 |
+
|
32 |
+
return impl
|
33 |
+
|
34 |
+
|
35 |
+
def wrap_type(use_c, to_patch, pattern, __torch_function__):
|
36 |
+
if use_c:
|
37 |
+
wrap_method = _wrap_method
|
38 |
+
else:
|
39 |
+
wrap_method = _py_wrap_method
|
40 |
+
|
41 |
+
all = {}
|
42 |
+
for t in reversed(pattern.mro()[:-1]): # skip object
|
43 |
+
all.update(t.__dict__)
|
44 |
+
|
45 |
+
def wrap_attr(orig):
|
46 |
+
return property(wrap_method(orig.__get__, __torch_function__))
|
47 |
+
|
48 |
+
for name, obj in all.items():
|
49 |
+
if name in (
|
50 |
+
"__dict__",
|
51 |
+
"__new__",
|
52 |
+
"__init__",
|
53 |
+
"__repr__",
|
54 |
+
"__weakref__",
|
55 |
+
"__doc__",
|
56 |
+
"__module__",
|
57 |
+
"__dir__",
|
58 |
+
):
|
59 |
+
continue
|
60 |
+
|
61 |
+
# skip things that have been overloaded
|
62 |
+
# things that come from object like `__eq__` still need to be patched, however.
|
63 |
+
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
|
64 |
+
object, name, None
|
65 |
+
):
|
66 |
+
continue
|
67 |
+
|
68 |
+
if isinstance(obj, FUNC_TYPES):
|
69 |
+
setattr(to_patch, name, wrap_method(obj, __torch_function__))
|
70 |
+
elif isinstance(obj, PROPERTY_TYPES):
|
71 |
+
setattr(to_patch, name, wrap_attr(obj))
|
lib/python3.11/site-packages/functorch/einops/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .rearrange import rearrange
|
2 |
+
|
3 |
+
__all__ = ["rearrange"]
|
lib/python3.11/site-packages/functorch/einops/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (302 Bytes). View file
|
|
lib/python3.11/site-packages/functorch/einops/__pycache__/_parsing.cpython-311.pyc
ADDED
Binary file (14.2 kB). View file
|
|
lib/python3.11/site-packages/functorch/einops/__pycache__/rearrange.cpython-311.pyc
ADDED
Binary file (10.8 kB). View file
|
|
lib/python3.11/site-packages/functorch/einops/_parsing.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py.
|
2 |
+
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2018 Alex Rogozhnikov
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
from __future__ import annotations
|
26 |
+
|
27 |
+
import keyword
|
28 |
+
import warnings
|
29 |
+
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
|
30 |
+
|
31 |
+
_ellipsis: str = "…" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
|
32 |
+
|
33 |
+
|
34 |
+
class AnonymousAxis:
|
35 |
+
"""Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier.
|
36 |
+
|
37 |
+
Note: Different instances of this class are not equal to each other, even if they have the same value.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, value: str) -> None:
|
41 |
+
self.value = int(value)
|
42 |
+
if self.value < 1:
|
43 |
+
raise ValueError(
|
44 |
+
f"Anonymous axis should have positive length, not {self.value}"
|
45 |
+
)
|
46 |
+
|
47 |
+
def __repr__(self) -> str:
|
48 |
+
return f"{self.value}-axis"
|
49 |
+
|
50 |
+
|
51 |
+
class ParsedExpression:
|
52 |
+
"""Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
expression: str,
|
57 |
+
*,
|
58 |
+
allow_underscore: bool = False,
|
59 |
+
allow_duplicates: bool = False,
|
60 |
+
) -> None:
|
61 |
+
"""Parse the expression and store relevant metadata.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
expression (str): the `einops`-pattern to parse
|
65 |
+
allow_underscore (bool): whether to allow axis identifier names to begin with an underscore
|
66 |
+
allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression
|
67 |
+
"""
|
68 |
+
self.has_ellipsis: bool = False
|
69 |
+
self.has_ellipsis_parenthesized: Optional[bool] = None
|
70 |
+
self.identifiers: Set[Union[str, AnonymousAxis]] = set()
|
71 |
+
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
|
72 |
+
self.has_non_unitary_anonymous_axes: bool = False
|
73 |
+
# composition keeps structure of composite axes, see how different corner cases are handled in tests
|
74 |
+
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
|
75 |
+
if "." in expression:
|
76 |
+
if "..." not in expression:
|
77 |
+
raise ValueError(
|
78 |
+
"Expression may contain dots only inside ellipsis (...)"
|
79 |
+
)
|
80 |
+
if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
|
81 |
+
raise ValueError(
|
82 |
+
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
|
83 |
+
)
|
84 |
+
expression = expression.replace("...", _ellipsis)
|
85 |
+
self.has_ellipsis = True
|
86 |
+
|
87 |
+
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
|
88 |
+
|
89 |
+
def add_axis_name(x: str) -> None:
|
90 |
+
if x in self.identifiers:
|
91 |
+
if not (allow_underscore and x == "_") and not allow_duplicates:
|
92 |
+
raise ValueError(
|
93 |
+
f"Indexing expression contains duplicate dimension '{x}'"
|
94 |
+
)
|
95 |
+
if x == _ellipsis:
|
96 |
+
self.identifiers.add(_ellipsis)
|
97 |
+
if bracket_group is None:
|
98 |
+
self.composition.append(_ellipsis)
|
99 |
+
self.has_ellipsis_parenthesized = False
|
100 |
+
else:
|
101 |
+
bracket_group.append(_ellipsis)
|
102 |
+
self.has_ellipsis_parenthesized = True
|
103 |
+
else:
|
104 |
+
is_number = str.isdecimal(x)
|
105 |
+
if is_number and int(x) == 1:
|
106 |
+
# handling the case of anonymous axis of length 1
|
107 |
+
if bracket_group is None:
|
108 |
+
self.composition.append([])
|
109 |
+
else:
|
110 |
+
pass # no need to think about 1s inside parenthesis
|
111 |
+
return
|
112 |
+
is_axis_name, reason = self.check_axis_name_return_reason(
|
113 |
+
x, allow_underscore=allow_underscore
|
114 |
+
)
|
115 |
+
if not (is_number or is_axis_name):
|
116 |
+
raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
|
117 |
+
axis_name: Union[str, AnonymousAxis] = (
|
118 |
+
AnonymousAxis(x) if is_number else x
|
119 |
+
)
|
120 |
+
self.identifiers.add(axis_name)
|
121 |
+
if is_number:
|
122 |
+
self.has_non_unitary_anonymous_axes = True
|
123 |
+
if bracket_group is None:
|
124 |
+
self.composition.append([axis_name])
|
125 |
+
else:
|
126 |
+
bracket_group.append(axis_name)
|
127 |
+
|
128 |
+
current_identifier = None
|
129 |
+
for char in expression:
|
130 |
+
if char in "() ":
|
131 |
+
if current_identifier is not None:
|
132 |
+
add_axis_name(current_identifier)
|
133 |
+
current_identifier = None
|
134 |
+
if char == "(":
|
135 |
+
if bracket_group is not None:
|
136 |
+
raise ValueError(
|
137 |
+
"Axis composition is one-level (brackets inside brackets not allowed)"
|
138 |
+
)
|
139 |
+
bracket_group = []
|
140 |
+
elif char == ")":
|
141 |
+
if bracket_group is None:
|
142 |
+
raise ValueError("Brackets are not balanced")
|
143 |
+
self.composition.append(bracket_group)
|
144 |
+
bracket_group = None
|
145 |
+
elif str.isalnum(char) or char in ["_", _ellipsis]:
|
146 |
+
if current_identifier is None:
|
147 |
+
current_identifier = char
|
148 |
+
else:
|
149 |
+
current_identifier += char
|
150 |
+
else:
|
151 |
+
raise ValueError(f"Unknown character '{char}'")
|
152 |
+
|
153 |
+
if bracket_group is not None:
|
154 |
+
raise ValueError(f"Imbalanced parentheses in expression: '{expression}'")
|
155 |
+
if current_identifier is not None:
|
156 |
+
add_axis_name(current_identifier)
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def check_axis_name_return_reason(
|
160 |
+
name: str, allow_underscore: bool = False
|
161 |
+
) -> Tuple[bool, str]:
|
162 |
+
"""Check if the given axis name is valid, and a message explaining why if not.
|
163 |
+
|
164 |
+
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
name (str): the axis name to check
|
168 |
+
allow_underscore (bool): whether axis names are allowed to start with an underscore
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
|
172 |
+
"""
|
173 |
+
if not str.isidentifier(name):
|
174 |
+
return False, "not a valid python identifier"
|
175 |
+
elif name[0] == "_" or name[-1] == "_":
|
176 |
+
if name == "_" and allow_underscore:
|
177 |
+
return True, ""
|
178 |
+
return False, "axis name should should not start or end with underscore"
|
179 |
+
else:
|
180 |
+
if keyword.iskeyword(name):
|
181 |
+
warnings.warn(
|
182 |
+
f"It is discouraged to use axes names that are keywords: {name}",
|
183 |
+
RuntimeWarning,
|
184 |
+
)
|
185 |
+
if name in ["axis"]:
|
186 |
+
warnings.warn(
|
187 |
+
"It is discouraged to use 'axis' as an axis name and will raise an error in future",
|
188 |
+
FutureWarning,
|
189 |
+
)
|
190 |
+
return True, ""
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def check_axis_name(name: str) -> bool:
|
194 |
+
"""Check if the name is a valid axis name.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
name (str): the axis name to check
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
bool: whether the axis name is valid
|
201 |
+
"""
|
202 |
+
is_valid, _ = ParsedExpression.check_axis_name_return_reason(name)
|
203 |
+
return is_valid
|
204 |
+
|
205 |
+
|
206 |
+
def parse_pattern(
|
207 |
+
pattern: str, axes_lengths: Mapping[str, int]
|
208 |
+
) -> Tuple[ParsedExpression, ParsedExpression]:
|
209 |
+
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
pattern (str): the `einops`-style rearrangement pattern
|
213 |
+
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
|
217 |
+
"""
|
218 |
+
# adapted from einops.einops._prepare_transformation_recipe
|
219 |
+
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
|
220 |
+
try:
|
221 |
+
left_str, right_str = pattern.split("->")
|
222 |
+
except ValueError:
|
223 |
+
raise ValueError("Pattern must contain a single '->' separator") from None
|
224 |
+
|
225 |
+
if _ellipsis in axes_lengths:
|
226 |
+
raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier")
|
227 |
+
|
228 |
+
left = ParsedExpression(left_str)
|
229 |
+
right = ParsedExpression(right_str)
|
230 |
+
|
231 |
+
if not left.has_ellipsis and right.has_ellipsis:
|
232 |
+
raise ValueError(
|
233 |
+
f"Ellipsis found in right side, but not left side of a pattern {pattern}"
|
234 |
+
)
|
235 |
+
if left.has_ellipsis and left.has_ellipsis_parenthesized:
|
236 |
+
raise ValueError(
|
237 |
+
f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
|
238 |
+
)
|
239 |
+
|
240 |
+
return left, right
|
241 |
+
|
242 |
+
|
243 |
+
def validate_rearrange_expressions(
|
244 |
+
left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int]
|
245 |
+
) -> None:
|
246 |
+
"""Perform expression validations that are specific to the `rearrange` operation.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
left (ParsedExpression): left-hand side expression
|
250 |
+
right (ParsedExpression): right-hand side expression
|
251 |
+
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
252 |
+
"""
|
253 |
+
for length in axes_lengths.values():
|
254 |
+
if (length_type := type(length)) is not int:
|
255 |
+
raise TypeError(
|
256 |
+
f"rearrange axis lengths must be integers, got: {length_type}"
|
257 |
+
)
|
258 |
+
|
259 |
+
if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
|
260 |
+
raise ValueError("rearrange only supports unnamed axes of size 1")
|
261 |
+
|
262 |
+
difference = set.symmetric_difference(left.identifiers, right.identifiers)
|
263 |
+
if len(difference) > 0:
|
264 |
+
raise ValueError(
|
265 |
+
f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
|
266 |
+
)
|
267 |
+
|
268 |
+
unmatched_axes = axes_lengths.keys() - left.identifiers
|
269 |
+
if len(unmatched_axes) > 0:
|
270 |
+
raise ValueError(
|
271 |
+
f"Identifiers not found in rearrange expression: {unmatched_axes}"
|
272 |
+
)
|
273 |
+
|
274 |
+
|
275 |
+
def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
|
276 |
+
"""Convert a collection of strings representing first class dims into a comma-separated string.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
str: the comma-separated string
|
283 |
+
|
284 |
+
Examples:
|
285 |
+
>>> comma_separate(('d0',))
|
286 |
+
'd0'
|
287 |
+
|
288 |
+
>>> comma_separate(('d0', 'd1', 'd2', 'd3'))
|
289 |
+
'd0, d1, d2, d3'
|
290 |
+
|
291 |
+
>>> comma_separate([('d1', 'd4')])
|
292 |
+
'(d1, d4)'
|
293 |
+
|
294 |
+
>>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')])
|
295 |
+
'(d0,), (), (d1,), (d2,), (d3, d4)'
|
296 |
+
"""
|
297 |
+
return ", ".join(
|
298 |
+
item
|
299 |
+
if isinstance(item, str)
|
300 |
+
else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
|
301 |
+
for item in collection
|
302 |
+
)
|
lib/python3.11/site-packages/functorch/einops/rearrange.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import functools
|
4 |
+
from typing import Callable, Dict, List, Sequence, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from functorch._C import dim as _C
|
9 |
+
from ._parsing import (
|
10 |
+
_ellipsis,
|
11 |
+
AnonymousAxis,
|
12 |
+
comma_separate,
|
13 |
+
parse_pattern,
|
14 |
+
validate_rearrange_expressions,
|
15 |
+
)
|
16 |
+
|
17 |
+
__all__ = ["rearrange"]
|
18 |
+
|
19 |
+
dims = _C.dims
|
20 |
+
|
21 |
+
|
22 |
+
@functools.lru_cache(256)
|
23 |
+
def _create_rearrange_callable(
|
24 |
+
tensor_ndim: int, pattern: str, **axes_lengths: int
|
25 |
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
26 |
+
r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions.
|
27 |
+
|
28 |
+
Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and
|
29 |
+
specified axes lengths, this function can be memoized.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
tensor_ndim (int): the number of dimensions in the tensor to rearrange
|
33 |
+
pattern (str): the `einops`-style rearrangement pattern
|
34 |
+
axes_lengths (int): any additional length specifications for dimensions
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement
|
38 |
+
"""
|
39 |
+
left, right = parse_pattern(pattern, axes_lengths)
|
40 |
+
validate_rearrange_expressions(left, right, axes_lengths)
|
41 |
+
|
42 |
+
n_anon_dims = sum(not dim for dim in left.composition)
|
43 |
+
if left.has_ellipsis:
|
44 |
+
n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1)
|
45 |
+
n_named_dims = len(left.identifiers) - 1
|
46 |
+
|
47 |
+
if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim:
|
48 |
+
raise ValueError(
|
49 |
+
f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of "
|
50 |
+
f"dimensions in the tensor ({tensor_ndim})"
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
n_ellipsis_dims = 0
|
54 |
+
n_named_dims = len(left.identifiers)
|
55 |
+
|
56 |
+
if (pattern_ndim := len(left.composition)) != tensor_ndim:
|
57 |
+
raise ValueError(
|
58 |
+
f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in "
|
59 |
+
f"the tensor ({tensor_ndim})"
|
60 |
+
)
|
61 |
+
n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
|
62 |
+
|
63 |
+
if n_dims == 0:
|
64 |
+
# an identity rearrangement on a 0-dimension tensor
|
65 |
+
return lambda tensor: tensor
|
66 |
+
|
67 |
+
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
|
68 |
+
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
|
69 |
+
anon_axes: List[AnonymousAxis] = []
|
70 |
+
|
71 |
+
# map the left-hand side identifiers to strings representing first class dims
|
72 |
+
dims_i = 0
|
73 |
+
for dimension in left.composition:
|
74 |
+
if isinstance(dimension, list):
|
75 |
+
for identifier in dimension:
|
76 |
+
# non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists
|
77 |
+
assert isinstance(identifier, str)
|
78 |
+
identifier_dim_map[identifier] = (first_class_dims[dims_i],)
|
79 |
+
dims_i += 1
|
80 |
+
if not dimension:
|
81 |
+
# unitary anonymous axis
|
82 |
+
anon_axis = AnonymousAxis("1")
|
83 |
+
identifier_dim_map[anon_axis] = (first_class_dims[dims_i],)
|
84 |
+
anon_axes.append(anon_axis)
|
85 |
+
dimension.append(anon_axis)
|
86 |
+
dims_i += 1
|
87 |
+
elif dimension == _ellipsis:
|
88 |
+
identifier = _ellipsis
|
89 |
+
identifier_dim_map[identifier] = tuple(
|
90 |
+
first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
|
91 |
+
)
|
92 |
+
dims_i += n_ellipsis_dims
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Unexpected dimension: {dimension}")
|
95 |
+
|
96 |
+
def composition_to_dims(
|
97 |
+
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
|
98 |
+
) -> List[Union[str, Tuple[str, ...]]]:
|
99 |
+
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
|
100 |
+
class dims."""
|
101 |
+
dim_composition: List[Union[str, Tuple[str, ...]]] = []
|
102 |
+
for dimension in composition:
|
103 |
+
if isinstance(dimension, list):
|
104 |
+
dim_composition.append(
|
105 |
+
tuple(
|
106 |
+
dim
|
107 |
+
for identifier in dimension
|
108 |
+
for dim in identifier_dim_map[identifier]
|
109 |
+
)
|
110 |
+
)
|
111 |
+
elif dimension == _ellipsis:
|
112 |
+
dim_composition.extend(identifier_dim_map[_ellipsis])
|
113 |
+
else:
|
114 |
+
raise ValueError(f"Unexpected dimension: {dimension}")
|
115 |
+
return dim_composition
|
116 |
+
|
117 |
+
left_dims = composition_to_dims(left.composition)
|
118 |
+
right_dims = composition_to_dims(right.composition)
|
119 |
+
anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes)
|
120 |
+
specified_lengths = tuple(
|
121 |
+
(identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items()
|
122 |
+
)
|
123 |
+
|
124 |
+
custom_rearrange_callable_name = "do_rearrange"
|
125 |
+
custom_rearrange_callable_code = (
|
126 |
+
(
|
127 |
+
f"def {custom_rearrange_callable_name}(tensor):\n"
|
128 |
+
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
|
129 |
+
)
|
130 |
+
+ (
|
131 |
+
"".join(
|
132 |
+
f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
|
133 |
+
)
|
134 |
+
if specified_lengths
|
135 |
+
else ""
|
136 |
+
)
|
137 |
+
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
|
138 |
+
+ (
|
139 |
+
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
|
140 |
+
if anon_dims
|
141 |
+
else " return tensor\n"
|
142 |
+
)
|
143 |
+
)
|
144 |
+
|
145 |
+
exec(custom_rearrange_callable_code)
|
146 |
+
return locals()[custom_rearrange_callable_name]
|
147 |
+
|
148 |
+
|
149 |
+
def rearrange(
|
150 |
+
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
151 |
+
pattern: str,
|
152 |
+
**axes_lengths: int,
|
153 |
+
) -> torch.Tensor:
|
154 |
+
r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
|
155 |
+
tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
|
156 |
+
stack, concatenate and other operations.
|
157 |
+
|
158 |
+
See: https://einops.rocks/api/rearrange/
|
159 |
+
|
160 |
+
Args:
|
161 |
+
tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange
|
162 |
+
pattern (str): the rearrangement pattern
|
163 |
+
axes_lengths (int): any additional length specifications for dimensions
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor: the rearranged tensor
|
167 |
+
|
168 |
+
Examples:
|
169 |
+
>>> # suppose we have a set of 32 images in "h w c" format (height-width-channel)
|
170 |
+
>>> images = torch.randn((32, 30, 40, 3))
|
171 |
+
|
172 |
+
>>> # stack along first (batch) axis, output is a single array
|
173 |
+
>>> rearrange(images, 'b h w c -> b h w c').shape
|
174 |
+
torch.Size([32, 30, 40, 3])
|
175 |
+
|
176 |
+
>>> # concatenate images along height (vertical axis), 960 = 32 * 30
|
177 |
+
>>> rearrange(images, 'b h w c -> (b h) w c').shape
|
178 |
+
torch.Size([960, 40, 3])
|
179 |
+
|
180 |
+
>>> # concatenated images along horizontal axis, 1280 = 32 * 40
|
181 |
+
>>> rearrange(images, 'b h w c -> h (b w) c').shape
|
182 |
+
torch.Size([30, 1280, 3])
|
183 |
+
|
184 |
+
>>> # reordered axes to "b c h w" format for deep learning
|
185 |
+
>>> rearrange(images, 'b h w c -> b c h w').shape
|
186 |
+
torch.Size([32, 3, 30, 40])
|
187 |
+
|
188 |
+
>>> # flattened each image into a vector, 3600 = 30 * 40 * 3
|
189 |
+
>>> rearrange(images, 'b h w c -> b (c h w)').shape
|
190 |
+
torch.Size([32, 3600])
|
191 |
+
|
192 |
+
>>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
|
193 |
+
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
|
194 |
+
torch.Size([128, 15, 20, 3])
|
195 |
+
|
196 |
+
>>> # space-to-depth operation
|
197 |
+
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
|
198 |
+
torch.Size([32, 15, 20, 12])
|
199 |
+
"""
|
200 |
+
if not isinstance(tensor, torch.Tensor):
|
201 |
+
tensor = torch.stack(tensor)
|
202 |
+
|
203 |
+
rearrange_callable = _create_rearrange_callable(
|
204 |
+
tensor.ndim, pattern, **axes_lengths
|
205 |
+
)
|
206 |
+
|
207 |
+
return rearrange_callable(tensor)
|
lib/python3.11/site-packages/functorch/experimental/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch forward-mode is not mature yet
|
2 |
+
from torch._functorch.apis import chunk_vmap
|
3 |
+
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
4 |
+
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
|
5 |
+
|
6 |
+
from functorch import functionalize
|
lib/python3.11/site-packages/functorch/experimental/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (605 Bytes). View file
|
|
lib/python3.11/site-packages/functorch/experimental/__pycache__/_map.cpython-311.pyc
ADDED
Binary file (23.8 kB). View file
|
|
lib/python3.11/site-packages/functorch/experimental/__pycache__/control_flow.cpython-311.pyc
ADDED
Binary file (433 Bytes). View file
|
|
lib/python3.11/site-packages/functorch/experimental/__pycache__/ops.cpython-311.pyc
ADDED
Binary file (300 Bytes). View file
|
|
lib/python3.11/site-packages/functorch/experimental/_map.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils._pytree as pytree
|
3 |
+
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
|
4 |
+
from torch._dispatch.python import suspend_functionalization
|
5 |
+
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
6 |
+
from torch._functorch.eager_transforms import (
|
7 |
+
_unwrap_all_tensors_from_functional,
|
8 |
+
_wrap_all_tensors_to_functional,
|
9 |
+
functionalize,
|
10 |
+
)
|
11 |
+
|
12 |
+
from torch._higher_order_ops.cond import (
|
13 |
+
_has_potential_branch_input_alias,
|
14 |
+
_has_potential_branch_input_mutation,
|
15 |
+
UnsupportedAliasMutationException,
|
16 |
+
)
|
17 |
+
from torch._ops import HigherOrderOperator
|
18 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
19 |
+
from torch.fx.experimental.proxy_tensor import (
|
20 |
+
disable_proxy_modes_tracing,
|
21 |
+
make_fx,
|
22 |
+
ProxyTorchDispatchMode,
|
23 |
+
track_tensor_tree,
|
24 |
+
)
|
25 |
+
from torch.multiprocessing.reductions import StorageWeakRef
|
26 |
+
from torch.utils._python_dispatch import (
|
27 |
+
_get_current_dispatch_mode,
|
28 |
+
_pop_mode_temporarily,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
|
33 |
+
# remove the wrapper call when it's ready.
|
34 |
+
class MapWrapper(HigherOrderOperator):
|
35 |
+
def __call__(self, xs, *args):
|
36 |
+
return map_wrapper(xs, *args)
|
37 |
+
|
38 |
+
|
39 |
+
map = MapWrapper("map", _deprecated_global_ns=True)
|
40 |
+
map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True)
|
41 |
+
|
42 |
+
dummy_aot_config = AOTConfig(
|
43 |
+
fw_compiler=None,
|
44 |
+
bw_compiler=None,
|
45 |
+
partition_fn=None,
|
46 |
+
decompositions={},
|
47 |
+
num_params_buffers=0,
|
48 |
+
aot_id=0,
|
49 |
+
keep_inference_input_mutations=False,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def create_fw_bw_graph(f, num_mapped_args, *args):
|
54 |
+
mapped_xs = args[:num_mapped_args]
|
55 |
+
pos_args = args[num_mapped_args:]
|
56 |
+
|
57 |
+
# Note: We create "clean" environments for make_fx by suspending all dispatch keys
|
58 |
+
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
59 |
+
# added when required. Will encounter two problems if we don't suspend functionalization:
|
60 |
+
#
|
61 |
+
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
62 |
+
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
63 |
+
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
64 |
+
# fetch the proxy for the inputs and fail to capture any operations on them.
|
65 |
+
#
|
66 |
+
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
67 |
+
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
68 |
+
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
69 |
+
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
70 |
+
# Instead, it will create _tensor_constant as output.
|
71 |
+
|
72 |
+
with suspend_functionalization():
|
73 |
+
with disable_proxy_modes_tracing():
|
74 |
+
|
75 |
+
def from_fun(t):
|
76 |
+
if isinstance(t, torch.Tensor):
|
77 |
+
if t.dtype != torch.bool:
|
78 |
+
return torch.empty_strided(
|
79 |
+
t.size(),
|
80 |
+
t.stride(),
|
81 |
+
dtype=t.dtype,
|
82 |
+
requires_grad=t.requires_grad,
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
return t.clone()
|
86 |
+
return t
|
87 |
+
|
88 |
+
example_xs = [from_fun(xs) for xs in _unstack_pytree(mapped_xs)[0]]
|
89 |
+
example_pos_args = [
|
90 |
+
from_fun(arg) if isinstance(arg, torch.Tensor) else arg
|
91 |
+
for arg in pos_args
|
92 |
+
]
|
93 |
+
example_flat_out = pytree.tree_map(
|
94 |
+
from_fun, f(*example_xs, *example_pos_args)
|
95 |
+
)
|
96 |
+
if any(
|
97 |
+
not isinstance(out, torch.Tensor)
|
98 |
+
for out in example_flat_out
|
99 |
+
if out is not None
|
100 |
+
):
|
101 |
+
raise RuntimeError(
|
102 |
+
"Expect outputs of map only contains tensors or None. "
|
103 |
+
f"Got types {[type(out) for out in example_flat_out]}."
|
104 |
+
)
|
105 |
+
example_grad = [from_fun(out) for out in example_flat_out]
|
106 |
+
|
107 |
+
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
|
108 |
+
|
109 |
+
def joint_f(*example_args):
|
110 |
+
joint_mapped_args = example_args[:joint_num_mapped]
|
111 |
+
args = example_args[joint_num_mapped:]
|
112 |
+
|
113 |
+
mapped_input = joint_mapped_args[:num_mapped_args]
|
114 |
+
mapped_grads = joint_mapped_args[num_mapped_args:]
|
115 |
+
|
116 |
+
def fw_with_masks(*args):
|
117 |
+
fw_out = f(*args)
|
118 |
+
return fw_out, [
|
119 |
+
True
|
120 |
+
if isinstance(ret, torch.Tensor) and ret.requires_grad
|
121 |
+
else False
|
122 |
+
for ret in fw_out
|
123 |
+
]
|
124 |
+
|
125 |
+
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
|
126 |
+
_, grads = joint(
|
127 |
+
list(mapped_input) + list(args),
|
128 |
+
[
|
129 |
+
grad
|
130 |
+
for grad in mapped_grads
|
131 |
+
if grad is not None and grad.requires_grad
|
132 |
+
],
|
133 |
+
)
|
134 |
+
|
135 |
+
# In order to keep map functional for backward graph,
|
136 |
+
# we clone outputs that are aliasing inputs
|
137 |
+
input_storage = {
|
138 |
+
StorageWeakRef(arg._typed_storage())
|
139 |
+
for arg in example_args
|
140 |
+
if isinstance(arg, torch.Tensor)
|
141 |
+
}
|
142 |
+
|
143 |
+
def maybe_clone(t):
|
144 |
+
if (
|
145 |
+
isinstance(t, torch.Tensor)
|
146 |
+
and StorageWeakRef(t._typed_storage()) in input_storage
|
147 |
+
):
|
148 |
+
return t.clone()
|
149 |
+
return t
|
150 |
+
|
151 |
+
return pytree.tree_map(maybe_clone, grads)
|
152 |
+
|
153 |
+
joint_num_mapped = len(example_grad) + len(example_xs)
|
154 |
+
joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
|
155 |
+
return fw_graph, joint_graph
|
156 |
+
|
157 |
+
|
158 |
+
def map_wrapper(f, xs, *args):
|
159 |
+
flat_xs, xs_spec = pytree.tree_flatten(xs)
|
160 |
+
if not all(isinstance(t, torch.Tensor) for t in flat_xs):
|
161 |
+
raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
|
162 |
+
|
163 |
+
num_mapped_args = len(flat_xs)
|
164 |
+
shapes = [xs.shape for xs in flat_xs]
|
165 |
+
leading_dim_size = shapes[0][0]
|
166 |
+
if leading_dim_size == 0:
|
167 |
+
raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
|
168 |
+
|
169 |
+
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
|
170 |
+
raise RuntimeError(
|
171 |
+
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
|
172 |
+
)
|
173 |
+
|
174 |
+
out_spec = None
|
175 |
+
|
176 |
+
def flat_fn(*flat_args):
|
177 |
+
xs = pytree.tree_unflatten(flat_args[:num_mapped_args], xs_spec)
|
178 |
+
unflattened_out = f(xs, *flat_args[num_mapped_args:])
|
179 |
+
flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)
|
180 |
+
|
181 |
+
nonlocal out_spec
|
182 |
+
out_spec = tmp_out_spec
|
183 |
+
return flat_out
|
184 |
+
|
185 |
+
return pytree.tree_unflatten(
|
186 |
+
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec
|
187 |
+
)
|
188 |
+
|
189 |
+
|
190 |
+
class MapAutogradOp(torch.autograd.Function):
|
191 |
+
@staticmethod
|
192 |
+
def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
|
193 |
+
ctx.save_for_backward(*flat_args)
|
194 |
+
ctx._joint_graph = joint_graph
|
195 |
+
ctx._num_mapped_args = num_mapped_args
|
196 |
+
with torch._C._AutoDispatchBelowAutograd():
|
197 |
+
return (*map_impl(fw_graph, num_mapped_args, *flat_args),)
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def backward(ctx, *flat_grads):
|
201 |
+
fw_args = ctx.saved_tensors
|
202 |
+
fw_mapped_args = fw_args[: ctx._num_mapped_args]
|
203 |
+
pos_args = fw_args[ctx._num_mapped_args :]
|
204 |
+
|
205 |
+
grads = map_impl(
|
206 |
+
ctx._joint_graph,
|
207 |
+
ctx._num_mapped_args + len(flat_grads),
|
208 |
+
*fw_mapped_args,
|
209 |
+
*flat_grads,
|
210 |
+
*pos_args,
|
211 |
+
)
|
212 |
+
return None, None, None, *grads
|
213 |
+
|
214 |
+
|
215 |
+
def trace_map(proxy_mode, func_overload, f, num_mapped, *args):
|
216 |
+
xs = list(args[:num_mapped])
|
217 |
+
pos_args = list(args[num_mapped:])
|
218 |
+
leading_dim_size = xs[0].shape[0]
|
219 |
+
|
220 |
+
example_input = _unstack_pytree(xs)[0]
|
221 |
+
body_graph = f
|
222 |
+
if not isinstance(body_graph, torch.fx.GraphModule):
|
223 |
+
body_graph = make_fx(body_graph)(*example_input, *pos_args)
|
224 |
+
|
225 |
+
with disable_proxy_modes_tracing():
|
226 |
+
example_outs = body_graph(*example_input, *pos_args)
|
227 |
+
|
228 |
+
def expand_tensor(t):
|
229 |
+
if isinstance(t, torch.Tensor):
|
230 |
+
return t.expand(leading_dim_size, *t.shape)
|
231 |
+
return t
|
232 |
+
|
233 |
+
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
|
234 |
+
|
235 |
+
next_name = None
|
236 |
+
i = 0
|
237 |
+
while not next_name:
|
238 |
+
candidate = f"body_graph_{i}"
|
239 |
+
if hasattr(proxy_mode.tracer.root, candidate):
|
240 |
+
i += 1
|
241 |
+
else:
|
242 |
+
next_name = candidate
|
243 |
+
|
244 |
+
proxy_mode.tracer.root.register_module(next_name, body_graph)
|
245 |
+
node_args = (body_graph, num_mapped, *args)
|
246 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
247 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
248 |
+
"call_function", func_overload, proxy_args, {}, name="map_impl"
|
249 |
+
)
|
250 |
+
return track_tensor_tree(
|
251 |
+
expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
def _unstack_pytree(xs):
|
256 |
+
flat_xs, inspec = pytree.tree_flatten(xs)
|
257 |
+
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
258 |
+
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
259 |
+
|
260 |
+
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
261 |
+
raise RuntimeError(
|
262 |
+
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
263 |
+
)
|
264 |
+
|
265 |
+
a = zip(*flat_xs)
|
266 |
+
pytrees = []
|
267 |
+
for tuple in a:
|
268 |
+
pytrees.append(pytree.tree_unflatten(tuple, inspec))
|
269 |
+
return pytrees
|
270 |
+
|
271 |
+
|
272 |
+
def _stack_pytree(pytrees):
|
273 |
+
flat_out = []
|
274 |
+
out_spec = None
|
275 |
+
for pt in pytrees:
|
276 |
+
flat_pt, out_spec = pytree.tree_flatten(pt)
|
277 |
+
flat_out.append(flat_pt)
|
278 |
+
b = zip(*flat_out)
|
279 |
+
stacked_out = []
|
280 |
+
for leaves in b:
|
281 |
+
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
282 |
+
stacked_out.append(torch.stack(leaves))
|
283 |
+
elif all(leaf is None for leaf in leaves):
|
284 |
+
# Backward graph can return None output when forward inputs doesn't require grad.
|
285 |
+
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
286 |
+
# therefore we need to deal with None output.
|
287 |
+
stacked_out.append(None)
|
288 |
+
else:
|
289 |
+
raise RuntimeError(f"Cannot stack {leaves}.")
|
290 |
+
return pytree.tree_unflatten(stacked_out, out_spec)
|
291 |
+
|
292 |
+
|
293 |
+
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
294 |
+
def map_dense(f, num_mapped_args, *args):
|
295 |
+
xs = args[:num_mapped_args]
|
296 |
+
pos_args = args[num_mapped_args:]
|
297 |
+
pytrees = []
|
298 |
+
for inp in _unstack_pytree(xs):
|
299 |
+
pytrees.append(f(*inp, *pos_args))
|
300 |
+
return _stack_pytree(pytrees)
|
301 |
+
|
302 |
+
|
303 |
+
@map_impl.py_impl(DispatchKey.Autograd)
|
304 |
+
def map_autograd(f, num_mapped_args, *args):
|
305 |
+
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *args)
|
306 |
+
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *args)
|
307 |
+
return flat_out
|
308 |
+
|
309 |
+
|
310 |
+
@map_impl.py_impl(ProxyTorchDispatchMode)
|
311 |
+
def map_proxy_torch_dispatch_mode(f, num_mapped, *args):
|
312 |
+
mode = _get_current_dispatch_mode()
|
313 |
+
assert mode is not None, "Mode should always be enabled for python fallback key"
|
314 |
+
with _pop_mode_temporarily() as mode:
|
315 |
+
if mode.enable_tracing:
|
316 |
+
return trace_map(mode, map_impl, f, num_mapped, *args)
|
317 |
+
else:
|
318 |
+
return map_impl(f, num_mapped, *args)
|
319 |
+
|
320 |
+
|
321 |
+
@map_impl.py_impl(FakeTensorMode)
|
322 |
+
def map_fake_tensor_mode(f, num_mapped, *args):
|
323 |
+
return map_dense(f, num_mapped, *args)
|
324 |
+
|
325 |
+
|
326 |
+
@map_impl.py_impl(DispatchKey.Functionalize)
|
327 |
+
def map_func(f, num_mapped, *args):
|
328 |
+
reapply_views = torch._C._functionalization_reapply_views_tls()
|
329 |
+
xs = args[:num_mapped]
|
330 |
+
pos_args = args[num_mapped:]
|
331 |
+
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
|
332 |
+
unwrapped_args = _unwrap_all_tensors_from_functional(
|
333 |
+
pos_args, reapply_views=reapply_views
|
334 |
+
)
|
335 |
+
mode = "mutations_and_views" if reapply_views else "mutations"
|
336 |
+
|
337 |
+
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
338 |
+
functional_map_fn = functionalize(f, remove=mode)
|
339 |
+
with disable_proxy_modes_tracing():
|
340 |
+
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
341 |
+
|
342 |
+
if _has_potential_branch_input_mutation(f, example_inputs):
|
343 |
+
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
|
344 |
+
|
345 |
+
if _has_potential_branch_input_alias(f, example_inputs):
|
346 |
+
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
347 |
+
|
348 |
+
map_return = map_impl(
|
349 |
+
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
|
350 |
+
)
|
351 |
+
return _wrap_all_tensors_to_functional(map_return, level=0)
|
352 |
+
|
353 |
+
|
354 |
+
@map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
|
355 |
+
def map_functionalize(interpreter, f, num_mapped, *args):
|
356 |
+
"""
|
357 |
+
Functionalization implementation for torch.map. Currently:
|
358 |
+
1. We don't allow any input mutation inside the map function
|
359 |
+
2. Our check for above condition is not exhaustive
|
360 |
+
"""
|
361 |
+
xs = args[:num_mapped]
|
362 |
+
pos_args = args[num_mapped:]
|
363 |
+
reapply_views = interpreter.functionalize_add_back_views()
|
364 |
+
mode = "mutations_and_views" if reapply_views else "mutations"
|
365 |
+
# At this point, we will see functionalized tensors, so need to unwrap them first
|
366 |
+
unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
|
367 |
+
unwrapped_args = _unwrap_all_tensors_from_functional(
|
368 |
+
pos_args, reapply_views=reapply_views
|
369 |
+
)
|
370 |
+
|
371 |
+
functional_map_fn = functionalize(f, remove=mode)
|
372 |
+
|
373 |
+
with interpreter.lower():
|
374 |
+
with disable_proxy_modes_tracing():
|
375 |
+
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
376 |
+
if _has_potential_branch_input_mutation(f, example_inputs):
|
377 |
+
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
|
378 |
+
|
379 |
+
if _has_potential_branch_input_alias(f, example_inputs):
|
380 |
+
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
381 |
+
|
382 |
+
map_return = map_impl(
|
383 |
+
functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args
|
384 |
+
)
|
385 |
+
return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
|
386 |
+
|
387 |
+
|
388 |
+
# TODO(voz) Make this automatic for keys, this is very ugly atm
|
389 |
+
map_impl.fallthrough(DispatchKey.PythonDispatcher)
|
390 |
+
map_impl.fallthrough(DispatchKey.PythonTLSSnapshot)
|
391 |
+
map_impl.fallthrough(DispatchKey.ADInplaceOrView)
|
392 |
+
map_impl.fallthrough(DispatchKey.BackendSelect)
|
393 |
+
map_impl.fallthrough(DispatchKey.AutocastCPU)
|
lib/python3.11/site-packages/functorch/experimental/control_flow.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch._higher_order_ops.cond import ( # noqa: F401
|
2 |
+
cond,
|
3 |
+
UnsupportedAliasMutationException,
|
4 |
+
)
|
5 |
+
|
6 |
+
from ._map import map # noqa: F401
|
lib/python3.11/site-packages/functorch/experimental/ops.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from torch._ops import HigherOrderOperator # noqa: F401
|
lib/python3.11/site-packages/huggingface_hub/__init__.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# ***********
|
16 |
+
# `huggingface_hub` init has 2 modes:
|
17 |
+
# - Normal usage:
|
18 |
+
# If imported to use it, all modules and functions are lazy-loaded. This means
|
19 |
+
# they exist at top level in module but are imported only the first time they are
|
20 |
+
# used. This way, `from huggingface_hub import something` will import `something`
|
21 |
+
# quickly without the hassle of importing all the features from `huggingface_hub`.
|
22 |
+
# - Static check:
|
23 |
+
# If statically analyzed, all modules and functions are loaded normally. This way
|
24 |
+
# static typing check works properly as well as autocomplete in text editors and
|
25 |
+
# IDEs.
|
26 |
+
#
|
27 |
+
# The static model imports are done inside the `if TYPE_CHECKING:` statement at
|
28 |
+
# the bottom of this file. Since module/functions imports are duplicated, it is
|
29 |
+
# mandatory to make sure to add them twice when adding one. This is checked in the
|
30 |
+
# `make quality` command.
|
31 |
+
#
|
32 |
+
# To update the static imports, please run the following command and commit the changes.
|
33 |
+
# ```
|
34 |
+
# # Use script
|
35 |
+
# python utils/check_static_imports.py --update-file
|
36 |
+
#
|
37 |
+
# # Or run style on codebase
|
38 |
+
# make style
|
39 |
+
# ```
|
40 |
+
#
|
41 |
+
# ***********
|
42 |
+
# Lazy loader vendored from https://github.com/scientific-python/lazy_loader
|
43 |
+
import importlib
|
44 |
+
import os
|
45 |
+
import sys
|
46 |
+
from typing import TYPE_CHECKING
|
47 |
+
|
48 |
+
|
49 |
+
__version__ = "0.20.2"
|
50 |
+
|
51 |
+
# Alphabetical order of definitions is ensured in tests
|
52 |
+
# WARNING: any comment added in this dictionary definition will be lost when
|
53 |
+
# re-generating the file !
|
54 |
+
_SUBMOD_ATTRS = {
|
55 |
+
"_commit_scheduler": [
|
56 |
+
"CommitScheduler",
|
57 |
+
],
|
58 |
+
"_inference_endpoints": [
|
59 |
+
"InferenceEndpoint",
|
60 |
+
"InferenceEndpointError",
|
61 |
+
"InferenceEndpointStatus",
|
62 |
+
"InferenceEndpointTimeoutError",
|
63 |
+
"InferenceEndpointType",
|
64 |
+
],
|
65 |
+
"_login": [
|
66 |
+
"interpreter_login",
|
67 |
+
"login",
|
68 |
+
"logout",
|
69 |
+
"notebook_login",
|
70 |
+
],
|
71 |
+
"_multi_commits": [
|
72 |
+
"MultiCommitException",
|
73 |
+
"plan_multi_commits",
|
74 |
+
],
|
75 |
+
"_snapshot_download": [
|
76 |
+
"snapshot_download",
|
77 |
+
],
|
78 |
+
"_space_api": [
|
79 |
+
"SpaceHardware",
|
80 |
+
"SpaceRuntime",
|
81 |
+
"SpaceStage",
|
82 |
+
"SpaceStorage",
|
83 |
+
"SpaceVariable",
|
84 |
+
],
|
85 |
+
"_tensorboard_logger": [
|
86 |
+
"HFSummaryWriter",
|
87 |
+
],
|
88 |
+
"_webhooks_payload": [
|
89 |
+
"WebhookPayload",
|
90 |
+
"WebhookPayloadComment",
|
91 |
+
"WebhookPayloadDiscussion",
|
92 |
+
"WebhookPayloadDiscussionChanges",
|
93 |
+
"WebhookPayloadEvent",
|
94 |
+
"WebhookPayloadMovedTo",
|
95 |
+
"WebhookPayloadRepo",
|
96 |
+
"WebhookPayloadUrl",
|
97 |
+
"WebhookPayloadWebhook",
|
98 |
+
],
|
99 |
+
"_webhooks_server": [
|
100 |
+
"WebhooksServer",
|
101 |
+
"webhook_endpoint",
|
102 |
+
],
|
103 |
+
"community": [
|
104 |
+
"Discussion",
|
105 |
+
"DiscussionComment",
|
106 |
+
"DiscussionCommit",
|
107 |
+
"DiscussionEvent",
|
108 |
+
"DiscussionStatusChange",
|
109 |
+
"DiscussionTitleChange",
|
110 |
+
"DiscussionWithDetails",
|
111 |
+
],
|
112 |
+
"constants": [
|
113 |
+
"CONFIG_NAME",
|
114 |
+
"FLAX_WEIGHTS_NAME",
|
115 |
+
"HUGGINGFACE_CO_URL_HOME",
|
116 |
+
"HUGGINGFACE_CO_URL_TEMPLATE",
|
117 |
+
"PYTORCH_WEIGHTS_NAME",
|
118 |
+
"REPO_TYPE_DATASET",
|
119 |
+
"REPO_TYPE_MODEL",
|
120 |
+
"REPO_TYPE_SPACE",
|
121 |
+
"TF2_WEIGHTS_NAME",
|
122 |
+
"TF_WEIGHTS_NAME",
|
123 |
+
],
|
124 |
+
"fastai_utils": [
|
125 |
+
"_save_pretrained_fastai",
|
126 |
+
"from_pretrained_fastai",
|
127 |
+
"push_to_hub_fastai",
|
128 |
+
],
|
129 |
+
"file_download": [
|
130 |
+
"HfFileMetadata",
|
131 |
+
"_CACHED_NO_EXIST",
|
132 |
+
"cached_download",
|
133 |
+
"get_hf_file_metadata",
|
134 |
+
"hf_hub_download",
|
135 |
+
"hf_hub_url",
|
136 |
+
"try_to_load_from_cache",
|
137 |
+
],
|
138 |
+
"hf_api": [
|
139 |
+
"Collection",
|
140 |
+
"CollectionItem",
|
141 |
+
"CommitInfo",
|
142 |
+
"CommitOperation",
|
143 |
+
"CommitOperationAdd",
|
144 |
+
"CommitOperationCopy",
|
145 |
+
"CommitOperationDelete",
|
146 |
+
"GitCommitInfo",
|
147 |
+
"GitRefInfo",
|
148 |
+
"GitRefs",
|
149 |
+
"HfApi",
|
150 |
+
"RepoUrl",
|
151 |
+
"User",
|
152 |
+
"UserLikes",
|
153 |
+
"accept_access_request",
|
154 |
+
"add_collection_item",
|
155 |
+
"add_space_secret",
|
156 |
+
"add_space_variable",
|
157 |
+
"cancel_access_request",
|
158 |
+
"change_discussion_status",
|
159 |
+
"comment_discussion",
|
160 |
+
"create_branch",
|
161 |
+
"create_collection",
|
162 |
+
"create_commit",
|
163 |
+
"create_commits_on_pr",
|
164 |
+
"create_discussion",
|
165 |
+
"create_inference_endpoint",
|
166 |
+
"create_pull_request",
|
167 |
+
"create_repo",
|
168 |
+
"create_tag",
|
169 |
+
"dataset_info",
|
170 |
+
"delete_branch",
|
171 |
+
"delete_collection",
|
172 |
+
"delete_collection_item",
|
173 |
+
"delete_file",
|
174 |
+
"delete_folder",
|
175 |
+
"delete_inference_endpoint",
|
176 |
+
"delete_repo",
|
177 |
+
"delete_space_secret",
|
178 |
+
"delete_space_storage",
|
179 |
+
"delete_space_variable",
|
180 |
+
"delete_tag",
|
181 |
+
"duplicate_space",
|
182 |
+
"edit_discussion_comment",
|
183 |
+
"file_exists",
|
184 |
+
"get_collection",
|
185 |
+
"get_dataset_tags",
|
186 |
+
"get_discussion_details",
|
187 |
+
"get_full_repo_name",
|
188 |
+
"get_inference_endpoint",
|
189 |
+
"get_model_tags",
|
190 |
+
"get_paths_info",
|
191 |
+
"get_repo_discussions",
|
192 |
+
"get_safetensors_metadata",
|
193 |
+
"get_space_runtime",
|
194 |
+
"get_space_variables",
|
195 |
+
"get_token_permission",
|
196 |
+
"grant_access",
|
197 |
+
"like",
|
198 |
+
"list_accepted_access_requests",
|
199 |
+
"list_collections",
|
200 |
+
"list_datasets",
|
201 |
+
"list_files_info",
|
202 |
+
"list_inference_endpoints",
|
203 |
+
"list_liked_repos",
|
204 |
+
"list_metrics",
|
205 |
+
"list_models",
|
206 |
+
"list_pending_access_requests",
|
207 |
+
"list_rejected_access_requests",
|
208 |
+
"list_repo_commits",
|
209 |
+
"list_repo_files",
|
210 |
+
"list_repo_likers",
|
211 |
+
"list_repo_refs",
|
212 |
+
"list_repo_tree",
|
213 |
+
"list_spaces",
|
214 |
+
"merge_pull_request",
|
215 |
+
"model_info",
|
216 |
+
"move_repo",
|
217 |
+
"parse_safetensors_file_metadata",
|
218 |
+
"pause_inference_endpoint",
|
219 |
+
"pause_space",
|
220 |
+
"preupload_lfs_files",
|
221 |
+
"reject_access_request",
|
222 |
+
"rename_discussion",
|
223 |
+
"repo_exists",
|
224 |
+
"repo_info",
|
225 |
+
"repo_type_and_id_from_hf_id",
|
226 |
+
"request_space_hardware",
|
227 |
+
"request_space_storage",
|
228 |
+
"restart_space",
|
229 |
+
"resume_inference_endpoint",
|
230 |
+
"run_as_future",
|
231 |
+
"scale_to_zero_inference_endpoint",
|
232 |
+
"set_space_sleep_time",
|
233 |
+
"space_info",
|
234 |
+
"super_squash_history",
|
235 |
+
"unlike",
|
236 |
+
"update_collection_item",
|
237 |
+
"update_collection_metadata",
|
238 |
+
"update_inference_endpoint",
|
239 |
+
"update_repo_visibility",
|
240 |
+
"upload_file",
|
241 |
+
"upload_folder",
|
242 |
+
"whoami",
|
243 |
+
],
|
244 |
+
"hf_file_system": [
|
245 |
+
"HfFileSystem",
|
246 |
+
"HfFileSystemFile",
|
247 |
+
"HfFileSystemResolvedPath",
|
248 |
+
],
|
249 |
+
"hub_mixin": [
|
250 |
+
"ModelHubMixin",
|
251 |
+
"PyTorchModelHubMixin",
|
252 |
+
],
|
253 |
+
"inference._client": [
|
254 |
+
"InferenceClient",
|
255 |
+
"InferenceTimeoutError",
|
256 |
+
],
|
257 |
+
"inference._generated._async_client": [
|
258 |
+
"AsyncInferenceClient",
|
259 |
+
],
|
260 |
+
"inference_api": [
|
261 |
+
"InferenceApi",
|
262 |
+
],
|
263 |
+
"keras_mixin": [
|
264 |
+
"KerasModelHubMixin",
|
265 |
+
"from_pretrained_keras",
|
266 |
+
"push_to_hub_keras",
|
267 |
+
"save_pretrained_keras",
|
268 |
+
],
|
269 |
+
"repocard": [
|
270 |
+
"DatasetCard",
|
271 |
+
"ModelCard",
|
272 |
+
"RepoCard",
|
273 |
+
"SpaceCard",
|
274 |
+
"metadata_eval_result",
|
275 |
+
"metadata_load",
|
276 |
+
"metadata_save",
|
277 |
+
"metadata_update",
|
278 |
+
],
|
279 |
+
"repocard_data": [
|
280 |
+
"CardData",
|
281 |
+
"DatasetCardData",
|
282 |
+
"EvalResult",
|
283 |
+
"ModelCardData",
|
284 |
+
"SpaceCardData",
|
285 |
+
],
|
286 |
+
"repository": [
|
287 |
+
"Repository",
|
288 |
+
],
|
289 |
+
"utils": [
|
290 |
+
"CacheNotFound",
|
291 |
+
"CachedFileInfo",
|
292 |
+
"CachedRepoInfo",
|
293 |
+
"CachedRevisionInfo",
|
294 |
+
"CorruptedCacheException",
|
295 |
+
"DeleteCacheStrategy",
|
296 |
+
"HFCacheInfo",
|
297 |
+
"HfFolder",
|
298 |
+
"cached_assets_path",
|
299 |
+
"configure_http_backend",
|
300 |
+
"dump_environment_info",
|
301 |
+
"get_session",
|
302 |
+
"get_token",
|
303 |
+
"logging",
|
304 |
+
"scan_cache_dir",
|
305 |
+
],
|
306 |
+
"utils.endpoint_helpers": [
|
307 |
+
"DatasetFilter",
|
308 |
+
"ModelFilter",
|
309 |
+
],
|
310 |
+
}
|
311 |
+
|
312 |
+
|
313 |
+
def _attach(package_name, submodules=None, submod_attrs=None):
|
314 |
+
"""Attach lazily loaded submodules, functions, or other attributes.
|
315 |
+
|
316 |
+
Typically, modules import submodules and attributes as follows:
|
317 |
+
|
318 |
+
```py
|
319 |
+
import mysubmodule
|
320 |
+
import anothersubmodule
|
321 |
+
|
322 |
+
from .foo import someattr
|
323 |
+
```
|
324 |
+
|
325 |
+
The idea is to replace a package's `__getattr__`, `__dir__`, and
|
326 |
+
`__all__`, such that all imports work exactly the way they would
|
327 |
+
with normal imports, except that the import occurs upon first use.
|
328 |
+
|
329 |
+
The typical way to call this function, replacing the above imports, is:
|
330 |
+
|
331 |
+
```python
|
332 |
+
__getattr__, __dir__, __all__ = lazy.attach(
|
333 |
+
__name__,
|
334 |
+
['mysubmodule', 'anothersubmodule'],
|
335 |
+
{'foo': ['someattr']}
|
336 |
+
)
|
337 |
+
```
|
338 |
+
This functionality requires Python 3.7 or higher.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
package_name (`str`):
|
342 |
+
Typically use `__name__`.
|
343 |
+
submodules (`set`):
|
344 |
+
List of submodules to attach.
|
345 |
+
submod_attrs (`dict`):
|
346 |
+
Dictionary of submodule -> list of attributes / functions.
|
347 |
+
These attributes are imported as they are used.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
__getattr__, __dir__, __all__
|
351 |
+
|
352 |
+
"""
|
353 |
+
if submod_attrs is None:
|
354 |
+
submod_attrs = {}
|
355 |
+
|
356 |
+
if submodules is None:
|
357 |
+
submodules = set()
|
358 |
+
else:
|
359 |
+
submodules = set(submodules)
|
360 |
+
|
361 |
+
attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs}
|
362 |
+
|
363 |
+
__all__ = list(submodules | attr_to_modules.keys())
|
364 |
+
|
365 |
+
def __getattr__(name):
|
366 |
+
if name in submodules:
|
367 |
+
return importlib.import_module(f"{package_name}.{name}")
|
368 |
+
elif name in attr_to_modules:
|
369 |
+
submod_path = f"{package_name}.{attr_to_modules[name]}"
|
370 |
+
submod = importlib.import_module(submod_path)
|
371 |
+
attr = getattr(submod, name)
|
372 |
+
|
373 |
+
# If the attribute lives in a file (module) with the same
|
374 |
+
# name as the attribute, ensure that the attribute and *not*
|
375 |
+
# the module is accessible on the package.
|
376 |
+
if name == attr_to_modules[name]:
|
377 |
+
pkg = sys.modules[package_name]
|
378 |
+
pkg.__dict__[name] = attr
|
379 |
+
|
380 |
+
return attr
|
381 |
+
else:
|
382 |
+
raise AttributeError(f"No {package_name} attribute {name}")
|
383 |
+
|
384 |
+
def __dir__():
|
385 |
+
return __all__
|
386 |
+
|
387 |
+
if os.environ.get("EAGER_IMPORT", ""):
|
388 |
+
for attr in set(attr_to_modules.keys()) | submodules:
|
389 |
+
__getattr__(attr)
|
390 |
+
|
391 |
+
return __getattr__, __dir__, list(__all__)
|
392 |
+
|
393 |
+
|
394 |
+
__getattr__, __dir__, __all__ = _attach(__name__, submodules=[], submod_attrs=_SUBMOD_ATTRS)
|
395 |
+
|
396 |
+
# WARNING: any content below this statement is generated automatically. Any manual edit
|
397 |
+
# will be lost when re-generating this file !
|
398 |
+
#
|
399 |
+
# To update the static imports, please run the following command and commit the changes.
|
400 |
+
# ```
|
401 |
+
# # Use script
|
402 |
+
# python utils/check_static_imports.py --update-file
|
403 |
+
#
|
404 |
+
# # Or run style on codebase
|
405 |
+
# make style
|
406 |
+
# ```
|
407 |
+
if TYPE_CHECKING: # pragma: no cover
|
408 |
+
from ._commit_scheduler import CommitScheduler # noqa: F401
|
409 |
+
from ._inference_endpoints import (
|
410 |
+
InferenceEndpoint, # noqa: F401
|
411 |
+
InferenceEndpointError, # noqa: F401
|
412 |
+
InferenceEndpointStatus, # noqa: F401
|
413 |
+
InferenceEndpointTimeoutError, # noqa: F401
|
414 |
+
InferenceEndpointType, # noqa: F401
|
415 |
+
)
|
416 |
+
from ._login import (
|
417 |
+
interpreter_login, # noqa: F401
|
418 |
+
login, # noqa: F401
|
419 |
+
logout, # noqa: F401
|
420 |
+
notebook_login, # noqa: F401
|
421 |
+
)
|
422 |
+
from ._multi_commits import (
|
423 |
+
MultiCommitException, # noqa: F401
|
424 |
+
plan_multi_commits, # noqa: F401
|
425 |
+
)
|
426 |
+
from ._snapshot_download import snapshot_download # noqa: F401
|
427 |
+
from ._space_api import (
|
428 |
+
SpaceHardware, # noqa: F401
|
429 |
+
SpaceRuntime, # noqa: F401
|
430 |
+
SpaceStage, # noqa: F401
|
431 |
+
SpaceStorage, # noqa: F401
|
432 |
+
SpaceVariable, # noqa: F401
|
433 |
+
)
|
434 |
+
from ._tensorboard_logger import HFSummaryWriter # noqa: F401
|
435 |
+
from ._webhooks_payload import (
|
436 |
+
WebhookPayload, # noqa: F401
|
437 |
+
WebhookPayloadComment, # noqa: F401
|
438 |
+
WebhookPayloadDiscussion, # noqa: F401
|
439 |
+
WebhookPayloadDiscussionChanges, # noqa: F401
|
440 |
+
WebhookPayloadEvent, # noqa: F401
|
441 |
+
WebhookPayloadMovedTo, # noqa: F401
|
442 |
+
WebhookPayloadRepo, # noqa: F401
|
443 |
+
WebhookPayloadUrl, # noqa: F401
|
444 |
+
WebhookPayloadWebhook, # noqa: F401
|
445 |
+
)
|
446 |
+
from ._webhooks_server import (
|
447 |
+
WebhooksServer, # noqa: F401
|
448 |
+
webhook_endpoint, # noqa: F401
|
449 |
+
)
|
450 |
+
from .community import (
|
451 |
+
Discussion, # noqa: F401
|
452 |
+
DiscussionComment, # noqa: F401
|
453 |
+
DiscussionCommit, # noqa: F401
|
454 |
+
DiscussionEvent, # noqa: F401
|
455 |
+
DiscussionStatusChange, # noqa: F401
|
456 |
+
DiscussionTitleChange, # noqa: F401
|
457 |
+
DiscussionWithDetails, # noqa: F401
|
458 |
+
)
|
459 |
+
from .constants import (
|
460 |
+
CONFIG_NAME, # noqa: F401
|
461 |
+
FLAX_WEIGHTS_NAME, # noqa: F401
|
462 |
+
HUGGINGFACE_CO_URL_HOME, # noqa: F401
|
463 |
+
HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401
|
464 |
+
PYTORCH_WEIGHTS_NAME, # noqa: F401
|
465 |
+
REPO_TYPE_DATASET, # noqa: F401
|
466 |
+
REPO_TYPE_MODEL, # noqa: F401
|
467 |
+
REPO_TYPE_SPACE, # noqa: F401
|
468 |
+
TF2_WEIGHTS_NAME, # noqa: F401
|
469 |
+
TF_WEIGHTS_NAME, # noqa: F401
|
470 |
+
)
|
471 |
+
from .fastai_utils import (
|
472 |
+
_save_pretrained_fastai, # noqa: F401
|
473 |
+
from_pretrained_fastai, # noqa: F401
|
474 |
+
push_to_hub_fastai, # noqa: F401
|
475 |
+
)
|
476 |
+
from .file_download import (
|
477 |
+
_CACHED_NO_EXIST, # noqa: F401
|
478 |
+
HfFileMetadata, # noqa: F401
|
479 |
+
cached_download, # noqa: F401
|
480 |
+
get_hf_file_metadata, # noqa: F401
|
481 |
+
hf_hub_download, # noqa: F401
|
482 |
+
hf_hub_url, # noqa: F401
|
483 |
+
try_to_load_from_cache, # noqa: F401
|
484 |
+
)
|
485 |
+
from .hf_api import (
|
486 |
+
Collection, # noqa: F401
|
487 |
+
CollectionItem, # noqa: F401
|
488 |
+
CommitInfo, # noqa: F401
|
489 |
+
CommitOperation, # noqa: F401
|
490 |
+
CommitOperationAdd, # noqa: F401
|
491 |
+
CommitOperationCopy, # noqa: F401
|
492 |
+
CommitOperationDelete, # noqa: F401
|
493 |
+
GitCommitInfo, # noqa: F401
|
494 |
+
GitRefInfo, # noqa: F401
|
495 |
+
GitRefs, # noqa: F401
|
496 |
+
HfApi, # noqa: F401
|
497 |
+
RepoUrl, # noqa: F401
|
498 |
+
User, # noqa: F401
|
499 |
+
UserLikes, # noqa: F401
|
500 |
+
accept_access_request, # noqa: F401
|
501 |
+
add_collection_item, # noqa: F401
|
502 |
+
add_space_secret, # noqa: F401
|
503 |
+
add_space_variable, # noqa: F401
|
504 |
+
cancel_access_request, # noqa: F401
|
505 |
+
change_discussion_status, # noqa: F401
|
506 |
+
comment_discussion, # noqa: F401
|
507 |
+
create_branch, # noqa: F401
|
508 |
+
create_collection, # noqa: F401
|
509 |
+
create_commit, # noqa: F401
|
510 |
+
create_commits_on_pr, # noqa: F401
|
511 |
+
create_discussion, # noqa: F401
|
512 |
+
create_inference_endpoint, # noqa: F401
|
513 |
+
create_pull_request, # noqa: F401
|
514 |
+
create_repo, # noqa: F401
|
515 |
+
create_tag, # noqa: F401
|
516 |
+
dataset_info, # noqa: F401
|
517 |
+
delete_branch, # noqa: F401
|
518 |
+
delete_collection, # noqa: F401
|
519 |
+
delete_collection_item, # noqa: F401
|
520 |
+
delete_file, # noqa: F401
|
521 |
+
delete_folder, # noqa: F401
|
522 |
+
delete_inference_endpoint, # noqa: F401
|
523 |
+
delete_repo, # noqa: F401
|
524 |
+
delete_space_secret, # noqa: F401
|
525 |
+
delete_space_storage, # noqa: F401
|
526 |
+
delete_space_variable, # noqa: F401
|
527 |
+
delete_tag, # noqa: F401
|
528 |
+
duplicate_space, # noqa: F401
|
529 |
+
edit_discussion_comment, # noqa: F401
|
530 |
+
file_exists, # noqa: F401
|
531 |
+
get_collection, # noqa: F401
|
532 |
+
get_dataset_tags, # noqa: F401
|
533 |
+
get_discussion_details, # noqa: F401
|
534 |
+
get_full_repo_name, # noqa: F401
|
535 |
+
get_inference_endpoint, # noqa: F401
|
536 |
+
get_model_tags, # noqa: F401
|
537 |
+
get_paths_info, # noqa: F401
|
538 |
+
get_repo_discussions, # noqa: F401
|
539 |
+
get_safetensors_metadata, # noqa: F401
|
540 |
+
get_space_runtime, # noqa: F401
|
541 |
+
get_space_variables, # noqa: F401
|
542 |
+
get_token_permission, # noqa: F401
|
543 |
+
grant_access, # noqa: F401
|
544 |
+
like, # noqa: F401
|
545 |
+
list_accepted_access_requests, # noqa: F401
|
546 |
+
list_collections, # noqa: F401
|
547 |
+
list_datasets, # noqa: F401
|
548 |
+
list_files_info, # noqa: F401
|
549 |
+
list_inference_endpoints, # noqa: F401
|
550 |
+
list_liked_repos, # noqa: F401
|
551 |
+
list_metrics, # noqa: F401
|
552 |
+
list_models, # noqa: F401
|
553 |
+
list_pending_access_requests, # noqa: F401
|
554 |
+
list_rejected_access_requests, # noqa: F401
|
555 |
+
list_repo_commits, # noqa: F401
|
556 |
+
list_repo_files, # noqa: F401
|
557 |
+
list_repo_likers, # noqa: F401
|
558 |
+
list_repo_refs, # noqa: F401
|
559 |
+
list_repo_tree, # noqa: F401
|
560 |
+
list_spaces, # noqa: F401
|
561 |
+
merge_pull_request, # noqa: F401
|
562 |
+
model_info, # noqa: F401
|
563 |
+
move_repo, # noqa: F401
|
564 |
+
parse_safetensors_file_metadata, # noqa: F401
|
565 |
+
pause_inference_endpoint, # noqa: F401
|
566 |
+
pause_space, # noqa: F401
|
567 |
+
preupload_lfs_files, # noqa: F401
|
568 |
+
reject_access_request, # noqa: F401
|
569 |
+
rename_discussion, # noqa: F401
|
570 |
+
repo_exists, # noqa: F401
|
571 |
+
repo_info, # noqa: F401
|
572 |
+
repo_type_and_id_from_hf_id, # noqa: F401
|
573 |
+
request_space_hardware, # noqa: F401
|
574 |
+
request_space_storage, # noqa: F401
|
575 |
+
restart_space, # noqa: F401
|
576 |
+
resume_inference_endpoint, # noqa: F401
|
577 |
+
run_as_future, # noqa: F401
|
578 |
+
scale_to_zero_inference_endpoint, # noqa: F401
|
579 |
+
set_space_sleep_time, # noqa: F401
|
580 |
+
space_info, # noqa: F401
|
581 |
+
super_squash_history, # noqa: F401
|
582 |
+
unlike, # noqa: F401
|
583 |
+
update_collection_item, # noqa: F401
|
584 |
+
update_collection_metadata, # noqa: F401
|
585 |
+
update_inference_endpoint, # noqa: F401
|
586 |
+
update_repo_visibility, # noqa: F401
|
587 |
+
upload_file, # noqa: F401
|
588 |
+
upload_folder, # noqa: F401
|
589 |
+
whoami, # noqa: F401
|
590 |
+
)
|
591 |
+
from .hf_file_system import (
|
592 |
+
HfFileSystem, # noqa: F401
|
593 |
+
HfFileSystemFile, # noqa: F401
|
594 |
+
HfFileSystemResolvedPath, # noqa: F401
|
595 |
+
)
|
596 |
+
from .hub_mixin import (
|
597 |
+
ModelHubMixin, # noqa: F401
|
598 |
+
PyTorchModelHubMixin, # noqa: F401
|
599 |
+
)
|
600 |
+
from .inference._client import (
|
601 |
+
InferenceClient, # noqa: F401
|
602 |
+
InferenceTimeoutError, # noqa: F401
|
603 |
+
)
|
604 |
+
from .inference._generated._async_client import AsyncInferenceClient # noqa: F401
|
605 |
+
from .inference_api import InferenceApi # noqa: F401
|
606 |
+
from .keras_mixin import (
|
607 |
+
KerasModelHubMixin, # noqa: F401
|
608 |
+
from_pretrained_keras, # noqa: F401
|
609 |
+
push_to_hub_keras, # noqa: F401
|
610 |
+
save_pretrained_keras, # noqa: F401
|
611 |
+
)
|
612 |
+
from .repocard import (
|
613 |
+
DatasetCard, # noqa: F401
|
614 |
+
ModelCard, # noqa: F401
|
615 |
+
RepoCard, # noqa: F401
|
616 |
+
SpaceCard, # noqa: F401
|
617 |
+
metadata_eval_result, # noqa: F401
|
618 |
+
metadata_load, # noqa: F401
|
619 |
+
metadata_save, # noqa: F401
|
620 |
+
metadata_update, # noqa: F401
|
621 |
+
)
|
622 |
+
from .repocard_data import (
|
623 |
+
CardData, # noqa: F401
|
624 |
+
DatasetCardData, # noqa: F401
|
625 |
+
EvalResult, # noqa: F401
|
626 |
+
ModelCardData, # noqa: F401
|
627 |
+
SpaceCardData, # noqa: F401
|
628 |
+
)
|
629 |
+
from .repository import Repository # noqa: F401
|
630 |
+
from .utils import (
|
631 |
+
CachedFileInfo, # noqa: F401
|
632 |
+
CachedRepoInfo, # noqa: F401
|
633 |
+
CachedRevisionInfo, # noqa: F401
|
634 |
+
CacheNotFound, # noqa: F401
|
635 |
+
CorruptedCacheException, # noqa: F401
|
636 |
+
DeleteCacheStrategy, # noqa: F401
|
637 |
+
HFCacheInfo, # noqa: F401
|
638 |
+
HfFolder, # noqa: F401
|
639 |
+
cached_assets_path, # noqa: F401
|
640 |
+
configure_http_backend, # noqa: F401
|
641 |
+
dump_environment_info, # noqa: F401
|
642 |
+
get_session, # noqa: F401
|
643 |
+
get_token, # noqa: F401
|
644 |
+
logging, # noqa: F401
|
645 |
+
scan_cache_dir, # noqa: F401
|
646 |
+
)
|
647 |
+
from .utils.endpoint_helpers import (
|
648 |
+
DatasetFilter, # noqa: F401
|
649 |
+
ModelFilter, # noqa: F401
|
650 |
+
)
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (13.8 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_commit_api.cpython-311.pyc
ADDED
Binary file (33.8 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_commit_scheduler.cpython-311.pyc
ADDED
Binary file (18.6 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_inference_endpoints.cpython-311.pyc
ADDED
Binary file (18.7 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_login.cpython-311.pyc
ADDED
Binary file (17.5 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_multi_commits.cpython-311.pyc
ADDED
Binary file (16.8 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_snapshot_download.cpython-311.pyc
ADDED
Binary file (15 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_space_api.cpython-311.pyc
ADDED
Binary file (6.64 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_tensorboard_logger.cpython-311.pyc
ADDED
Binary file (7.74 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_webhooks_payload.cpython-311.pyc
ADDED
Binary file (4.71 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/_webhooks_server.cpython-311.pyc
ADDED
Binary file (18.8 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/community.cpython-311.pyc
ADDED
Binary file (16 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/constants.cpython-311.pyc
ADDED
Binary file (7.69 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/fastai_utils.cpython-311.pyc
ADDED
Binary file (20.1 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/file_download.cpython-311.pyc
ADDED
Binary file (75.4 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_api.cpython-311.pyc
ADDED
Binary file (375 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/hf_file_system.cpython-311.pyc
ADDED
Binary file (35.4 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/hub_mixin.cpython-311.pyc
ADDED
Binary file (18.6 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/inference_api.cpython-311.pyc
ADDED
Binary file (9.45 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/keras_mixin.cpython-311.pyc
ADDED
Binary file (21.6 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/lfs.cpython-311.pyc
ADDED
Binary file (27.1 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/repocard.cpython-311.pyc
ADDED
Binary file (37.7 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/repocard_data.cpython-311.pyc
ADDED
Binary file (34.5 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/__pycache__/repository.cpython-311.pyc
ADDED
Binary file (72.1 kB). View file
|
|
lib/python3.11/site-packages/huggingface_hub/_commit_api.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Type definitions and utilities for the `create_commit` API
|
3 |
+
"""
|
4 |
+
import base64
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import warnings
|
8 |
+
from collections import defaultdict
|
9 |
+
from contextlib import contextmanager
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from itertools import groupby
|
12 |
+
from pathlib import Path, PurePosixPath
|
13 |
+
from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union
|
14 |
+
|
15 |
+
from tqdm.contrib.concurrent import thread_map
|
16 |
+
|
17 |
+
from huggingface_hub import get_session
|
18 |
+
|
19 |
+
from .constants import ENDPOINT, HF_HUB_ENABLE_HF_TRANSFER
|
20 |
+
from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info
|
21 |
+
from .utils import (
|
22 |
+
EntryNotFoundError,
|
23 |
+
build_hf_headers,
|
24 |
+
chunk_iterable,
|
25 |
+
hf_raise_for_status,
|
26 |
+
logging,
|
27 |
+
tqdm_stream_file,
|
28 |
+
validate_hf_hub_args,
|
29 |
+
)
|
30 |
+
from .utils import tqdm as hf_tqdm
|
31 |
+
|
32 |
+
|
33 |
+
if TYPE_CHECKING:
|
34 |
+
from .hf_api import RepoFile
|
35 |
+
|
36 |
+
|
37 |
+
logger = logging.get_logger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
UploadMode = Literal["lfs", "regular"]
|
41 |
+
|
42 |
+
# Max is 1,000 per request on the Hub for HfApi.get_paths_info
|
43 |
+
# Otherwise we get:
|
44 |
+
# HfHubHTTPError: 413 Client Error: Payload Too Large for url: https://huggingface.co/api/datasets/xxx (Request ID: xxx)\n\ntoo many parameters
|
45 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1503
|
46 |
+
FETCH_LFS_BATCH_SIZE = 500
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class CommitOperationDelete:
|
51 |
+
"""
|
52 |
+
Data structure holding necessary info to delete a file or a folder from a repository
|
53 |
+
on the Hub.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
path_in_repo (`str`):
|
57 |
+
Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
|
58 |
+
for a file or `"checkpoints/1fec34a/"` for a folder.
|
59 |
+
is_folder (`bool` or `Literal["auto"]`, *optional*)
|
60 |
+
Whether the Delete Operation applies to a folder or not. If "auto", the path
|
61 |
+
type (file or folder) is guessed automatically by looking if path ends with
|
62 |
+
a "/" (folder) or not (file). To explicitly set the path type, you can set
|
63 |
+
`is_folder=True` or `is_folder=False`.
|
64 |
+
"""
|
65 |
+
|
66 |
+
path_in_repo: str
|
67 |
+
is_folder: Union[bool, Literal["auto"]] = "auto"
|
68 |
+
|
69 |
+
def __post_init__(self):
|
70 |
+
self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
|
71 |
+
|
72 |
+
if self.is_folder == "auto":
|
73 |
+
self.is_folder = self.path_in_repo.endswith("/")
|
74 |
+
if not isinstance(self.is_folder, bool):
|
75 |
+
raise ValueError(
|
76 |
+
f"Wrong value for `is_folder`. Must be one of [`True`, `False`, `'auto'`]. Got '{self.is_folder}'."
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
@dataclass
|
81 |
+
class CommitOperationCopy:
|
82 |
+
"""
|
83 |
+
Data structure holding necessary info to copy a file in a repository on the Hub.
|
84 |
+
|
85 |
+
Limitations:
|
86 |
+
- Only LFS files can be copied. To copy a regular file, you need to download it locally and re-upload it
|
87 |
+
- Cross-repository copies are not supported.
|
88 |
+
|
89 |
+
Note: you can combine a [`CommitOperationCopy`] and a [`CommitOperationDelete`] to rename an LFS file on the Hub.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
src_path_in_repo (`str`):
|
93 |
+
Relative filepath in the repo of the file to be copied, e.g. `"checkpoints/1fec34a/weights.bin"`.
|
94 |
+
path_in_repo (`str`):
|
95 |
+
Relative filepath in the repo where to copy the file, e.g. `"checkpoints/1fec34a/weights_copy.bin"`.
|
96 |
+
src_revision (`str`, *optional*):
|
97 |
+
The git revision of the file to be copied. Can be any valid git revision.
|
98 |
+
Default to the target commit revision.
|
99 |
+
"""
|
100 |
+
|
101 |
+
src_path_in_repo: str
|
102 |
+
path_in_repo: str
|
103 |
+
src_revision: Optional[str] = None
|
104 |
+
|
105 |
+
def __post_init__(self):
|
106 |
+
self.src_path_in_repo = _validate_path_in_repo(self.src_path_in_repo)
|
107 |
+
self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
|
108 |
+
|
109 |
+
|
110 |
+
@dataclass
|
111 |
+
class CommitOperationAdd:
|
112 |
+
"""
|
113 |
+
Data structure holding necessary info to upload a file to a repository on the Hub.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
path_in_repo (`str`):
|
117 |
+
Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
|
118 |
+
path_or_fileobj (`str`, `Path`, `bytes`, or `BinaryIO`):
|
119 |
+
Either:
|
120 |
+
- a path to a local file (as `str` or `pathlib.Path`) to upload
|
121 |
+
- a buffer of bytes (`bytes`) holding the content of the file to upload
|
122 |
+
- a "file object" (subclass of `io.BufferedIOBase`), typically obtained
|
123 |
+
with `open(path, "rb")`. It must support `seek()` and `tell()` methods.
|
124 |
+
|
125 |
+
Raises:
|
126 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
127 |
+
If `path_or_fileobj` is not one of `str`, `Path`, `bytes` or `io.BufferedIOBase`.
|
128 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
129 |
+
If `path_or_fileobj` is a `str` or `Path` but not a path to an existing file.
|
130 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
131 |
+
If `path_or_fileobj` is a `io.BufferedIOBase` but it doesn't support both
|
132 |
+
`seek()` and `tell()`.
|
133 |
+
"""
|
134 |
+
|
135 |
+
path_in_repo: str
|
136 |
+
path_or_fileobj: Union[str, Path, bytes, BinaryIO]
|
137 |
+
upload_info: UploadInfo = field(init=False, repr=False)
|
138 |
+
|
139 |
+
# Internal attributes
|
140 |
+
|
141 |
+
# set to "lfs" or "regular" once known
|
142 |
+
_upload_mode: Optional[UploadMode] = field(init=False, repr=False, default=None)
|
143 |
+
|
144 |
+
# set to True if .gitignore rules prevent the file from being uploaded as LFS
|
145 |
+
# (server-side check)
|
146 |
+
_should_ignore: Optional[bool] = field(init=False, repr=False, default=None)
|
147 |
+
|
148 |
+
# set to True once the file has been uploaded as LFS
|
149 |
+
_is_uploaded: bool = field(init=False, repr=False, default=False)
|
150 |
+
|
151 |
+
# set to True once the file has been committed
|
152 |
+
_is_committed: bool = field(init=False, repr=False, default=False)
|
153 |
+
|
154 |
+
def __post_init__(self) -> None:
|
155 |
+
"""Validates `path_or_fileobj` and compute `upload_info`."""
|
156 |
+
self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
|
157 |
+
|
158 |
+
# Validate `path_or_fileobj` value
|
159 |
+
if isinstance(self.path_or_fileobj, Path):
|
160 |
+
self.path_or_fileobj = str(self.path_or_fileobj)
|
161 |
+
if isinstance(self.path_or_fileobj, str):
|
162 |
+
path_or_fileobj = os.path.normpath(os.path.expanduser(self.path_or_fileobj))
|
163 |
+
if not os.path.isfile(path_or_fileobj):
|
164 |
+
raise ValueError(f"Provided path: '{path_or_fileobj}' is not a file on the local file system")
|
165 |
+
elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)):
|
166 |
+
# ^^ Inspired from: https://stackoverflow.com/questions/44584829/how-to-determine-if-file-is-opened-in-binary-or-text-mode
|
167 |
+
raise ValueError(
|
168 |
+
"path_or_fileobj must be either an instance of str, bytes or"
|
169 |
+
" io.BufferedIOBase. If you passed a file-like object, make sure it is"
|
170 |
+
" in binary mode."
|
171 |
+
)
|
172 |
+
if isinstance(self.path_or_fileobj, io.BufferedIOBase):
|
173 |
+
try:
|
174 |
+
self.path_or_fileobj.tell()
|
175 |
+
self.path_or_fileobj.seek(0, os.SEEK_CUR)
|
176 |
+
except (OSError, AttributeError) as exc:
|
177 |
+
raise ValueError(
|
178 |
+
"path_or_fileobj is a file-like object but does not implement seek() and tell()"
|
179 |
+
) from exc
|
180 |
+
|
181 |
+
# Compute "upload_info" attribute
|
182 |
+
if isinstance(self.path_or_fileobj, str):
|
183 |
+
self.upload_info = UploadInfo.from_path(self.path_or_fileobj)
|
184 |
+
elif isinstance(self.path_or_fileobj, bytes):
|
185 |
+
self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj)
|
186 |
+
else:
|
187 |
+
self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj)
|
188 |
+
|
189 |
+
@contextmanager
|
190 |
+
def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]:
|
191 |
+
"""
|
192 |
+
A context manager that yields a file-like object allowing to read the underlying
|
193 |
+
data behind `path_or_fileobj`.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
with_tqdm (`bool`, *optional*, defaults to `False`):
|
197 |
+
If True, iterating over the file object will display a progress bar. Only
|
198 |
+
works if the file-like object is a path to a file. Pure bytes and buffers
|
199 |
+
are not supported.
|
200 |
+
|
201 |
+
Example:
|
202 |
+
|
203 |
+
```python
|
204 |
+
>>> operation = CommitOperationAdd(
|
205 |
+
... path_in_repo="remote/dir/weights.h5",
|
206 |
+
... path_or_fileobj="./local/weights.h5",
|
207 |
+
... )
|
208 |
+
CommitOperationAdd(path_in_repo='remote/dir/weights.h5', path_or_fileobj='./local/weights.h5')
|
209 |
+
|
210 |
+
>>> with operation.as_file() as file:
|
211 |
+
... content = file.read()
|
212 |
+
|
213 |
+
>>> with operation.as_file(with_tqdm=True) as file:
|
214 |
+
... while True:
|
215 |
+
... data = file.read(1024)
|
216 |
+
... if not data:
|
217 |
+
... break
|
218 |
+
config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
|
219 |
+
|
220 |
+
>>> with operation.as_file(with_tqdm=True) as file:
|
221 |
+
... requests.put(..., data=file)
|
222 |
+
config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
|
223 |
+
```
|
224 |
+
"""
|
225 |
+
if isinstance(self.path_or_fileobj, str) or isinstance(self.path_or_fileobj, Path):
|
226 |
+
if with_tqdm:
|
227 |
+
with tqdm_stream_file(self.path_or_fileobj) as file:
|
228 |
+
yield file
|
229 |
+
else:
|
230 |
+
with open(self.path_or_fileobj, "rb") as file:
|
231 |
+
yield file
|
232 |
+
elif isinstance(self.path_or_fileobj, bytes):
|
233 |
+
yield io.BytesIO(self.path_or_fileobj)
|
234 |
+
elif isinstance(self.path_or_fileobj, io.BufferedIOBase):
|
235 |
+
prev_pos = self.path_or_fileobj.tell()
|
236 |
+
yield self.path_or_fileobj
|
237 |
+
self.path_or_fileobj.seek(prev_pos, io.SEEK_SET)
|
238 |
+
|
239 |
+
def b64content(self) -> bytes:
|
240 |
+
"""
|
241 |
+
The base64-encoded content of `path_or_fileobj`
|
242 |
+
|
243 |
+
Returns: `bytes`
|
244 |
+
"""
|
245 |
+
with self.as_file() as file:
|
246 |
+
return base64.b64encode(file.read())
|
247 |
+
|
248 |
+
|
249 |
+
def _validate_path_in_repo(path_in_repo: str) -> str:
|
250 |
+
# Validate `path_in_repo` value to prevent a server-side issue
|
251 |
+
if path_in_repo.startswith("/"):
|
252 |
+
path_in_repo = path_in_repo[1:]
|
253 |
+
if path_in_repo == "." or path_in_repo == ".." or path_in_repo.startswith("../"):
|
254 |
+
raise ValueError(f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'")
|
255 |
+
if path_in_repo.startswith("./"):
|
256 |
+
path_in_repo = path_in_repo[2:]
|
257 |
+
if any(part == ".git" for part in path_in_repo.split("/")):
|
258 |
+
raise ValueError(
|
259 |
+
"Invalid `path_in_repo` in CommitOperation: cannot update files under a '.git/' folder (path:"
|
260 |
+
f" '{path_in_repo}')."
|
261 |
+
)
|
262 |
+
return path_in_repo
|
263 |
+
|
264 |
+
|
265 |
+
CommitOperation = Union[CommitOperationAdd, CommitOperationCopy, CommitOperationDelete]
|
266 |
+
|
267 |
+
|
268 |
+
def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None:
|
269 |
+
"""
|
270 |
+
Warn user when a list of operations is expected to overwrite itself in a single
|
271 |
+
commit.
|
272 |
+
|
273 |
+
Rules:
|
274 |
+
- If a filepath is updated by multiple `CommitOperationAdd` operations, a warning
|
275 |
+
message is triggered.
|
276 |
+
- If a filepath is updated at least once by a `CommitOperationAdd` and then deleted
|
277 |
+
by a `CommitOperationDelete`, a warning is triggered.
|
278 |
+
- If a `CommitOperationDelete` deletes a filepath that is then updated by a
|
279 |
+
`CommitOperationAdd`, no warning is triggered. This is usually useless (no need to
|
280 |
+
delete before upload) but can happen if a user deletes an entire folder and then
|
281 |
+
add new files to it.
|
282 |
+
"""
|
283 |
+
nb_additions_per_path: Dict[str, int] = defaultdict(int)
|
284 |
+
for operation in operations:
|
285 |
+
path_in_repo = operation.path_in_repo
|
286 |
+
if isinstance(operation, CommitOperationAdd):
|
287 |
+
if nb_additions_per_path[path_in_repo] > 0:
|
288 |
+
warnings.warn(
|
289 |
+
"About to update multiple times the same file in the same commit:"
|
290 |
+
f" '{path_in_repo}'. This can cause undesired inconsistencies in"
|
291 |
+
" your repo."
|
292 |
+
)
|
293 |
+
nb_additions_per_path[path_in_repo] += 1
|
294 |
+
for parent in PurePosixPath(path_in_repo).parents:
|
295 |
+
# Also keep track of number of updated files per folder
|
296 |
+
# => warns if deleting a folder overwrite some contained files
|
297 |
+
nb_additions_per_path[str(parent)] += 1
|
298 |
+
if isinstance(operation, CommitOperationDelete):
|
299 |
+
if nb_additions_per_path[str(PurePosixPath(path_in_repo))] > 0:
|
300 |
+
if operation.is_folder:
|
301 |
+
warnings.warn(
|
302 |
+
"About to delete a folder containing files that have just been"
|
303 |
+
f" updated within the same commit: '{path_in_repo}'. This can"
|
304 |
+
" cause undesired inconsistencies in your repo."
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
warnings.warn(
|
308 |
+
"About to delete a file that have just been updated within the"
|
309 |
+
f" same commit: '{path_in_repo}'. This can cause undesired"
|
310 |
+
" inconsistencies in your repo."
|
311 |
+
)
|
312 |
+
|
313 |
+
|
314 |
+
@validate_hf_hub_args
|
315 |
+
def _upload_lfs_files(
|
316 |
+
*,
|
317 |
+
additions: List[CommitOperationAdd],
|
318 |
+
repo_type: str,
|
319 |
+
repo_id: str,
|
320 |
+
token: Optional[str],
|
321 |
+
endpoint: Optional[str] = None,
|
322 |
+
num_threads: int = 5,
|
323 |
+
revision: Optional[str] = None,
|
324 |
+
):
|
325 |
+
"""
|
326 |
+
Uploads the content of `additions` to the Hub using the large file storage protocol.
|
327 |
+
|
328 |
+
Relevant external documentation:
|
329 |
+
- LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
|
330 |
+
|
331 |
+
Args:
|
332 |
+
additions (`List` of `CommitOperationAdd`):
|
333 |
+
The files to be uploaded
|
334 |
+
repo_type (`str`):
|
335 |
+
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
336 |
+
repo_id (`str`):
|
337 |
+
A namespace (user or an organization) and a repo name separated
|
338 |
+
by a `/`.
|
339 |
+
token (`str`, *optional*):
|
340 |
+
An authentication token ( See https://huggingface.co/settings/tokens )
|
341 |
+
num_threads (`int`, *optional*):
|
342 |
+
The number of concurrent threads to use when uploading. Defaults to 5.
|
343 |
+
revision (`str`, *optional*):
|
344 |
+
The git revision to upload to.
|
345 |
+
|
346 |
+
Raises: `RuntimeError` if an upload failed for any reason
|
347 |
+
|
348 |
+
Raises: `ValueError` if the server returns malformed responses
|
349 |
+
|
350 |
+
Raises: `requests.HTTPError` if the LFS batch endpoint returned an HTTP
|
351 |
+
error
|
352 |
+
|
353 |
+
"""
|
354 |
+
# Step 1: retrieve upload instructions from the LFS batch endpoint.
|
355 |
+
# Upload instructions are retrieved by chunk of 256 files to avoid reaching
|
356 |
+
# the payload limit.
|
357 |
+
batch_actions: List[Dict] = []
|
358 |
+
for chunk in chunk_iterable(additions, chunk_size=256):
|
359 |
+
batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info(
|
360 |
+
upload_infos=[op.upload_info for op in chunk],
|
361 |
+
token=token,
|
362 |
+
repo_id=repo_id,
|
363 |
+
repo_type=repo_type,
|
364 |
+
revision=revision,
|
365 |
+
endpoint=endpoint,
|
366 |
+
)
|
367 |
+
|
368 |
+
# If at least 1 error, we do not retrieve information for other chunks
|
369 |
+
if batch_errors_chunk:
|
370 |
+
message = "\n".join(
|
371 |
+
[
|
372 |
+
f'Encountered error for file with OID {err.get("oid")}: `{err.get("error", {}).get("message")}'
|
373 |
+
for err in batch_errors_chunk
|
374 |
+
]
|
375 |
+
)
|
376 |
+
raise ValueError(f"LFS batch endpoint returned errors:\n{message}")
|
377 |
+
|
378 |
+
batch_actions += batch_actions_chunk
|
379 |
+
oid2addop = {add_op.upload_info.sha256.hex(): add_op for add_op in additions}
|
380 |
+
|
381 |
+
# Step 2: ignore files that have already been uploaded
|
382 |
+
filtered_actions = []
|
383 |
+
for action in batch_actions:
|
384 |
+
if action.get("actions") is None:
|
385 |
+
logger.debug(
|
386 |
+
f"Content of file {oid2addop[action['oid']].path_in_repo} is already"
|
387 |
+
" present upstream - skipping upload."
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
filtered_actions.append(action)
|
391 |
+
|
392 |
+
if len(filtered_actions) == 0:
|
393 |
+
logger.debug("No LFS files to upload.")
|
394 |
+
return
|
395 |
+
|
396 |
+
# Step 3: upload files concurrently according to these instructions
|
397 |
+
def _wrapped_lfs_upload(batch_action) -> None:
|
398 |
+
try:
|
399 |
+
operation = oid2addop[batch_action["oid"]]
|
400 |
+
lfs_upload(operation=operation, lfs_batch_action=batch_action, token=token)
|
401 |
+
except Exception as exc:
|
402 |
+
raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc
|
403 |
+
|
404 |
+
if HF_HUB_ENABLE_HF_TRANSFER:
|
405 |
+
logger.debug(f"Uploading {len(filtered_actions)} LFS files to the Hub using `hf_transfer`.")
|
406 |
+
for action in hf_tqdm(filtered_actions):
|
407 |
+
_wrapped_lfs_upload(action)
|
408 |
+
elif len(filtered_actions) == 1:
|
409 |
+
logger.debug("Uploading 1 LFS file to the Hub")
|
410 |
+
_wrapped_lfs_upload(filtered_actions[0])
|
411 |
+
else:
|
412 |
+
logger.debug(
|
413 |
+
f"Uploading {len(filtered_actions)} LFS files to the Hub using up to {num_threads} threads concurrently"
|
414 |
+
)
|
415 |
+
thread_map(
|
416 |
+
_wrapped_lfs_upload,
|
417 |
+
filtered_actions,
|
418 |
+
desc=f"Upload {len(filtered_actions)} LFS files",
|
419 |
+
max_workers=num_threads,
|
420 |
+
tqdm_class=hf_tqdm,
|
421 |
+
)
|
422 |
+
|
423 |
+
|
424 |
+
def _validate_preupload_info(preupload_info: dict):
|
425 |
+
files = preupload_info.get("files")
|
426 |
+
if not isinstance(files, list):
|
427 |
+
raise ValueError("preupload_info is improperly formatted")
|
428 |
+
for file_info in files:
|
429 |
+
if not (
|
430 |
+
isinstance(file_info, dict)
|
431 |
+
and isinstance(file_info.get("path"), str)
|
432 |
+
and isinstance(file_info.get("uploadMode"), str)
|
433 |
+
and (file_info["uploadMode"] in ("lfs", "regular"))
|
434 |
+
):
|
435 |
+
raise ValueError("preupload_info is improperly formatted:")
|
436 |
+
return preupload_info
|
437 |
+
|
438 |
+
|
439 |
+
@validate_hf_hub_args
|
440 |
+
def _fetch_upload_modes(
|
441 |
+
additions: Iterable[CommitOperationAdd],
|
442 |
+
repo_type: str,
|
443 |
+
repo_id: str,
|
444 |
+
token: Optional[str],
|
445 |
+
revision: str,
|
446 |
+
endpoint: Optional[str] = None,
|
447 |
+
create_pr: bool = False,
|
448 |
+
gitignore_content: Optional[str] = None,
|
449 |
+
) -> None:
|
450 |
+
"""
|
451 |
+
Requests the Hub "preupload" endpoint to determine whether each input file should be uploaded as a regular git blob
|
452 |
+
or as git LFS blob. Input `additions` are mutated in-place with the upload mode.
|
453 |
+
|
454 |
+
Args:
|
455 |
+
additions (`Iterable` of :class:`CommitOperationAdd`):
|
456 |
+
Iterable of :class:`CommitOperationAdd` describing the files to
|
457 |
+
upload to the Hub.
|
458 |
+
repo_type (`str`):
|
459 |
+
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
460 |
+
repo_id (`str`):
|
461 |
+
A namespace (user or an organization) and a repo name separated
|
462 |
+
by a `/`.
|
463 |
+
token (`str`, *optional*):
|
464 |
+
An authentication token ( See https://huggingface.co/settings/tokens )
|
465 |
+
revision (`str`):
|
466 |
+
The git revision to upload the files to. Can be any valid git revision.
|
467 |
+
gitignore_content (`str`, *optional*):
|
468 |
+
The content of the `.gitignore` file to know which files should be ignored. The order of priority
|
469 |
+
is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present
|
470 |
+
in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub
|
471 |
+
(if any).
|
472 |
+
Raises:
|
473 |
+
[`~utils.HfHubHTTPError`]
|
474 |
+
If the Hub API returned an error.
|
475 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
476 |
+
If the Hub API response is improperly formatted.
|
477 |
+
"""
|
478 |
+
endpoint = endpoint if endpoint is not None else ENDPOINT
|
479 |
+
headers = build_hf_headers(token=token)
|
480 |
+
|
481 |
+
# Fetch upload mode (LFS or regular) chunk by chunk.
|
482 |
+
upload_modes: Dict[str, UploadMode] = {}
|
483 |
+
should_ignore_info: Dict[str, bool] = {}
|
484 |
+
|
485 |
+
for chunk in chunk_iterable(additions, 256):
|
486 |
+
payload: Dict = {
|
487 |
+
"files": [
|
488 |
+
{
|
489 |
+
"path": op.path_in_repo,
|
490 |
+
"sample": base64.b64encode(op.upload_info.sample).decode("ascii"),
|
491 |
+
"size": op.upload_info.size,
|
492 |
+
"sha": op.upload_info.sha256.hex(),
|
493 |
+
}
|
494 |
+
for op in chunk
|
495 |
+
]
|
496 |
+
}
|
497 |
+
if gitignore_content is not None:
|
498 |
+
payload["gitIgnore"] = gitignore_content
|
499 |
+
|
500 |
+
resp = get_session().post(
|
501 |
+
f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}",
|
502 |
+
json=payload,
|
503 |
+
headers=headers,
|
504 |
+
params={"create_pr": "1"} if create_pr else None,
|
505 |
+
)
|
506 |
+
hf_raise_for_status(resp)
|
507 |
+
preupload_info = _validate_preupload_info(resp.json())
|
508 |
+
upload_modes.update(**{file["path"]: file["uploadMode"] for file in preupload_info["files"]})
|
509 |
+
should_ignore_info.update(**{file["path"]: file["shouldIgnore"] for file in preupload_info["files"]})
|
510 |
+
|
511 |
+
# Set upload mode for each addition operation
|
512 |
+
for addition in additions:
|
513 |
+
addition._upload_mode = upload_modes[addition.path_in_repo]
|
514 |
+
addition._should_ignore = should_ignore_info[addition.path_in_repo]
|
515 |
+
|
516 |
+
# Empty files cannot be uploaded as LFS (S3 would fail with a 501 Not Implemented)
|
517 |
+
# => empty files are uploaded as "regular" to still allow users to commit them.
|
518 |
+
for addition in additions:
|
519 |
+
if addition.upload_info.size == 0:
|
520 |
+
addition._upload_mode = "regular"
|
521 |
+
|
522 |
+
|
523 |
+
@validate_hf_hub_args
|
524 |
+
def _fetch_lfs_files_to_copy(
|
525 |
+
copies: Iterable[CommitOperationCopy],
|
526 |
+
repo_type: str,
|
527 |
+
repo_id: str,
|
528 |
+
token: Optional[str],
|
529 |
+
revision: str,
|
530 |
+
endpoint: Optional[str] = None,
|
531 |
+
) -> Dict[Tuple[str, Optional[str]], "RepoFile"]:
|
532 |
+
"""
|
533 |
+
Requests the Hub files information of the LFS files to be copied, including their sha256.
|
534 |
+
|
535 |
+
Args:
|
536 |
+
copies (`Iterable` of :class:`CommitOperationCopy`):
|
537 |
+
Iterable of :class:`CommitOperationCopy` describing the files to
|
538 |
+
copy on the Hub.
|
539 |
+
repo_type (`str`):
|
540 |
+
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
|
541 |
+
repo_id (`str`):
|
542 |
+
A namespace (user or an organization) and a repo name separated
|
543 |
+
by a `/`.
|
544 |
+
token (`str`, *optional*):
|
545 |
+
An authentication token ( See https://huggingface.co/settings/tokens )
|
546 |
+
revision (`str`):
|
547 |
+
The git revision to upload the files to. Can be any valid git revision.
|
548 |
+
|
549 |
+
Returns: `Dict[Tuple[str, Optional[str]], RepoFile]]`
|
550 |
+
Key is the file path and revision of the file to copy, value is the repo file.
|
551 |
+
|
552 |
+
Raises:
|
553 |
+
[`~utils.HfHubHTTPError`]
|
554 |
+
If the Hub API returned an error.
|
555 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
556 |
+
If the Hub API response is improperly formatted.
|
557 |
+
"""
|
558 |
+
from .hf_api import HfApi, RepoFolder
|
559 |
+
|
560 |
+
hf_api = HfApi(endpoint=endpoint, token=token)
|
561 |
+
files_to_copy = {}
|
562 |
+
for src_revision, operations in groupby(copies, key=lambda op: op.src_revision):
|
563 |
+
operations = list(operations) # type: ignore
|
564 |
+
paths = [op.src_path_in_repo for op in operations]
|
565 |
+
for offset in range(0, len(paths), FETCH_LFS_BATCH_SIZE):
|
566 |
+
src_repo_files = hf_api.get_paths_info(
|
567 |
+
repo_id=repo_id,
|
568 |
+
paths=paths[offset : offset + FETCH_LFS_BATCH_SIZE],
|
569 |
+
revision=src_revision or revision,
|
570 |
+
repo_type=repo_type,
|
571 |
+
)
|
572 |
+
for src_repo_file in src_repo_files:
|
573 |
+
if isinstance(src_repo_file, RepoFolder):
|
574 |
+
raise NotImplementedError("Copying a folder is not implemented.")
|
575 |
+
if not src_repo_file.lfs:
|
576 |
+
raise NotImplementedError("Copying a non-LFS file is not implemented")
|
577 |
+
files_to_copy[(src_repo_file.rfilename, src_revision)] = src_repo_file
|
578 |
+
for operation in operations:
|
579 |
+
if (operation.src_path_in_repo, src_revision) not in files_to_copy:
|
580 |
+
raise EntryNotFoundError(
|
581 |
+
f"Cannot copy {operation.src_path_in_repo} at revision "
|
582 |
+
f"{src_revision or revision}: file is missing on repo."
|
583 |
+
)
|
584 |
+
return files_to_copy
|
585 |
+
|
586 |
+
|
587 |
+
def _prepare_commit_payload(
|
588 |
+
operations: Iterable[CommitOperation],
|
589 |
+
files_to_copy: Dict[Tuple[str, Optional[str]], "RepoFile"],
|
590 |
+
commit_message: str,
|
591 |
+
commit_description: Optional[str] = None,
|
592 |
+
parent_commit: Optional[str] = None,
|
593 |
+
) -> Iterable[Dict[str, Any]]:
|
594 |
+
"""
|
595 |
+
Builds the payload to POST to the `/commit` API of the Hub.
|
596 |
+
|
597 |
+
Payload is returned as an iterator so that it can be streamed as a ndjson in the
|
598 |
+
POST request.
|
599 |
+
|
600 |
+
For more information, see:
|
601 |
+
- https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073
|
602 |
+
- http://ndjson.org/
|
603 |
+
"""
|
604 |
+
commit_description = commit_description if commit_description is not None else ""
|
605 |
+
|
606 |
+
# 1. Send a header item with the commit metadata
|
607 |
+
header_value = {"summary": commit_message, "description": commit_description}
|
608 |
+
if parent_commit is not None:
|
609 |
+
header_value["parentCommit"] = parent_commit
|
610 |
+
yield {"key": "header", "value": header_value}
|
611 |
+
|
612 |
+
nb_ignored_files = 0
|
613 |
+
|
614 |
+
# 2. Send operations, one per line
|
615 |
+
for operation in operations:
|
616 |
+
# Skip ignored files
|
617 |
+
if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
|
618 |
+
logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
|
619 |
+
nb_ignored_files += 1
|
620 |
+
continue
|
621 |
+
|
622 |
+
# 2.a. Case adding a regular file
|
623 |
+
if isinstance(operation, CommitOperationAdd) and operation._upload_mode == "regular":
|
624 |
+
yield {
|
625 |
+
"key": "file",
|
626 |
+
"value": {
|
627 |
+
"content": operation.b64content().decode(),
|
628 |
+
"path": operation.path_in_repo,
|
629 |
+
"encoding": "base64",
|
630 |
+
},
|
631 |
+
}
|
632 |
+
# 2.b. Case adding an LFS file
|
633 |
+
elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == "lfs":
|
634 |
+
yield {
|
635 |
+
"key": "lfsFile",
|
636 |
+
"value": {
|
637 |
+
"path": operation.path_in_repo,
|
638 |
+
"algo": "sha256",
|
639 |
+
"oid": operation.upload_info.sha256.hex(),
|
640 |
+
"size": operation.upload_info.size,
|
641 |
+
},
|
642 |
+
}
|
643 |
+
# 2.c. Case deleting a file or folder
|
644 |
+
elif isinstance(operation, CommitOperationDelete):
|
645 |
+
yield {
|
646 |
+
"key": "deletedFolder" if operation.is_folder else "deletedFile",
|
647 |
+
"value": {"path": operation.path_in_repo},
|
648 |
+
}
|
649 |
+
# 2.d. Case copying a file or folder
|
650 |
+
elif isinstance(operation, CommitOperationCopy):
|
651 |
+
file_to_copy = files_to_copy[(operation.src_path_in_repo, operation.src_revision)]
|
652 |
+
if not file_to_copy.lfs:
|
653 |
+
raise NotImplementedError("Copying a non-LFS file is not implemented")
|
654 |
+
yield {
|
655 |
+
"key": "lfsFile",
|
656 |
+
"value": {
|
657 |
+
"path": operation.path_in_repo,
|
658 |
+
"algo": "sha256",
|
659 |
+
"oid": file_to_copy.lfs["sha256"],
|
660 |
+
},
|
661 |
+
}
|
662 |
+
# 2.e. Never expected to happen
|
663 |
+
else:
|
664 |
+
raise ValueError(
|
665 |
+
f"Unknown operation to commit. Operation: {operation}. Upload mode:"
|
666 |
+
f" {getattr(operation, '_upload_mode', None)}"
|
667 |
+
)
|
668 |
+
|
669 |
+
if nb_ignored_files > 0:
|
670 |
+
logger.info(f"Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).")
|
lib/python3.11/site-packages/huggingface_hub/_commit_scheduler.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import atexit
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from concurrent.futures import Future
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from io import SEEK_END, SEEK_SET, BytesIO
|
8 |
+
from pathlib import Path
|
9 |
+
from threading import Lock, Thread
|
10 |
+
from typing import Dict, List, Optional, Union
|
11 |
+
|
12 |
+
from .hf_api import IGNORE_GIT_FOLDER_PATTERNS, CommitInfo, CommitOperationAdd, HfApi
|
13 |
+
from .utils import filter_repo_objects
|
14 |
+
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass(frozen=True)
|
20 |
+
class _FileToUpload:
|
21 |
+
"""Temporary dataclass to store info about files to upload. Not meant to be used directly."""
|
22 |
+
|
23 |
+
local_path: Path
|
24 |
+
path_in_repo: str
|
25 |
+
size_limit: int
|
26 |
+
last_modified: float
|
27 |
+
|
28 |
+
|
29 |
+
class CommitScheduler:
|
30 |
+
"""
|
31 |
+
Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
|
32 |
+
|
33 |
+
The scheduler is started when instantiated and run indefinitely. At the end of your script, a last commit is
|
34 |
+
triggered. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
|
35 |
+
to learn more about how to use it.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
repo_id (`str`):
|
39 |
+
The id of the repo to commit to.
|
40 |
+
folder_path (`str` or `Path`):
|
41 |
+
Path to the local folder to upload regularly.
|
42 |
+
every (`int` or `float`, *optional*):
|
43 |
+
The number of minutes between each commit. Defaults to 5 minutes.
|
44 |
+
path_in_repo (`str`, *optional*):
|
45 |
+
Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
|
46 |
+
of the repository.
|
47 |
+
repo_type (`str`, *optional*):
|
48 |
+
The type of the repo to commit to. Defaults to `model`.
|
49 |
+
revision (`str`, *optional*):
|
50 |
+
The revision of the repo to commit to. Defaults to `main`.
|
51 |
+
private (`bool`, *optional*):
|
52 |
+
Whether to make the repo private. Defaults to `False`. This value is ignored if the repo already exist.
|
53 |
+
token (`str`, *optional*):
|
54 |
+
The token to use to commit to the repo. Defaults to the token saved on the machine.
|
55 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
56 |
+
If provided, only files matching at least one pattern are uploaded.
|
57 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
58 |
+
If provided, files matching any of the patterns are not uploaded.
|
59 |
+
squash_history (`bool`, *optional*):
|
60 |
+
Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
|
61 |
+
useful to avoid degraded performances on the repo when it grows too large.
|
62 |
+
hf_api (`HfApi`, *optional*):
|
63 |
+
The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
|
64 |
+
|
65 |
+
Example:
|
66 |
+
```py
|
67 |
+
>>> from pathlib import Path
|
68 |
+
>>> from huggingface_hub import CommitScheduler
|
69 |
+
|
70 |
+
# Scheduler uploads every 10 minutes
|
71 |
+
>>> csv_path = Path("watched_folder/data.csv")
|
72 |
+
>>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
|
73 |
+
|
74 |
+
>>> with csv_path.open("a") as f:
|
75 |
+
... f.write("first line")
|
76 |
+
|
77 |
+
# Some time later (...)
|
78 |
+
>>> with csv_path.open("a") as f:
|
79 |
+
... f.write("second line")
|
80 |
+
```
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
*,
|
86 |
+
repo_id: str,
|
87 |
+
folder_path: Union[str, Path],
|
88 |
+
every: Union[int, float] = 5,
|
89 |
+
path_in_repo: Optional[str] = None,
|
90 |
+
repo_type: Optional[str] = None,
|
91 |
+
revision: Optional[str] = None,
|
92 |
+
private: bool = False,
|
93 |
+
token: Optional[str] = None,
|
94 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
95 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
96 |
+
squash_history: bool = False,
|
97 |
+
hf_api: Optional["HfApi"] = None,
|
98 |
+
) -> None:
|
99 |
+
self.api = hf_api or HfApi(token=token)
|
100 |
+
|
101 |
+
# Folder
|
102 |
+
self.folder_path = Path(folder_path).expanduser().resolve()
|
103 |
+
self.path_in_repo = path_in_repo or ""
|
104 |
+
self.allow_patterns = allow_patterns
|
105 |
+
|
106 |
+
if ignore_patterns is None:
|
107 |
+
ignore_patterns = []
|
108 |
+
elif isinstance(ignore_patterns, str):
|
109 |
+
ignore_patterns = [ignore_patterns]
|
110 |
+
self.ignore_patterns = ignore_patterns + IGNORE_GIT_FOLDER_PATTERNS
|
111 |
+
|
112 |
+
if self.folder_path.is_file():
|
113 |
+
raise ValueError(f"'folder_path' must be a directory, not a file: '{self.folder_path}'.")
|
114 |
+
self.folder_path.mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
# Repository
|
117 |
+
repo_url = self.api.create_repo(repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True)
|
118 |
+
self.repo_id = repo_url.repo_id
|
119 |
+
self.repo_type = repo_type
|
120 |
+
self.revision = revision
|
121 |
+
self.token = token
|
122 |
+
|
123 |
+
# Keep track of already uploaded files
|
124 |
+
self.last_uploaded: Dict[Path, float] = {} # key is local path, value is timestamp
|
125 |
+
|
126 |
+
# Scheduler
|
127 |
+
if not every > 0:
|
128 |
+
raise ValueError(f"'every' must be a positive integer, not '{every}'.")
|
129 |
+
self.lock = Lock()
|
130 |
+
self.every = every
|
131 |
+
self.squash_history = squash_history
|
132 |
+
|
133 |
+
logger.info(f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes.")
|
134 |
+
self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
|
135 |
+
self._scheduler_thread.start()
|
136 |
+
atexit.register(self._push_to_hub)
|
137 |
+
|
138 |
+
self.__stopped = False
|
139 |
+
|
140 |
+
def stop(self) -> None:
|
141 |
+
"""Stop the scheduler.
|
142 |
+
|
143 |
+
A stopped scheduler cannot be restarted. Mostly for tests purposes.
|
144 |
+
"""
|
145 |
+
self.__stopped = True
|
146 |
+
|
147 |
+
def _run_scheduler(self) -> None:
|
148 |
+
"""Dumb thread waiting between each scheduled push to Hub."""
|
149 |
+
while True:
|
150 |
+
self.last_future = self.trigger()
|
151 |
+
time.sleep(self.every * 60)
|
152 |
+
if self.__stopped:
|
153 |
+
break
|
154 |
+
|
155 |
+
def trigger(self) -> Future:
|
156 |
+
"""Trigger a `push_to_hub` and return a future.
|
157 |
+
|
158 |
+
This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
|
159 |
+
immediately, without waiting for the next scheduled commit.
|
160 |
+
"""
|
161 |
+
return self.api.run_as_future(self._push_to_hub)
|
162 |
+
|
163 |
+
def _push_to_hub(self) -> Optional[CommitInfo]:
|
164 |
+
if self.__stopped: # If stopped, already scheduled commits are ignored
|
165 |
+
return None
|
166 |
+
|
167 |
+
logger.info("(Background) scheduled commit triggered.")
|
168 |
+
try:
|
169 |
+
value = self.push_to_hub()
|
170 |
+
if self.squash_history:
|
171 |
+
logger.info("(Background) squashing repo history.")
|
172 |
+
self.api.super_squash_history(repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision)
|
173 |
+
return value
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"Error while pushing to Hub: {e}") # Depending on the setup, error might be silenced
|
176 |
+
raise
|
177 |
+
|
178 |
+
def push_to_hub(self) -> Optional[CommitInfo]:
|
179 |
+
"""
|
180 |
+
Push folder to the Hub and return the commit info.
|
181 |
+
|
182 |
+
<Tip warning={true}>
|
183 |
+
|
184 |
+
This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
|
185 |
+
queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
|
186 |
+
issues.
|
187 |
+
|
188 |
+
</Tip>
|
189 |
+
|
190 |
+
The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
|
191 |
+
uploads only changed files. If no changes are found, the method returns without committing anything. If you want
|
192 |
+
to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
|
193 |
+
for example to compress data together in a single file before committing. For more details and examples, check
|
194 |
+
out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
|
195 |
+
"""
|
196 |
+
# Check files to upload (with lock)
|
197 |
+
with self.lock:
|
198 |
+
logger.debug("Listing files to upload for scheduled commit.")
|
199 |
+
|
200 |
+
# List files from folder (taken from `_prepare_upload_folder_additions`)
|
201 |
+
relpath_to_abspath = {
|
202 |
+
path.relative_to(self.folder_path).as_posix(): path
|
203 |
+
for path in sorted(self.folder_path.glob("**/*")) # sorted to be deterministic
|
204 |
+
if path.is_file()
|
205 |
+
}
|
206 |
+
prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
|
207 |
+
|
208 |
+
# Filter with pattern + filter out unchanged files + retrieve current file size
|
209 |
+
files_to_upload: List[_FileToUpload] = []
|
210 |
+
for relpath in filter_repo_objects(
|
211 |
+
relpath_to_abspath.keys(), allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns
|
212 |
+
):
|
213 |
+
local_path = relpath_to_abspath[relpath]
|
214 |
+
stat = local_path.stat()
|
215 |
+
if self.last_uploaded.get(local_path) is None or self.last_uploaded[local_path] != stat.st_mtime:
|
216 |
+
files_to_upload.append(
|
217 |
+
_FileToUpload(
|
218 |
+
local_path=local_path,
|
219 |
+
path_in_repo=prefix + relpath,
|
220 |
+
size_limit=stat.st_size,
|
221 |
+
last_modified=stat.st_mtime,
|
222 |
+
)
|
223 |
+
)
|
224 |
+
|
225 |
+
# Return if nothing to upload
|
226 |
+
if len(files_to_upload) == 0:
|
227 |
+
logger.debug("Dropping schedule commit: no changed file to upload.")
|
228 |
+
return None
|
229 |
+
|
230 |
+
# Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
|
231 |
+
logger.debug("Removing unchanged files since previous scheduled commit.")
|
232 |
+
add_operations = [
|
233 |
+
CommitOperationAdd(
|
234 |
+
# Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
|
235 |
+
path_or_fileobj=PartialFileIO(file_to_upload.local_path, size_limit=file_to_upload.size_limit),
|
236 |
+
path_in_repo=file_to_upload.path_in_repo,
|
237 |
+
)
|
238 |
+
for file_to_upload in files_to_upload
|
239 |
+
]
|
240 |
+
|
241 |
+
# Upload files (append mode expected - no need for lock)
|
242 |
+
logger.debug("Uploading files for scheduled commit.")
|
243 |
+
commit_info = self.api.create_commit(
|
244 |
+
repo_id=self.repo_id,
|
245 |
+
repo_type=self.repo_type,
|
246 |
+
operations=add_operations,
|
247 |
+
commit_message="Scheduled Commit",
|
248 |
+
revision=self.revision,
|
249 |
+
)
|
250 |
+
|
251 |
+
# Successful commit: keep track of the latest "last_modified" for each file
|
252 |
+
for file in files_to_upload:
|
253 |
+
self.last_uploaded[file.local_path] = file.last_modified
|
254 |
+
return commit_info
|
255 |
+
|
256 |
+
|
257 |
+
class PartialFileIO(BytesIO):
|
258 |
+
"""A file-like object that reads only the first part of a file.
|
259 |
+
|
260 |
+
Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
|
261 |
+
file is uploaded (i.e. the part that was available when the filesystem was first scanned).
|
262 |
+
|
263 |
+
In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
|
264 |
+
disturbance for the user. The object is passed to `CommitOperationAdd`.
|
265 |
+
|
266 |
+
Only supports `read`, `tell` and `seek` methods.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
file_path (`str` or `Path`):
|
270 |
+
Path to the file to read.
|
271 |
+
size_limit (`int`):
|
272 |
+
The maximum number of bytes to read from the file. If the file is larger than this, only the first part
|
273 |
+
will be read (and uploaded).
|
274 |
+
"""
|
275 |
+
|
276 |
+
def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
|
277 |
+
self._file_path = Path(file_path)
|
278 |
+
self._file = self._file_path.open("rb")
|
279 |
+
self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
|
280 |
+
|
281 |
+
def __del__(self) -> None:
|
282 |
+
self._file.close()
|
283 |
+
return super().__del__()
|
284 |
+
|
285 |
+
def __repr__(self) -> str:
|
286 |
+
return f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
|
287 |
+
|
288 |
+
def __len__(self) -> int:
|
289 |
+
return self._size_limit
|
290 |
+
|
291 |
+
def __getattribute__(self, name: str):
|
292 |
+
if name.startswith("_") or name in ("read", "tell", "seek"): # only 3 public methods supported
|
293 |
+
return super().__getattribute__(name)
|
294 |
+
raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
|
295 |
+
|
296 |
+
def tell(self) -> int:
|
297 |
+
"""Return the current file position."""
|
298 |
+
return self._file.tell()
|
299 |
+
|
300 |
+
def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
|
301 |
+
"""Change the stream position to the given offset.
|
302 |
+
|
303 |
+
Behavior is the same as a regular file, except that the position is capped to the size limit.
|
304 |
+
"""
|
305 |
+
if __whence == SEEK_END:
|
306 |
+
# SEEK_END => set from the truncated end
|
307 |
+
__offset = len(self) + __offset
|
308 |
+
__whence = SEEK_SET
|
309 |
+
|
310 |
+
pos = self._file.seek(__offset, __whence)
|
311 |
+
if pos > self._size_limit:
|
312 |
+
return self._file.seek(self._size_limit)
|
313 |
+
return pos
|
314 |
+
|
315 |
+
def read(self, __size: Optional[int] = -1) -> bytes:
|
316 |
+
"""Read at most `__size` bytes from the file.
|
317 |
+
|
318 |
+
Behavior is the same as a regular file, except that it is capped to the size limit.
|
319 |
+
"""
|
320 |
+
current = self._file.tell()
|
321 |
+
if __size is None or __size < 0:
|
322 |
+
# Read until file limit
|
323 |
+
truncated_size = self._size_limit - current
|
324 |
+
else:
|
325 |
+
# Read until file limit or __size
|
326 |
+
truncated_size = min(__size, self._size_limit - current)
|
327 |
+
return self._file.read(truncated_size)
|
lib/python3.11/site-packages/huggingface_hub/_inference_endpoints.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from datetime import datetime
|
4 |
+
from enum import Enum
|
5 |
+
from typing import TYPE_CHECKING, Dict, Optional
|
6 |
+
|
7 |
+
from .inference._client import InferenceClient
|
8 |
+
from .inference._generated._async_client import AsyncInferenceClient
|
9 |
+
from .utils import logging, parse_datetime
|
10 |
+
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from .hf_api import HfApi
|
14 |
+
|
15 |
+
|
16 |
+
logger = logging.get_logger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class InferenceEndpointError(Exception):
|
20 |
+
"""Generic exception when dealing with Inference Endpoints."""
|
21 |
+
|
22 |
+
|
23 |
+
class InferenceEndpointTimeoutError(InferenceEndpointError, TimeoutError):
|
24 |
+
"""Exception for timeouts while waiting for Inference Endpoint."""
|
25 |
+
|
26 |
+
|
27 |
+
class InferenceEndpointStatus(str, Enum):
|
28 |
+
PENDING = "pending"
|
29 |
+
INITIALIZING = "initializing"
|
30 |
+
UPDATING = "updating"
|
31 |
+
UPDATE_FAILED = "updateFailed"
|
32 |
+
RUNNING = "running"
|
33 |
+
PAUSED = "paused"
|
34 |
+
FAILED = "failed"
|
35 |
+
SCALED_TO_ZERO = "scaledToZero"
|
36 |
+
|
37 |
+
|
38 |
+
class InferenceEndpointType(str, Enum):
|
39 |
+
PUBlIC = "public"
|
40 |
+
PROTECTED = "protected"
|
41 |
+
PRIVATE = "private"
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class InferenceEndpoint:
|
46 |
+
"""
|
47 |
+
Contains information about a deployed Inference Endpoint.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
name (`str`):
|
51 |
+
The unique name of the Inference Endpoint.
|
52 |
+
namespace (`str`):
|
53 |
+
The namespace where the Inference Endpoint is located.
|
54 |
+
repository (`str`):
|
55 |
+
The name of the model repository deployed on this Inference Endpoint.
|
56 |
+
status ([`InferenceEndpointStatus`]):
|
57 |
+
The current status of the Inference Endpoint.
|
58 |
+
url (`str`, *optional*):
|
59 |
+
The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL.
|
60 |
+
framework (`str`):
|
61 |
+
The machine learning framework used for the model.
|
62 |
+
revision (`str`):
|
63 |
+
The specific model revision deployed on the Inference Endpoint.
|
64 |
+
task (`str`):
|
65 |
+
The task associated with the deployed model.
|
66 |
+
created_at (`datetime.datetime`):
|
67 |
+
The timestamp when the Inference Endpoint was created.
|
68 |
+
updated_at (`datetime.datetime`):
|
69 |
+
The timestamp of the last update of the Inference Endpoint.
|
70 |
+
type ([`InferenceEndpointType`]):
|
71 |
+
The type of the Inference Endpoint (public, protected, private).
|
72 |
+
raw (`Dict`):
|
73 |
+
The raw dictionary data returned from the API.
|
74 |
+
token (`str`, *optional*):
|
75 |
+
Authentication token for the Inference Endpoint, if set when requesting the API.
|
76 |
+
|
77 |
+
Example:
|
78 |
+
```python
|
79 |
+
>>> from huggingface_hub import get_inference_endpoint
|
80 |
+
>>> endpoint = get_inference_endpoint("my-text-to-image")
|
81 |
+
>>> endpoint
|
82 |
+
InferenceEndpoint(name='my-text-to-image', ...)
|
83 |
+
|
84 |
+
# Get status
|
85 |
+
>>> endpoint.status
|
86 |
+
'running'
|
87 |
+
>>> endpoint.url
|
88 |
+
'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud'
|
89 |
+
|
90 |
+
# Run inference
|
91 |
+
>>> endpoint.client.text_to_image(...)
|
92 |
+
|
93 |
+
# Pause endpoint to save $$$
|
94 |
+
>>> endpoint.pause()
|
95 |
+
|
96 |
+
# ...
|
97 |
+
# Resume and wait for deployment
|
98 |
+
>>> endpoint.resume()
|
99 |
+
>>> endpoint.wait()
|
100 |
+
>>> endpoint.client.text_to_image(...)
|
101 |
+
```
|
102 |
+
"""
|
103 |
+
|
104 |
+
# Field in __repr__
|
105 |
+
name: str = field(init=False)
|
106 |
+
namespace: str
|
107 |
+
repository: str = field(init=False)
|
108 |
+
status: InferenceEndpointStatus = field(init=False)
|
109 |
+
url: Optional[str] = field(init=False)
|
110 |
+
|
111 |
+
# Other fields
|
112 |
+
framework: str = field(repr=False, init=False)
|
113 |
+
revision: str = field(repr=False, init=False)
|
114 |
+
task: str = field(repr=False, init=False)
|
115 |
+
created_at: datetime = field(repr=False, init=False)
|
116 |
+
updated_at: datetime = field(repr=False, init=False)
|
117 |
+
type: InferenceEndpointType = field(repr=False, init=False)
|
118 |
+
|
119 |
+
# Raw dict from the API
|
120 |
+
raw: Dict = field(repr=False)
|
121 |
+
|
122 |
+
# Internal fields
|
123 |
+
_token: Optional[str] = field(repr=False, compare=False)
|
124 |
+
_api: "HfApi" = field(repr=False, compare=False)
|
125 |
+
|
126 |
+
@classmethod
|
127 |
+
def from_raw(
|
128 |
+
cls, raw: Dict, namespace: str, token: Optional[str] = None, api: Optional["HfApi"] = None
|
129 |
+
) -> "InferenceEndpoint":
|
130 |
+
"""Initialize object from raw dictionary."""
|
131 |
+
if api is None:
|
132 |
+
from .hf_api import HfApi
|
133 |
+
|
134 |
+
api = HfApi()
|
135 |
+
if token is None:
|
136 |
+
token = api.token
|
137 |
+
|
138 |
+
# All other fields are populated in __post_init__
|
139 |
+
return cls(raw=raw, namespace=namespace, _token=token, _api=api)
|
140 |
+
|
141 |
+
def __post_init__(self) -> None:
|
142 |
+
"""Populate fields from raw dictionary."""
|
143 |
+
self._populate_from_raw()
|
144 |
+
|
145 |
+
@property
|
146 |
+
def client(self) -> InferenceClient:
|
147 |
+
"""Returns a client to make predictions on this Inference Endpoint.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
[`InferenceClient`]: an inference client pointing to the deployed endpoint.
|
151 |
+
|
152 |
+
Raises:
|
153 |
+
[`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
|
154 |
+
"""
|
155 |
+
if self.url is None:
|
156 |
+
raise InferenceEndpointError(
|
157 |
+
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
|
158 |
+
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
|
159 |
+
)
|
160 |
+
return InferenceClient(model=self.url, token=self._token)
|
161 |
+
|
162 |
+
@property
|
163 |
+
def async_client(self) -> AsyncInferenceClient:
|
164 |
+
"""Returns a client to make predictions on this Inference Endpoint.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
[`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint.
|
168 |
+
|
169 |
+
Raises:
|
170 |
+
[`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed.
|
171 |
+
"""
|
172 |
+
if self.url is None:
|
173 |
+
raise InferenceEndpointError(
|
174 |
+
"Cannot create a client for this Inference Endpoint as it is not yet deployed. "
|
175 |
+
"Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again."
|
176 |
+
)
|
177 |
+
return AsyncInferenceClient(model=self.url, token=self._token)
|
178 |
+
|
179 |
+
def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint":
|
180 |
+
"""Wait for the Inference Endpoint to be deployed.
|
181 |
+
|
182 |
+
Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout`
|
183 |
+
seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest
|
184 |
+
data.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
timeout (`int`, *optional*):
|
188 |
+
The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait
|
189 |
+
indefinitely.
|
190 |
+
refresh_every (`int`, *optional*):
|
191 |
+
The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
195 |
+
"""
|
196 |
+
if self.url is not None: # Means the endpoint is deployed
|
197 |
+
logger.info("Inference Endpoint is ready to be used.")
|
198 |
+
return self
|
199 |
+
|
200 |
+
if timeout is not None and timeout < 0:
|
201 |
+
raise ValueError("`timeout` cannot be negative.")
|
202 |
+
if refresh_every <= 0:
|
203 |
+
raise ValueError("`refresh_every` must be positive.")
|
204 |
+
|
205 |
+
start = time.time()
|
206 |
+
while True:
|
207 |
+
self.fetch()
|
208 |
+
if self.url is not None: # Means the endpoint is deployed
|
209 |
+
logger.info("Inference Endpoint is ready to be used.")
|
210 |
+
return self
|
211 |
+
if timeout is not None:
|
212 |
+
if time.time() - start > timeout:
|
213 |
+
raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.")
|
214 |
+
logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...")
|
215 |
+
time.sleep(refresh_every)
|
216 |
+
|
217 |
+
def fetch(self) -> "InferenceEndpoint":
|
218 |
+
"""Fetch latest information about the Inference Endpoint.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
222 |
+
"""
|
223 |
+
obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
|
224 |
+
self.raw = obj.raw
|
225 |
+
self._populate_from_raw()
|
226 |
+
return self
|
227 |
+
|
228 |
+
def update(
|
229 |
+
self,
|
230 |
+
*,
|
231 |
+
# Compute update
|
232 |
+
accelerator: Optional[str] = None,
|
233 |
+
instance_size: Optional[str] = None,
|
234 |
+
instance_type: Optional[str] = None,
|
235 |
+
min_replica: Optional[int] = None,
|
236 |
+
max_replica: Optional[int] = None,
|
237 |
+
# Model update
|
238 |
+
repository: Optional[str] = None,
|
239 |
+
framework: Optional[str] = None,
|
240 |
+
revision: Optional[str] = None,
|
241 |
+
task: Optional[str] = None,
|
242 |
+
) -> "InferenceEndpoint":
|
243 |
+
"""Update the Inference Endpoint.
|
244 |
+
|
245 |
+
This method allows the update of either the compute configuration, the deployed model, or both. All arguments are
|
246 |
+
optional but at least one must be provided.
|
247 |
+
|
248 |
+
This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the
|
249 |
+
latest data from the server.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
accelerator (`str`, *optional*):
|
253 |
+
The hardware accelerator to be used for inference (e.g. `"cpu"`).
|
254 |
+
instance_size (`str`, *optional*):
|
255 |
+
The size or type of the instance to be used for hosting the model (e.g. `"large"`).
|
256 |
+
instance_type (`str`, *optional*):
|
257 |
+
The cloud instance type where the Inference Endpoint will be deployed (e.g. `"c6i"`).
|
258 |
+
min_replica (`int`, *optional*):
|
259 |
+
The minimum number of replicas (instances) to keep running for the Inference Endpoint.
|
260 |
+
max_replica (`int`, *optional*):
|
261 |
+
The maximum number of replicas (instances) to scale to for the Inference Endpoint.
|
262 |
+
|
263 |
+
repository (`str`, *optional*):
|
264 |
+
The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
|
265 |
+
framework (`str`, *optional*):
|
266 |
+
The machine learning framework used for the model (e.g. `"custom"`).
|
267 |
+
revision (`str`, *optional*):
|
268 |
+
The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`).
|
269 |
+
task (`str`, *optional*):
|
270 |
+
The task on which to deploy the model (e.g. `"text-classification"`).
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
274 |
+
"""
|
275 |
+
# Make API call
|
276 |
+
obj = self._api.update_inference_endpoint(
|
277 |
+
name=self.name,
|
278 |
+
namespace=self.namespace,
|
279 |
+
accelerator=accelerator,
|
280 |
+
instance_size=instance_size,
|
281 |
+
instance_type=instance_type,
|
282 |
+
min_replica=min_replica,
|
283 |
+
max_replica=max_replica,
|
284 |
+
repository=repository,
|
285 |
+
framework=framework,
|
286 |
+
revision=revision,
|
287 |
+
task=task,
|
288 |
+
token=self._token,
|
289 |
+
)
|
290 |
+
|
291 |
+
# Mutate current object
|
292 |
+
self.raw = obj.raw
|
293 |
+
self._populate_from_raw()
|
294 |
+
return self
|
295 |
+
|
296 |
+
def pause(self) -> "InferenceEndpoint":
|
297 |
+
"""Pause the Inference Endpoint.
|
298 |
+
|
299 |
+
A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`].
|
300 |
+
This is different than scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which
|
301 |
+
would be automatically restarted when a request is made to it.
|
302 |
+
|
303 |
+
This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the
|
304 |
+
latest data from the server.
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
308 |
+
"""
|
309 |
+
obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
|
310 |
+
self.raw = obj.raw
|
311 |
+
self._populate_from_raw()
|
312 |
+
return self
|
313 |
+
|
314 |
+
def resume(self) -> "InferenceEndpoint":
|
315 |
+
"""Resume the Inference Endpoint.
|
316 |
+
|
317 |
+
This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the
|
318 |
+
latest data from the server.
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
322 |
+
"""
|
323 |
+
obj = self._api.resume_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
|
324 |
+
self.raw = obj.raw
|
325 |
+
self._populate_from_raw()
|
326 |
+
return self
|
327 |
+
|
328 |
+
def scale_to_zero(self) -> "InferenceEndpoint":
|
329 |
+
"""Scale Inference Endpoint to zero.
|
330 |
+
|
331 |
+
An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a
|
332 |
+
cold start delay. This is different than pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which
|
333 |
+
would require a manual resume with [`InferenceEndpoint.resume`].
|
334 |
+
|
335 |
+
This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the
|
336 |
+
latest data from the server.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
|
340 |
+
"""
|
341 |
+
obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
|
342 |
+
self.raw = obj.raw
|
343 |
+
self._populate_from_raw()
|
344 |
+
return self
|
345 |
+
|
346 |
+
def delete(self) -> None:
|
347 |
+
"""Delete the Inference Endpoint.
|
348 |
+
|
349 |
+
This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable
|
350 |
+
to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`].
|
351 |
+
|
352 |
+
This is an alias for [`HfApi.delete_inference_endpoint`].
|
353 |
+
"""
|
354 |
+
self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
|
355 |
+
|
356 |
+
def _populate_from_raw(self) -> None:
|
357 |
+
"""Populate fields from raw dictionary.
|
358 |
+
|
359 |
+
Called in __post_init__ + each time the Inference Endpoint is updated.
|
360 |
+
"""
|
361 |
+
# Repr fields
|
362 |
+
self.name = self.raw["name"]
|
363 |
+
self.repository = self.raw["model"]["repository"]
|
364 |
+
self.status = self.raw["status"]["state"]
|
365 |
+
self.url = self.raw["status"].get("url")
|
366 |
+
|
367 |
+
# Other fields
|
368 |
+
self.framework = self.raw["model"]["framework"]
|
369 |
+
self.revision = self.raw["model"]["revision"]
|
370 |
+
self.task = self.raw["model"]["task"]
|
371 |
+
self.created_at = parse_datetime(self.raw["status"]["createdAt"])
|
372 |
+
self.updated_at = parse_datetime(self.raw["status"]["updatedAt"])
|
373 |
+
self.type = self.raw["type"]
|
lib/python3.11/site-packages/huggingface_hub/_login.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Contains methods to login to the Hub."""
|
15 |
+
import os
|
16 |
+
import subprocess
|
17 |
+
from functools import partial
|
18 |
+
from getpass import getpass
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import Optional
|
21 |
+
|
22 |
+
from . import constants
|
23 |
+
from .commands._cli_utils import ANSI
|
24 |
+
from .utils import (
|
25 |
+
capture_output,
|
26 |
+
get_token,
|
27 |
+
is_google_colab,
|
28 |
+
is_notebook,
|
29 |
+
list_credential_helpers,
|
30 |
+
logging,
|
31 |
+
run_subprocess,
|
32 |
+
set_git_credential,
|
33 |
+
unset_git_credential,
|
34 |
+
)
|
35 |
+
from .utils._token import _get_token_from_environment, _get_token_from_google_colab
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__)
|
39 |
+
|
40 |
+
_HF_LOGO_ASCII = """
|
41 |
+
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
42 |
+
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
43 |
+
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
|
44 |
+
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
45 |
+
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
|
46 |
+
"""
|
47 |
+
|
48 |
+
|
49 |
+
def login(
|
50 |
+
token: Optional[str] = None,
|
51 |
+
add_to_git_credential: bool = False,
|
52 |
+
new_session: bool = True,
|
53 |
+
write_permission: bool = False,
|
54 |
+
) -> None:
|
55 |
+
"""Login the machine to access the Hub.
|
56 |
+
|
57 |
+
The `token` is persisted in cache and set as a git credential. Once done, the machine
|
58 |
+
is logged in and the access token will be available across all `huggingface_hub`
|
59 |
+
components. If `token` is not provided, it will be prompted to the user either with
|
60 |
+
a widget (in a notebook) or via the terminal.
|
61 |
+
|
62 |
+
To login from outside of a script, one can also use `huggingface-cli login` which is
|
63 |
+
a cli command that wraps [`login`].
|
64 |
+
|
65 |
+
<Tip>
|
66 |
+
|
67 |
+
[`login`] is a drop-in replacement method for [`notebook_login`] as it wraps and
|
68 |
+
extends its capabilities.
|
69 |
+
|
70 |
+
</Tip>
|
71 |
+
|
72 |
+
<Tip>
|
73 |
+
|
74 |
+
When the token is not passed, [`login`] will automatically detect if the script runs
|
75 |
+
in a notebook or not. However, this detection might not be accurate due to the
|
76 |
+
variety of notebooks that exists nowadays. If that is the case, you can always force
|
77 |
+
the UI by using [`notebook_login`] or [`interpreter_login`].
|
78 |
+
|
79 |
+
</Tip>
|
80 |
+
|
81 |
+
Args:
|
82 |
+
token (`str`, *optional*):
|
83 |
+
User access token to generate from https://huggingface.co/settings/token.
|
84 |
+
add_to_git_credential (`bool`, defaults to `False`):
|
85 |
+
If `True`, token will be set as git credential. If no git credential helper
|
86 |
+
is configured, a warning will be displayed to the user. If `token` is `None`,
|
87 |
+
the value of `add_to_git_credential` is ignored and will be prompted again
|
88 |
+
to the end user.
|
89 |
+
new_session (`bool`, defaults to `True`):
|
90 |
+
If `True`, will request a token even if one is already saved on the machine.
|
91 |
+
write_permission (`bool`, defaults to `False`):
|
92 |
+
If `True`, requires a token with write permission.
|
93 |
+
Raises:
|
94 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
95 |
+
If an organization token is passed. Only personal account tokens are valid
|
96 |
+
to login.
|
97 |
+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
98 |
+
If token is invalid.
|
99 |
+
[`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
|
100 |
+
If running in a notebook but `ipywidgets` is not installed.
|
101 |
+
"""
|
102 |
+
if token is not None:
|
103 |
+
if not add_to_git_credential:
|
104 |
+
print(
|
105 |
+
"Token will not been saved to git credential helper. Pass"
|
106 |
+
" `add_to_git_credential=True` if you want to set the git"
|
107 |
+
" credential as well."
|
108 |
+
)
|
109 |
+
_login(token, add_to_git_credential=add_to_git_credential, write_permission=write_permission)
|
110 |
+
elif is_notebook():
|
111 |
+
notebook_login(new_session=new_session, write_permission=write_permission)
|
112 |
+
else:
|
113 |
+
interpreter_login(new_session=new_session, write_permission=write_permission)
|
114 |
+
|
115 |
+
|
116 |
+
def logout() -> None:
|
117 |
+
"""Logout the machine from the Hub.
|
118 |
+
|
119 |
+
Token is deleted from the machine and removed from git credential.
|
120 |
+
"""
|
121 |
+
if get_token() is None:
|
122 |
+
print("Not logged in!")
|
123 |
+
return
|
124 |
+
|
125 |
+
# Delete token from git credentials
|
126 |
+
unset_git_credential()
|
127 |
+
|
128 |
+
# Delete token file
|
129 |
+
try:
|
130 |
+
Path(constants.HF_TOKEN_PATH).unlink()
|
131 |
+
except FileNotFoundError:
|
132 |
+
pass
|
133 |
+
|
134 |
+
# Check if still logged in
|
135 |
+
if _get_token_from_google_colab() is not None:
|
136 |
+
raise EnvironmentError(
|
137 |
+
"You are automatically logged in using a Google Colab secret.\n"
|
138 |
+
"To log out, you must unset the `HF_TOKEN` secret in your Colab settings."
|
139 |
+
)
|
140 |
+
if _get_token_from_environment() is not None:
|
141 |
+
raise EnvironmentError(
|
142 |
+
"Token has been deleted from your machine but you are still logged in.\n"
|
143 |
+
"To log out, you must clear out both `HF_TOKEN` and `HUGGING_FACE_HUB_TOKEN` environment variables."
|
144 |
+
)
|
145 |
+
|
146 |
+
print("Successfully logged out.")
|
147 |
+
|
148 |
+
|
149 |
+
###
|
150 |
+
# Interpreter-based login (text)
|
151 |
+
###
|
152 |
+
|
153 |
+
|
154 |
+
def interpreter_login(new_session: bool = True, write_permission: bool = False) -> None:
|
155 |
+
"""
|
156 |
+
Displays a prompt to login to the HF website and store the token.
|
157 |
+
|
158 |
+
This is equivalent to [`login`] without passing a token when not run in a notebook.
|
159 |
+
[`interpreter_login`] is useful if you want to force the use of the terminal prompt
|
160 |
+
instead of a notebook widget.
|
161 |
+
|
162 |
+
For more details, see [`login`].
|
163 |
+
|
164 |
+
Args:
|
165 |
+
new_session (`bool`, defaults to `True`):
|
166 |
+
If `True`, will request a token even if one is already saved on the machine.
|
167 |
+
write_permission (`bool`, defaults to `False`):
|
168 |
+
If `True`, requires a token with write permission.
|
169 |
+
|
170 |
+
"""
|
171 |
+
if not new_session and _current_token_okay(write_permission=write_permission):
|
172 |
+
print("User is already logged in.")
|
173 |
+
return
|
174 |
+
|
175 |
+
from .commands.delete_cache import _ask_for_confirmation_no_tui
|
176 |
+
|
177 |
+
print(_HF_LOGO_ASCII)
|
178 |
+
if get_token() is not None:
|
179 |
+
print(
|
180 |
+
" A token is already saved on your machine. Run `huggingface-cli"
|
181 |
+
" whoami` to get more information or `huggingface-cli logout` if you want"
|
182 |
+
" to log out."
|
183 |
+
)
|
184 |
+
print(" Setting a new token will erase the existing one.")
|
185 |
+
|
186 |
+
print(" To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .")
|
187 |
+
if os.name == "nt":
|
188 |
+
print("Token can be pasted using 'Right-Click'.")
|
189 |
+
token = getpass("Token: ")
|
190 |
+
add_to_git_credential = _ask_for_confirmation_no_tui("Add token as git credential?")
|
191 |
+
|
192 |
+
_login(token=token, add_to_git_credential=add_to_git_credential, write_permission=write_permission)
|
193 |
+
|
194 |
+
|
195 |
+
###
|
196 |
+
# Notebook-based login (widget)
|
197 |
+
###
|
198 |
+
|
199 |
+
NOTEBOOK_LOGIN_PASSWORD_HTML = """<center> <img
|
200 |
+
src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg
|
201 |
+
alt='Hugging Face'> <br> Immediately click login after typing your password or
|
202 |
+
it might be stored in plain text in this notebook file. </center>"""
|
203 |
+
|
204 |
+
|
205 |
+
NOTEBOOK_LOGIN_TOKEN_HTML_START = """<center> <img
|
206 |
+
src=https://huggingface.co/front/assets/huggingface_logo-noborder.svg
|
207 |
+
alt='Hugging Face'> <br> Copy a token from <a
|
208 |
+
href="https://huggingface.co/settings/tokens" target="_blank">your Hugging Face
|
209 |
+
tokens page</a> and paste it below. <br> Immediately click login after copying
|
210 |
+
your token or it might be stored in plain text in this notebook file. </center>"""
|
211 |
+
|
212 |
+
|
213 |
+
NOTEBOOK_LOGIN_TOKEN_HTML_END = """
|
214 |
+
<b>Pro Tip:</b> If you don't already have one, you can create a dedicated
|
215 |
+
'notebooks' token with 'write' access, that you can then easily reuse for all
|
216 |
+
notebooks. </center>"""
|
217 |
+
|
218 |
+
|
219 |
+
def notebook_login(new_session: bool = True, write_permission: bool = False) -> None:
|
220 |
+
"""
|
221 |
+
Displays a widget to login to the HF website and store the token.
|
222 |
+
|
223 |
+
This is equivalent to [`login`] without passing a token when run in a notebook.
|
224 |
+
[`notebook_login`] is useful if you want to force the use of the notebook widget
|
225 |
+
instead of a prompt in the terminal.
|
226 |
+
|
227 |
+
For more details, see [`login`].
|
228 |
+
|
229 |
+
Args:
|
230 |
+
new_session (`bool`, defaults to `True`):
|
231 |
+
If `True`, will request a token even if one is already saved on the machine.
|
232 |
+
write_permission (`bool`, defaults to `False`):
|
233 |
+
If `True`, requires a token with write permission.
|
234 |
+
"""
|
235 |
+
try:
|
236 |
+
import ipywidgets.widgets as widgets # type: ignore
|
237 |
+
from IPython.display import display # type: ignore
|
238 |
+
except ImportError:
|
239 |
+
raise ImportError(
|
240 |
+
"The `notebook_login` function can only be used in a notebook (Jupyter or"
|
241 |
+
" Colab) and you need the `ipywidgets` module: `pip install ipywidgets`."
|
242 |
+
)
|
243 |
+
if not new_session and _current_token_okay(write_permission=write_permission):
|
244 |
+
print("User is already logged in.")
|
245 |
+
return
|
246 |
+
|
247 |
+
box_layout = widgets.Layout(display="flex", flex_flow="column", align_items="center", width="50%")
|
248 |
+
|
249 |
+
token_widget = widgets.Password(description="Token:")
|
250 |
+
git_checkbox_widget = widgets.Checkbox(value=True, description="Add token as git credential?")
|
251 |
+
token_finish_button = widgets.Button(description="Login")
|
252 |
+
|
253 |
+
login_token_widget = widgets.VBox(
|
254 |
+
[
|
255 |
+
widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_START),
|
256 |
+
token_widget,
|
257 |
+
git_checkbox_widget,
|
258 |
+
token_finish_button,
|
259 |
+
widgets.HTML(NOTEBOOK_LOGIN_TOKEN_HTML_END),
|
260 |
+
],
|
261 |
+
layout=box_layout,
|
262 |
+
)
|
263 |
+
display(login_token_widget)
|
264 |
+
|
265 |
+
# On click events
|
266 |
+
def login_token_event(t, write_permission: bool = False):
|
267 |
+
"""
|
268 |
+
Event handler for the login button.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
write_permission (`bool`, defaults to `False`):
|
272 |
+
If `True`, requires a token with write permission.
|
273 |
+
"""
|
274 |
+
token = token_widget.value
|
275 |
+
add_to_git_credential = git_checkbox_widget.value
|
276 |
+
# Erase token and clear value to make sure it's not saved in the notebook.
|
277 |
+
token_widget.value = ""
|
278 |
+
# Hide inputs
|
279 |
+
login_token_widget.children = [widgets.Label("Connecting...")]
|
280 |
+
try:
|
281 |
+
with capture_output() as captured:
|
282 |
+
_login(token, add_to_git_credential=add_to_git_credential, write_permission=write_permission)
|
283 |
+
message = captured.getvalue()
|
284 |
+
except Exception as error:
|
285 |
+
message = str(error)
|
286 |
+
# Print result (success message or error)
|
287 |
+
login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()]
|
288 |
+
|
289 |
+
token_finish_button.on_click(partial(login_token_event, write_permission=write_permission))
|
290 |
+
|
291 |
+
|
292 |
+
###
|
293 |
+
# Login private helpers
|
294 |
+
###
|
295 |
+
|
296 |
+
|
297 |
+
def _login(token: str, add_to_git_credential: bool, write_permission: bool = False) -> None:
|
298 |
+
from .hf_api import get_token_permission # avoid circular import
|
299 |
+
|
300 |
+
if token.startswith("api_org"):
|
301 |
+
raise ValueError("You must use your personal account token, not an organization token.")
|
302 |
+
|
303 |
+
permission = get_token_permission(token)
|
304 |
+
if permission is None:
|
305 |
+
raise ValueError("Invalid token passed!")
|
306 |
+
elif write_permission and permission != "write":
|
307 |
+
raise ValueError(
|
308 |
+
"Token is valid but is 'read-only' and a 'write' token is required.\nPlease provide a new token with"
|
309 |
+
" correct permission."
|
310 |
+
)
|
311 |
+
print(f"Token is valid (permission: {permission}).")
|
312 |
+
|
313 |
+
if add_to_git_credential:
|
314 |
+
if _is_git_credential_helper_configured():
|
315 |
+
set_git_credential(token)
|
316 |
+
print(
|
317 |
+
"Your token has been saved in your configured git credential helpers"
|
318 |
+
+ f" ({','.join(list_credential_helpers())})."
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
print("Token has not been saved to git credential helper.")
|
322 |
+
|
323 |
+
# Save token
|
324 |
+
path = Path(constants.HF_TOKEN_PATH)
|
325 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
326 |
+
path.write_text(token)
|
327 |
+
print(f"Your token has been saved to {constants.HF_TOKEN_PATH}")
|
328 |
+
print("Login successful")
|
329 |
+
|
330 |
+
|
331 |
+
def _current_token_okay(write_permission: bool = False):
|
332 |
+
"""Check if the current token is valid.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
write_permission (`bool`, defaults to `False`):
|
336 |
+
If `True`, requires a token with write permission.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
`bool`: `True` if the current token is valid, `False` otherwise.
|
340 |
+
"""
|
341 |
+
from .hf_api import get_token_permission # avoid circular import
|
342 |
+
|
343 |
+
permission = get_token_permission()
|
344 |
+
if permission is None or (write_permission and permission != "write"):
|
345 |
+
return False
|
346 |
+
return True
|
347 |
+
|
348 |
+
|
349 |
+
def _is_git_credential_helper_configured() -> bool:
|
350 |
+
"""Check if a git credential helper is configured.
|
351 |
+
|
352 |
+
Warns user if not the case (except for Google Colab where "store" is set by default
|
353 |
+
by `huggingface_hub`).
|
354 |
+
"""
|
355 |
+
helpers = list_credential_helpers()
|
356 |
+
if len(helpers) > 0:
|
357 |
+
return True # Do not warn: at least 1 helper is set
|
358 |
+
|
359 |
+
# Only in Google Colab to avoid the warning message
|
360 |
+
# See https://github.com/huggingface/huggingface_hub/issues/1043#issuecomment-1247010710
|
361 |
+
if is_google_colab():
|
362 |
+
_set_store_as_git_credential_helper_globally()
|
363 |
+
return True # Do not warn: "store" is used by default in Google Colab
|
364 |
+
|
365 |
+
# Otherwise, warn user
|
366 |
+
print(
|
367 |
+
ANSI.red(
|
368 |
+
"Cannot authenticate through git-credential as no helper is defined on your"
|
369 |
+
" machine.\nYou might have to re-authenticate when pushing to the Hugging"
|
370 |
+
" Face Hub.\nRun the following command in your terminal in case you want to"
|
371 |
+
" set the 'store' credential helper as default.\n\ngit config --global"
|
372 |
+
" credential.helper store\n\nRead"
|
373 |
+
" https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more"
|
374 |
+
" details."
|
375 |
+
)
|
376 |
+
)
|
377 |
+
return False
|
378 |
+
|
379 |
+
|
380 |
+
def _set_store_as_git_credential_helper_globally() -> None:
|
381 |
+
"""Set globally the credential.helper to `store`.
|
382 |
+
|
383 |
+
To be used only in Google Colab as we assume the user doesn't care about the git
|
384 |
+
credential config. It is the only particular case where we don't want to display the
|
385 |
+
warning message in [`notebook_login()`].
|
386 |
+
|
387 |
+
Related:
|
388 |
+
- https://github.com/huggingface/huggingface_hub/issues/1043
|
389 |
+
- https://github.com/huggingface/huggingface_hub/issues/1051
|
390 |
+
- https://git-scm.com/docs/git-credential-store
|
391 |
+
"""
|
392 |
+
try:
|
393 |
+
run_subprocess("git config --global credential.helper store")
|
394 |
+
except subprocess.CalledProcessError as exc:
|
395 |
+
raise EnvironmentError(exc.stderr)
|
lib/python3.11/site-packages/huggingface_hub/_multi_commits.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Contains utilities to multi-commits (i.e. push changes iteratively on a PR)."""
|
16 |
+
import re
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple, Union
|
19 |
+
|
20 |
+
from ._commit_api import CommitOperationAdd, CommitOperationDelete
|
21 |
+
from .community import DiscussionWithDetails
|
22 |
+
from .utils import experimental
|
23 |
+
from .utils._cache_manager import _format_size
|
24 |
+
from .utils.insecure_hashlib import sha256
|
25 |
+
|
26 |
+
|
27 |
+
if TYPE_CHECKING:
|
28 |
+
from .hf_api import HfApi
|
29 |
+
|
30 |
+
|
31 |
+
class MultiCommitException(Exception):
|
32 |
+
"""Base exception for any exception happening while doing a multi-commit."""
|
33 |
+
|
34 |
+
|
35 |
+
MULTI_COMMIT_PR_DESCRIPTION_TEMPLATE = """
|
36 |
+
## {commit_message}
|
37 |
+
|
38 |
+
{commit_description}
|
39 |
+
|
40 |
+
**Multi commit ID:** {multi_commit_id}
|
41 |
+
|
42 |
+
Scheduled commits:
|
43 |
+
|
44 |
+
{multi_commit_strategy}
|
45 |
+
|
46 |
+
_This is a PR opened using the `huggingface_hub` library in the context of a multi-commit. PR can be commented as a usual PR. However, please be aware that manually updating the PR description, changing the PR status, or pushing new commits, is not recommended as it might corrupt the commit process. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
|
47 |
+
"""
|
48 |
+
|
49 |
+
MULTI_COMMIT_PR_COMPLETION_COMMENT_TEMPLATE = """
|
50 |
+
Multi-commit is now completed! You can ping the repo owner to review the changes. This PR can now be commented or modified without risking to corrupt it.
|
51 |
+
|
52 |
+
_This is a comment posted using the `huggingface_hub` library in the context of a multi-commit. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
|
53 |
+
"""
|
54 |
+
|
55 |
+
MULTI_COMMIT_PR_CLOSING_COMMENT_TEMPLATE = """
|
56 |
+
`create_pr=False` has been passed so PR is automatically merged.
|
57 |
+
|
58 |
+
_This is a comment posted using the `huggingface_hub` library in the context of a multi-commit. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
|
59 |
+
"""
|
60 |
+
|
61 |
+
MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_NO_CHANGES_TEMPLATE = """
|
62 |
+
Cannot merge Pull Requests as no changes are associated. This PR will be closed automatically.
|
63 |
+
|
64 |
+
_This is a comment posted using the `huggingface_hub` library in the context of a multi-commit. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
|
65 |
+
"""
|
66 |
+
|
67 |
+
MULTI_COMMIT_PR_CLOSE_COMMENT_FAILURE_BAD_REQUEST_TEMPLATE = """
|
68 |
+
An error occurred while trying to merge the Pull Request: `{error_message}`.
|
69 |
+
|
70 |
+
_This is a comment posted using the `huggingface_hub` library in the context of a multi-commit. Learn more about multi-commits [in this guide](https://huggingface.co/docs/huggingface_hub/main/guides/upload)._
|
71 |
+
"""
|
72 |
+
|
73 |
+
|
74 |
+
STEP_ID_REGEX = re.compile(r"- \[(?P<completed>[ |x])\].*(?P<step_id>[a-fA-F0-9]{64})", flags=re.MULTILINE)
|
75 |
+
|
76 |
+
|
77 |
+
@experimental
|
78 |
+
def plan_multi_commits(
|
79 |
+
operations: Iterable[Union[CommitOperationAdd, CommitOperationDelete]],
|
80 |
+
max_operations_per_commit: int = 50,
|
81 |
+
max_upload_size_per_commit: int = 2 * 1024 * 1024 * 1024,
|
82 |
+
) -> Tuple[List[List[CommitOperationAdd]], List[List[CommitOperationDelete]]]:
|
83 |
+
"""Split a list of operations in a list of commits to perform.
|
84 |
+
|
85 |
+
Implementation follows a sub-optimal (yet simple) algorithm:
|
86 |
+
1. Delete operations are grouped together by commits of maximum `max_operations_per_commits` operations.
|
87 |
+
2. All additions exceeding `max_upload_size_per_commit` are committed 1 by 1.
|
88 |
+
3. All remaining additions are grouped together and split each time the `max_operations_per_commit` or the
|
89 |
+
`max_upload_size_per_commit` limit is reached.
|
90 |
+
|
91 |
+
We do not try to optimize the splitting to get the lowest number of commits as this is a NP-hard problem (see
|
92 |
+
[bin packing problem](https://en.wikipedia.org/wiki/Bin_packing_problem)). For our use case, it is not problematic
|
93 |
+
to use a sub-optimal solution so we favored an easy-to-explain implementation.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
operations (`List` of [`~hf_api.CommitOperation`]):
|
97 |
+
The list of operations to split into commits.
|
98 |
+
max_operations_per_commit (`int`):
|
99 |
+
Maximum number of operations in a single commit. Defaults to 50.
|
100 |
+
max_upload_size_per_commit (`int`):
|
101 |
+
Maximum size to upload (in bytes) in a single commit. Defaults to 2GB. Files bigger than this limit are
|
102 |
+
uploaded, 1 per commit.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
`Tuple[List[List[CommitOperationAdd]], List[List[CommitOperationDelete]]]`: a tuple. First item is a list of
|
106 |
+
lists of [`CommitOperationAdd`] representing the addition commits to push. The second item is a list of lists
|
107 |
+
of [`CommitOperationDelete`] representing the deletion commits.
|
108 |
+
|
109 |
+
<Tip warning={true}>
|
110 |
+
|
111 |
+
`plan_multi_commits` is experimental. Its API and behavior is subject to change in the future without prior notice.
|
112 |
+
|
113 |
+
</Tip>
|
114 |
+
|
115 |
+
Example:
|
116 |
+
```python
|
117 |
+
>>> from huggingface_hub import HfApi, plan_multi_commits
|
118 |
+
>>> addition_commits, deletion_commits = plan_multi_commits(
|
119 |
+
... operations=[
|
120 |
+
... CommitOperationAdd(...),
|
121 |
+
... CommitOperationAdd(...),
|
122 |
+
... CommitOperationDelete(...),
|
123 |
+
... CommitOperationDelete(...),
|
124 |
+
... CommitOperationAdd(...),
|
125 |
+
... ],
|
126 |
+
... )
|
127 |
+
>>> HfApi().create_commits_on_pr(
|
128 |
+
... repo_id="my-cool-model",
|
129 |
+
... addition_commits=addition_commits,
|
130 |
+
... deletion_commits=deletion_commits,
|
131 |
+
... (...)
|
132 |
+
... verbose=True,
|
133 |
+
... )
|
134 |
+
```
|
135 |
+
|
136 |
+
<Tip warning={true}>
|
137 |
+
|
138 |
+
The initial order of the operations is not guaranteed! All deletions will be performed before additions. If you are
|
139 |
+
not updating multiple times the same file, you are fine.
|
140 |
+
|
141 |
+
</Tip>
|
142 |
+
"""
|
143 |
+
addition_commits: List[List[CommitOperationAdd]] = []
|
144 |
+
deletion_commits: List[List[CommitOperationDelete]] = []
|
145 |
+
|
146 |
+
additions: List[CommitOperationAdd] = []
|
147 |
+
additions_size = 0
|
148 |
+
deletions: List[CommitOperationDelete] = []
|
149 |
+
for op in operations:
|
150 |
+
if isinstance(op, CommitOperationDelete):
|
151 |
+
# Group delete operations together
|
152 |
+
deletions.append(op)
|
153 |
+
if len(deletions) >= max_operations_per_commit:
|
154 |
+
deletion_commits.append(deletions)
|
155 |
+
deletions = []
|
156 |
+
|
157 |
+
elif op.upload_info.size >= max_upload_size_per_commit:
|
158 |
+
# Upload huge files 1 by 1
|
159 |
+
addition_commits.append([op])
|
160 |
+
|
161 |
+
elif additions_size + op.upload_info.size < max_upload_size_per_commit:
|
162 |
+
# Group other additions and split if size limit is reached (either max_nb_files or max_upload_size)
|
163 |
+
additions.append(op)
|
164 |
+
additions_size += op.upload_info.size
|
165 |
+
|
166 |
+
else:
|
167 |
+
addition_commits.append(additions)
|
168 |
+
additions = [op]
|
169 |
+
additions_size = op.upload_info.size
|
170 |
+
|
171 |
+
if len(additions) >= max_operations_per_commit:
|
172 |
+
addition_commits.append(additions)
|
173 |
+
additions = []
|
174 |
+
additions_size = 0
|
175 |
+
|
176 |
+
if len(additions) > 0:
|
177 |
+
addition_commits.append(additions)
|
178 |
+
if len(deletions) > 0:
|
179 |
+
deletion_commits.append(deletions)
|
180 |
+
|
181 |
+
return addition_commits, deletion_commits
|
182 |
+
|
183 |
+
|
184 |
+
@dataclass
|
185 |
+
class MultiCommitStep:
|
186 |
+
"""Dataclass containing a list of CommitOperation to commit at once.
|
187 |
+
|
188 |
+
A [`MultiCommitStep`] is one atomic part of a [`MultiCommitStrategy`]. Each step is identified by its own
|
189 |
+
deterministic ID based on the list of commit operations (hexadecimal sha256). ID is persistent between re-runs if
|
190 |
+
the list of commits is kept the same.
|
191 |
+
"""
|
192 |
+
|
193 |
+
operations: List[Union[CommitOperationAdd, CommitOperationDelete]]
|
194 |
+
|
195 |
+
id: str = field(init=False)
|
196 |
+
completed: bool = False
|
197 |
+
|
198 |
+
def __post_init__(self) -> None:
|
199 |
+
if len(self.operations) == 0:
|
200 |
+
raise ValueError("A MultiCommitStep must have at least 1 commit operation, got 0.")
|
201 |
+
|
202 |
+
# Generate commit id
|
203 |
+
sha = sha256()
|
204 |
+
for op in self.operations:
|
205 |
+
if isinstance(op, CommitOperationAdd):
|
206 |
+
sha.update(b"ADD")
|
207 |
+
sha.update(op.path_in_repo.encode())
|
208 |
+
sha.update(op.upload_info.sha256)
|
209 |
+
elif isinstance(op, CommitOperationDelete):
|
210 |
+
sha.update(b"DELETE")
|
211 |
+
sha.update(op.path_in_repo.encode())
|
212 |
+
sha.update(str(op.is_folder).encode())
|
213 |
+
else:
|
214 |
+
NotImplementedError()
|
215 |
+
self.id = sha.hexdigest()
|
216 |
+
|
217 |
+
def __str__(self) -> str:
|
218 |
+
"""Format a step for PR description.
|
219 |
+
|
220 |
+
Formatting can be changed in the future as long as it is single line, starts with `- [ ]`/`- [x]` and contains
|
221 |
+
`self.id`. Must be able to match `STEP_ID_REGEX`.
|
222 |
+
"""
|
223 |
+
additions = [op for op in self.operations if isinstance(op, CommitOperationAdd)]
|
224 |
+
file_deletions = [op for op in self.operations if isinstance(op, CommitOperationDelete) and not op.is_folder]
|
225 |
+
folder_deletions = [op for op in self.operations if isinstance(op, CommitOperationDelete) and op.is_folder]
|
226 |
+
if len(additions) > 0:
|
227 |
+
return (
|
228 |
+
f"- [{'x' if self.completed else ' '}] Upload {len(additions)} file(s) "
|
229 |
+
f"totalling {_format_size(sum(add.upload_info.size for add in additions))}"
|
230 |
+
f" ({self.id})"
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
return (
|
234 |
+
f"- [{'x' if self.completed else ' '}] Delete {len(file_deletions)} file(s) and"
|
235 |
+
f" {len(folder_deletions)} folder(s) ({self.id})"
|
236 |
+
)
|
237 |
+
|
238 |
+
|
239 |
+
@dataclass
|
240 |
+
class MultiCommitStrategy:
|
241 |
+
"""Dataclass containing a list of [`MultiCommitStep`] to commit iteratively.
|
242 |
+
|
243 |
+
A strategy is identified by its own deterministic ID based on the list of its steps (hexadecimal sha256). ID is
|
244 |
+
persistent between re-runs if the list of commits is kept the same.
|
245 |
+
"""
|
246 |
+
|
247 |
+
addition_commits: List[MultiCommitStep]
|
248 |
+
deletion_commits: List[MultiCommitStep]
|
249 |
+
|
250 |
+
id: str = field(init=False)
|
251 |
+
all_steps: Set[str] = field(init=False)
|
252 |
+
|
253 |
+
def __post_init__(self) -> None:
|
254 |
+
self.all_steps = {step.id for step in self.addition_commits + self.deletion_commits}
|
255 |
+
if len(self.all_steps) < len(self.addition_commits) + len(self.deletion_commits):
|
256 |
+
raise ValueError("Got duplicate commits in MultiCommitStrategy. All commits must be unique.")
|
257 |
+
|
258 |
+
if len(self.all_steps) == 0:
|
259 |
+
raise ValueError("A MultiCommitStrategy must have at least 1 commit, got 0.")
|
260 |
+
|
261 |
+
# Generate strategy id
|
262 |
+
sha = sha256()
|
263 |
+
for step in self.addition_commits + self.deletion_commits:
|
264 |
+
sha.update("new step".encode())
|
265 |
+
sha.update(step.id.encode())
|
266 |
+
self.id = sha.hexdigest()
|
267 |
+
|
268 |
+
|
269 |
+
def multi_commit_create_pull_request(
|
270 |
+
api: "HfApi",
|
271 |
+
repo_id: str,
|
272 |
+
commit_message: str,
|
273 |
+
commit_description: Optional[str],
|
274 |
+
strategy: MultiCommitStrategy,
|
275 |
+
token: Optional[str],
|
276 |
+
repo_type: Optional[str],
|
277 |
+
) -> DiscussionWithDetails:
|
278 |
+
return api.create_pull_request(
|
279 |
+
repo_id=repo_id,
|
280 |
+
title=f"[WIP] {commit_message} (multi-commit {strategy.id})",
|
281 |
+
description=multi_commit_generate_comment(
|
282 |
+
commit_message=commit_message, commit_description=commit_description, strategy=strategy
|
283 |
+
),
|
284 |
+
token=token,
|
285 |
+
repo_type=repo_type,
|
286 |
+
)
|
287 |
+
|
288 |
+
|
289 |
+
def multi_commit_generate_comment(
|
290 |
+
commit_message: str,
|
291 |
+
commit_description: Optional[str],
|
292 |
+
strategy: MultiCommitStrategy,
|
293 |
+
) -> str:
|
294 |
+
return MULTI_COMMIT_PR_DESCRIPTION_TEMPLATE.format(
|
295 |
+
commit_message=commit_message,
|
296 |
+
commit_description=commit_description or "",
|
297 |
+
multi_commit_id=strategy.id,
|
298 |
+
multi_commit_strategy="\n".join(
|
299 |
+
str(commit) for commit in strategy.deletion_commits + strategy.addition_commits
|
300 |
+
),
|
301 |
+
)
|
302 |
+
|
303 |
+
|
304 |
+
def multi_commit_parse_pr_description(description: str) -> Set[str]:
|
305 |
+
return {match[1] for match in STEP_ID_REGEX.findall(description)}
|
lib/python3.11/site-packages/huggingface_hub/_snapshot_download.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict, List, Literal, Optional, Union
|
4 |
+
|
5 |
+
import requests
|
6 |
+
from tqdm.auto import tqdm as base_tqdm
|
7 |
+
from tqdm.contrib.concurrent import thread_map
|
8 |
+
|
9 |
+
from .constants import (
|
10 |
+
DEFAULT_ETAG_TIMEOUT,
|
11 |
+
DEFAULT_REVISION,
|
12 |
+
HF_HUB_CACHE,
|
13 |
+
HF_HUB_ENABLE_HF_TRANSFER,
|
14 |
+
REPO_TYPES,
|
15 |
+
)
|
16 |
+
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
|
17 |
+
from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
|
18 |
+
from .utils import (
|
19 |
+
GatedRepoError,
|
20 |
+
LocalEntryNotFoundError,
|
21 |
+
OfflineModeIsEnabled,
|
22 |
+
RepositoryNotFoundError,
|
23 |
+
RevisionNotFoundError,
|
24 |
+
filter_repo_objects,
|
25 |
+
logging,
|
26 |
+
validate_hf_hub_args,
|
27 |
+
)
|
28 |
+
from .utils import tqdm as hf_tqdm
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
@validate_hf_hub_args
|
35 |
+
def snapshot_download(
|
36 |
+
repo_id: str,
|
37 |
+
*,
|
38 |
+
repo_type: Optional[str] = None,
|
39 |
+
revision: Optional[str] = None,
|
40 |
+
cache_dir: Union[str, Path, None] = None,
|
41 |
+
local_dir: Union[str, Path, None] = None,
|
42 |
+
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
43 |
+
library_name: Optional[str] = None,
|
44 |
+
library_version: Optional[str] = None,
|
45 |
+
user_agent: Optional[Union[Dict, str]] = None,
|
46 |
+
proxies: Optional[Dict] = None,
|
47 |
+
etag_timeout: float = DEFAULT_ETAG_TIMEOUT,
|
48 |
+
resume_download: bool = False,
|
49 |
+
force_download: bool = False,
|
50 |
+
token: Optional[Union[bool, str]] = None,
|
51 |
+
local_files_only: bool = False,
|
52 |
+
allow_patterns: Optional[Union[List[str], str]] = None,
|
53 |
+
ignore_patterns: Optional[Union[List[str], str]] = None,
|
54 |
+
max_workers: int = 8,
|
55 |
+
tqdm_class: Optional[base_tqdm] = None,
|
56 |
+
endpoint: Optional[str] = None,
|
57 |
+
) -> str:
|
58 |
+
"""Download repo files.
|
59 |
+
|
60 |
+
Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from
|
61 |
+
a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order
|
62 |
+
to keep their actual filename relative to that folder. You can also filter which files to download using
|
63 |
+
`allow_patterns` and `ignore_patterns`.
|
64 |
+
|
65 |
+
If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure
|
66 |
+
how you want to move those files:
|
67 |
+
- If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob
|
68 |
+
files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal
|
69 |
+
is to be able to manually edit and save small files without corrupting the cache while saving disk space for
|
70 |
+
binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD`
|
71 |
+
environment variable.
|
72 |
+
- If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`.
|
73 |
+
This is optimal in term of disk usage but files must not be manually edited.
|
74 |
+
- If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the
|
75 |
+
local dir. This means disk usage is not optimized.
|
76 |
+
- Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the
|
77 |
+
files are downloaded and directly placed under `local_dir`. This means if you need to download them again later,
|
78 |
+
they will be re-downloaded entirely.
|
79 |
+
|
80 |
+
An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly
|
81 |
+
configured. It is also not possible to filter which files to download when cloning a repository using git.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
repo_id (`str`):
|
85 |
+
A user or an organization name and a repo name separated by a `/`.
|
86 |
+
repo_type (`str`, *optional*):
|
87 |
+
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
|
88 |
+
`None` or `"model"` if downloading from a model. Default is `None`.
|
89 |
+
revision (`str`, *optional*):
|
90 |
+
An optional Git revision id which can be a branch name, a tag, or a
|
91 |
+
commit hash.
|
92 |
+
cache_dir (`str`, `Path`, *optional*):
|
93 |
+
Path to the folder where cached files are stored.
|
94 |
+
local_dir (`str` or `Path`, *optional*):
|
95 |
+
If provided, the downloaded files will be placed under this directory, either as symlinks (default) or
|
96 |
+
regular files (see description for more details).
|
97 |
+
local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`):
|
98 |
+
To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either
|
99 |
+
duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be
|
100 |
+
created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if
|
101 |
+
already exists) or downloaded from the Hub and not cached. See description for more details.
|
102 |
+
library_name (`str`, *optional*):
|
103 |
+
The name of the library to which the object corresponds.
|
104 |
+
library_version (`str`, *optional*):
|
105 |
+
The version of the library.
|
106 |
+
user_agent (`str`, `dict`, *optional*):
|
107 |
+
The user-agent info in the form of a dictionary or a string.
|
108 |
+
proxies (`dict`, *optional*):
|
109 |
+
Dictionary mapping protocol to the URL of the proxy passed to
|
110 |
+
`requests.request`.
|
111 |
+
etag_timeout (`float`, *optional*, defaults to `10`):
|
112 |
+
When fetching ETag, how many seconds to wait for the server to send
|
113 |
+
data before giving up which is passed to `requests.request`.
|
114 |
+
resume_download (`bool`, *optional*, defaults to `False):
|
115 |
+
If `True`, resume a previously interrupted download.
|
116 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
117 |
+
Whether the file should be downloaded even if it already exists in the local cache.
|
118 |
+
token (`str`, `bool`, *optional*):
|
119 |
+
A token to be used for the download.
|
120 |
+
- If `True`, the token is read from the HuggingFace config
|
121 |
+
folder.
|
122 |
+
- If a string, it's used as the authentication token.
|
123 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
124 |
+
If `True`, avoid downloading the file and return the path to the
|
125 |
+
local cached file if it exists.
|
126 |
+
allow_patterns (`List[str]` or `str`, *optional*):
|
127 |
+
If provided, only files matching at least one pattern are downloaded.
|
128 |
+
ignore_patterns (`List[str]` or `str`, *optional*):
|
129 |
+
If provided, files matching any of the patterns are not downloaded.
|
130 |
+
max_workers (`int`, *optional*):
|
131 |
+
Number of concurrent threads to download files (1 thread = 1 file download).
|
132 |
+
Defaults to 8.
|
133 |
+
tqdm_class (`tqdm`, *optional*):
|
134 |
+
If provided, overwrites the default behavior for the progress bar. Passed
|
135 |
+
argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior.
|
136 |
+
Note that the `tqdm_class` is not passed to each individual download.
|
137 |
+
Defaults to the custom HF progress bar that can be disabled by setting
|
138 |
+
`HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
Local folder path (string) of repo snapshot
|
142 |
+
|
143 |
+
<Tip>
|
144 |
+
|
145 |
+
Raises the following errors:
|
146 |
+
|
147 |
+
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
148 |
+
if `token=True` and the token cannot be found.
|
149 |
+
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
150 |
+
ETag cannot be determined.
|
151 |
+
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
152 |
+
if some parameter value is invalid
|
153 |
+
|
154 |
+
</Tip>
|
155 |
+
"""
|
156 |
+
if cache_dir is None:
|
157 |
+
cache_dir = HF_HUB_CACHE
|
158 |
+
if revision is None:
|
159 |
+
revision = DEFAULT_REVISION
|
160 |
+
if isinstance(cache_dir, Path):
|
161 |
+
cache_dir = str(cache_dir)
|
162 |
+
|
163 |
+
if repo_type is None:
|
164 |
+
repo_type = "model"
|
165 |
+
if repo_type not in REPO_TYPES:
|
166 |
+
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
|
167 |
+
|
168 |
+
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
|
169 |
+
|
170 |
+
repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
|
171 |
+
api_call_error: Optional[Exception] = None
|
172 |
+
if not local_files_only:
|
173 |
+
# try/except logic to handle different errors => taken from `hf_hub_download`
|
174 |
+
try:
|
175 |
+
# if we have internet connection we want to list files to download
|
176 |
+
api = HfApi(
|
177 |
+
library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint
|
178 |
+
)
|
179 |
+
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
|
180 |
+
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
181 |
+
# Actually raise for those subclasses of ConnectionError
|
182 |
+
raise
|
183 |
+
except (
|
184 |
+
requests.exceptions.ConnectionError,
|
185 |
+
requests.exceptions.Timeout,
|
186 |
+
OfflineModeIsEnabled,
|
187 |
+
) as error:
|
188 |
+
# Internet connection is down
|
189 |
+
# => will try to use local files only
|
190 |
+
api_call_error = error
|
191 |
+
pass
|
192 |
+
except RevisionNotFoundError:
|
193 |
+
# The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
|
194 |
+
raise
|
195 |
+
except requests.HTTPError as error:
|
196 |
+
# Multiple reasons for an http error:
|
197 |
+
# - Repository is private and invalid/missing token sent
|
198 |
+
# - Repository is gated and invalid/missing token sent
|
199 |
+
# - Hub is down (error 500 or 504)
|
200 |
+
# => let's switch to 'local_files_only=True' to check if the files are already cached.
|
201 |
+
# (if it's not the case, the error will be re-raised)
|
202 |
+
api_call_error = error
|
203 |
+
pass
|
204 |
+
|
205 |
+
# At this stage, if `repo_info` is None it means either:
|
206 |
+
# - internet connection is down
|
207 |
+
# - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True)
|
208 |
+
# - repo is private/gated and invalid/missing token sent
|
209 |
+
# - Hub is down
|
210 |
+
# => let's look if we can find the appropriate folder in the cache:
|
211 |
+
# - if the specified revision is a commit hash, look inside "snapshots".
|
212 |
+
# - f the specified revision is a branch or tag, look inside "refs".
|
213 |
+
if repo_info is None:
|
214 |
+
# Try to get which commit hash corresponds to the specified revision
|
215 |
+
commit_hash = None
|
216 |
+
if REGEX_COMMIT_HASH.match(revision):
|
217 |
+
commit_hash = revision
|
218 |
+
else:
|
219 |
+
ref_path = os.path.join(storage_folder, "refs", revision)
|
220 |
+
if os.path.exists(ref_path):
|
221 |
+
# retrieve commit_hash from refs file
|
222 |
+
with open(ref_path) as f:
|
223 |
+
commit_hash = f.read()
|
224 |
+
|
225 |
+
# Try to locate snapshot folder for this commit hash
|
226 |
+
if commit_hash is not None:
|
227 |
+
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
|
228 |
+
if os.path.exists(snapshot_folder):
|
229 |
+
# Snapshot folder exists => let's return it
|
230 |
+
# (but we can't check if all the files are actually there)
|
231 |
+
return snapshot_folder
|
232 |
+
|
233 |
+
# If we couldn't find the appropriate folder on disk, raise an error.
|
234 |
+
if local_files_only:
|
235 |
+
raise LocalEntryNotFoundError(
|
236 |
+
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
|
237 |
+
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass "
|
238 |
+
"'local_files_only=False' as input."
|
239 |
+
)
|
240 |
+
elif isinstance(api_call_error, OfflineModeIsEnabled):
|
241 |
+
raise LocalEntryNotFoundError(
|
242 |
+
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
|
243 |
+
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
|
244 |
+
"'HF_HUB_OFFLINE=0' as environment variable."
|
245 |
+
) from api_call_error
|
246 |
+
elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError):
|
247 |
+
# Repo not found => let's raise the actual error
|
248 |
+
raise api_call_error
|
249 |
+
else:
|
250 |
+
# Otherwise: most likely a connection issue or Hub downtime => let's warn the user
|
251 |
+
raise LocalEntryNotFoundError(
|
252 |
+
"An error happened while trying to locate the files on the Hub and we cannot find the appropriate"
|
253 |
+
" snapshot folder for the specified revision on the local disk. Please check your internet connection"
|
254 |
+
" and try again."
|
255 |
+
) from api_call_error
|
256 |
+
|
257 |
+
# At this stage, internet connection is up and running
|
258 |
+
# => let's download the files!
|
259 |
+
assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
|
260 |
+
assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
|
261 |
+
filtered_repo_files = list(
|
262 |
+
filter_repo_objects(
|
263 |
+
items=[f.rfilename for f in repo_info.siblings],
|
264 |
+
allow_patterns=allow_patterns,
|
265 |
+
ignore_patterns=ignore_patterns,
|
266 |
+
)
|
267 |
+
)
|
268 |
+
commit_hash = repo_info.sha
|
269 |
+
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
|
270 |
+
# if passed revision is not identical to commit_hash
|
271 |
+
# then revision has to be a branch name or tag name.
|
272 |
+
# In that case store a ref.
|
273 |
+
if revision != commit_hash:
|
274 |
+
ref_path = os.path.join(storage_folder, "refs", revision)
|
275 |
+
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
|
276 |
+
with open(ref_path, "w") as f:
|
277 |
+
f.write(commit_hash)
|
278 |
+
|
279 |
+
# we pass the commit_hash to hf_hub_download
|
280 |
+
# so no network call happens if we already
|
281 |
+
# have the file locally.
|
282 |
+
def _inner_hf_hub_download(repo_file: str):
|
283 |
+
return hf_hub_download(
|
284 |
+
repo_id,
|
285 |
+
filename=repo_file,
|
286 |
+
repo_type=repo_type,
|
287 |
+
revision=commit_hash,
|
288 |
+
endpoint=endpoint,
|
289 |
+
cache_dir=cache_dir,
|
290 |
+
local_dir=local_dir,
|
291 |
+
local_dir_use_symlinks=local_dir_use_symlinks,
|
292 |
+
library_name=library_name,
|
293 |
+
library_version=library_version,
|
294 |
+
user_agent=user_agent,
|
295 |
+
proxies=proxies,
|
296 |
+
etag_timeout=etag_timeout,
|
297 |
+
resume_download=resume_download,
|
298 |
+
force_download=force_download,
|
299 |
+
token=token,
|
300 |
+
)
|
301 |
+
|
302 |
+
if HF_HUB_ENABLE_HF_TRANSFER:
|
303 |
+
# when using hf_transfer we don't want extra parallelism
|
304 |
+
# from the one hf_transfer provides
|
305 |
+
for file in filtered_repo_files:
|
306 |
+
_inner_hf_hub_download(file)
|
307 |
+
else:
|
308 |
+
thread_map(
|
309 |
+
_inner_hf_hub_download,
|
310 |
+
filtered_repo_files,
|
311 |
+
desc=f"Fetching {len(filtered_repo_files)} files",
|
312 |
+
max_workers=max_workers,
|
313 |
+
# User can use its own tqdm class or the default one from `huggingface_hub.utils`
|
314 |
+
tqdm_class=tqdm_class or hf_tqdm,
|
315 |
+
)
|
316 |
+
|
317 |
+
if local_dir is not None:
|
318 |
+
return str(os.path.realpath(local_dir))
|
319 |
+
return snapshot_folder
|
lib/python3.11/site-packages/huggingface_hub/_space_api.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2019-present, the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from datetime import datetime
|
17 |
+
from enum import Enum
|
18 |
+
from typing import Dict, Optional
|
19 |
+
|
20 |
+
from huggingface_hub.utils import parse_datetime
|
21 |
+
|
22 |
+
|
23 |
+
class SpaceStage(str, Enum):
|
24 |
+
"""
|
25 |
+
Enumeration of possible stage of a Space on the Hub.
|
26 |
+
|
27 |
+
Value can be compared to a string:
|
28 |
+
```py
|
29 |
+
assert SpaceStage.BUILDING == "BUILDING"
|
30 |
+
```
|
31 |
+
|
32 |
+
Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L61 (private url).
|
33 |
+
"""
|
34 |
+
|
35 |
+
# Copied from moon-landing > server > repo_types > SpaceInfo.ts (private repo)
|
36 |
+
NO_APP_FILE = "NO_APP_FILE"
|
37 |
+
CONFIG_ERROR = "CONFIG_ERROR"
|
38 |
+
BUILDING = "BUILDING"
|
39 |
+
BUILD_ERROR = "BUILD_ERROR"
|
40 |
+
RUNNING = "RUNNING"
|
41 |
+
RUNNING_BUILDING = "RUNNING_BUILDING"
|
42 |
+
RUNTIME_ERROR = "RUNTIME_ERROR"
|
43 |
+
DELETING = "DELETING"
|
44 |
+
STOPPED = "STOPPED"
|
45 |
+
PAUSED = "PAUSED"
|
46 |
+
|
47 |
+
|
48 |
+
class SpaceHardware(str, Enum):
|
49 |
+
"""
|
50 |
+
Enumeration of hardwares available to run your Space on the Hub.
|
51 |
+
|
52 |
+
Value can be compared to a string:
|
53 |
+
```py
|
54 |
+
assert SpaceHardware.CPU_BASIC == "cpu-basic"
|
55 |
+
```
|
56 |
+
|
57 |
+
Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceInfo.ts#L73 (private url).
|
58 |
+
"""
|
59 |
+
|
60 |
+
CPU_BASIC = "cpu-basic"
|
61 |
+
CPU_UPGRADE = "cpu-upgrade"
|
62 |
+
T4_SMALL = "t4-small"
|
63 |
+
T4_MEDIUM = "t4-medium"
|
64 |
+
ZERO_A10G = "zero-a10g"
|
65 |
+
A10G_SMALL = "a10g-small"
|
66 |
+
A10G_LARGE = "a10g-large"
|
67 |
+
A10G_LARGEX2 = "a10g-largex2"
|
68 |
+
A10G_LARGEX4 = "a10g-largex4"
|
69 |
+
A100_LARGE = "a100-large"
|
70 |
+
|
71 |
+
|
72 |
+
class SpaceStorage(str, Enum):
|
73 |
+
"""
|
74 |
+
Enumeration of persistent storage available for your Space on the Hub.
|
75 |
+
|
76 |
+
Value can be compared to a string:
|
77 |
+
```py
|
78 |
+
assert SpaceStorage.SMALL == "small"
|
79 |
+
```
|
80 |
+
|
81 |
+
Taken from https://github.com/huggingface/moon-landing/blob/main/server/repo_types/SpaceHardwareFlavor.ts#L24 (private url).
|
82 |
+
"""
|
83 |
+
|
84 |
+
SMALL = "small"
|
85 |
+
MEDIUM = "medium"
|
86 |
+
LARGE = "large"
|
87 |
+
|
88 |
+
|
89 |
+
@dataclass
|
90 |
+
class SpaceRuntime:
|
91 |
+
"""
|
92 |
+
Contains information about the current runtime of a Space.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
stage (`str`):
|
96 |
+
Current stage of the space. Example: RUNNING.
|
97 |
+
hardware (`str` or `None`):
|
98 |
+
Current hardware of the space. Example: "cpu-basic". Can be `None` if Space
|
99 |
+
is `BUILDING` for the first time.
|
100 |
+
requested_hardware (`str` or `None`):
|
101 |
+
Requested hardware. Can be different than `hardware` especially if the request
|
102 |
+
has just been made. Example: "t4-medium". Can be `None` if no hardware has
|
103 |
+
been requested yet.
|
104 |
+
sleep_time (`int` or `None`):
|
105 |
+
Number of seconds the Space will be kept alive after the last request. By default (if value is `None`), the
|
106 |
+
Space will never go to sleep if it's running on an upgraded hardware, while it will go to sleep after 48
|
107 |
+
hours on a free 'cpu-basic' hardware. For more details, see https://huggingface.co/docs/hub/spaces-gpus#sleep-time.
|
108 |
+
raw (`dict`):
|
109 |
+
Raw response from the server. Contains more information about the Space
|
110 |
+
runtime like number of replicas, number of cpu, memory size,...
|
111 |
+
"""
|
112 |
+
|
113 |
+
stage: SpaceStage
|
114 |
+
hardware: Optional[SpaceHardware]
|
115 |
+
requested_hardware: Optional[SpaceHardware]
|
116 |
+
sleep_time: Optional[int]
|
117 |
+
storage: Optional[SpaceStorage]
|
118 |
+
raw: Dict
|
119 |
+
|
120 |
+
def __init__(self, data: Dict) -> None:
|
121 |
+
self.stage = data["stage"]
|
122 |
+
self.hardware = data.get("hardware", {}).get("current")
|
123 |
+
self.requested_hardware = data.get("hardware", {}).get("requested")
|
124 |
+
self.sleep_time = data.get("gcTimeout")
|
125 |
+
self.storage = data.get("storage")
|
126 |
+
self.raw = data
|
127 |
+
|
128 |
+
|
129 |
+
@dataclass
|
130 |
+
class SpaceVariable:
|
131 |
+
"""
|
132 |
+
Contains information about the current variables of a Space.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
key (`str`):
|
136 |
+
Variable key. Example: `"MODEL_REPO_ID"`
|
137 |
+
value (`str`):
|
138 |
+
Variable value. Example: `"the_model_repo_id"`.
|
139 |
+
description (`str` or None):
|
140 |
+
Description of the variable. Example: `"Model Repo ID of the implemented model"`.
|
141 |
+
updatedAt (`datetime`):
|
142 |
+
datetime of the last update of the variable.
|
143 |
+
"""
|
144 |
+
|
145 |
+
key: str
|
146 |
+
value: str
|
147 |
+
description: Optional[str]
|
148 |
+
updated_at: datetime
|
149 |
+
|
150 |
+
def __init__(self, key: str, values: Dict) -> None:
|
151 |
+
self.key = key
|
152 |
+
self.value = values["value"]
|
153 |
+
self.description = values.get("description")
|
154 |
+
self.updated_at = parse_datetime(values["updatedAt"])
|