Spaces:
Runtime error
Runtime error
CorvaeOboro
commited on
Commit
β’
debdde2
1
Parent(s):
3ada297
Upload custom_ops.py
Browse files- torch_utils/custom_ops.py +126 -0
torch_utils/custom_ops.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import torch
|
12 |
+
import torch.utils.cpp_extension
|
13 |
+
import importlib
|
14 |
+
import hashlib
|
15 |
+
import shutil
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
from torch.utils.file_baton import FileBaton
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
# Global options.
|
22 |
+
|
23 |
+
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
# Internal helper funcs.
|
27 |
+
|
28 |
+
def _find_compiler_bindir():
|
29 |
+
patterns = [
|
30 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
31 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
32 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
33 |
+
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
34 |
+
]
|
35 |
+
for pattern in patterns:
|
36 |
+
matches = sorted(glob.glob(pattern))
|
37 |
+
if len(matches):
|
38 |
+
return matches[-1]
|
39 |
+
return None
|
40 |
+
|
41 |
+
#----------------------------------------------------------------------------
|
42 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
43 |
+
|
44 |
+
_cached_plugins = dict()
|
45 |
+
|
46 |
+
def get_plugin(module_name, sources, **build_kwargs):
|
47 |
+
assert verbosity in ['none', 'brief', 'full']
|
48 |
+
|
49 |
+
# Already cached?
|
50 |
+
if module_name in _cached_plugins:
|
51 |
+
return _cached_plugins[module_name]
|
52 |
+
|
53 |
+
# Print status.
|
54 |
+
if verbosity == 'full':
|
55 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
56 |
+
elif verbosity == 'brief':
|
57 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
58 |
+
|
59 |
+
try: # pylint: disable=too-many-nested-blocks
|
60 |
+
# Make sure we can find the necessary compiler binaries.
|
61 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
62 |
+
compiler_bindir = _find_compiler_bindir()
|
63 |
+
if compiler_bindir is None:
|
64 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
65 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
66 |
+
|
67 |
+
# Compile and load.
|
68 |
+
verbose_build = (verbosity == 'full')
|
69 |
+
|
70 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
71 |
+
# into a cached build directory under a combined md5 digest of the input
|
72 |
+
# source files. Copying is done only if the combined digest has changed.
|
73 |
+
# This keeps input file timestamps and filenames the same as in previous
|
74 |
+
# extension builds, allowing for fast incremental rebuilds.
|
75 |
+
#
|
76 |
+
# This optimization is done only in case all the source files reside in
|
77 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
78 |
+
# environment variable is set (we take this as a signal that the user
|
79 |
+
# actually cares about this.)
|
80 |
+
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
81 |
+
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
82 |
+
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
83 |
+
|
84 |
+
# Compute a combined hash digest for all source files in the same
|
85 |
+
# custom op directory (usually .cu, .cpp, .py and .h files).
|
86 |
+
hash_md5 = hashlib.md5()
|
87 |
+
for src in all_source_files:
|
88 |
+
with open(src, 'rb') as f:
|
89 |
+
hash_md5.update(f.read())
|
90 |
+
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
91 |
+
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
92 |
+
|
93 |
+
if not os.path.isdir(digest_build_dir):
|
94 |
+
os.makedirs(digest_build_dir, exist_ok=True)
|
95 |
+
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
96 |
+
if baton.try_acquire():
|
97 |
+
try:
|
98 |
+
for src in all_source_files:
|
99 |
+
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
100 |
+
finally:
|
101 |
+
baton.release()
|
102 |
+
else:
|
103 |
+
# Someone else is copying source files under the digest dir,
|
104 |
+
# wait until done and continue.
|
105 |
+
baton.wait()
|
106 |
+
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
107 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
108 |
+
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
109 |
+
else:
|
110 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
111 |
+
module = importlib.import_module(module_name)
|
112 |
+
|
113 |
+
except:
|
114 |
+
if verbosity == 'brief':
|
115 |
+
print('Failed!')
|
116 |
+
raise
|
117 |
+
|
118 |
+
# Print status and add to cache.
|
119 |
+
if verbosity == 'full':
|
120 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
121 |
+
elif verbosity == 'brief':
|
122 |
+
print('Done.')
|
123 |
+
_cached_plugins[module_name] = module
|
124 |
+
return module
|
125 |
+
|
126 |
+
#----------------------------------------------------------------------------
|