Fix inconsistent types in custom_autotune.py
Browse files- custom_autotune.py +7 -2
custom_autotune.py
CHANGED
@@ -81,16 +81,21 @@ class Autotuner(triton.KernelInterface):
|
|
81 |
# In my testing this gives decent results, and greatly reduces the amount of tuning required
|
82 |
if self.nearest_power_of_two:
|
83 |
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
|
84 |
-
|
85 |
if key not in self.cache:
|
86 |
# prune configs
|
87 |
pruned_configs = self.prune_configs(kwargs)
|
88 |
bench_start = time.time()
|
89 |
timings = {config: self._bench(*args, config=config, **kwargs)
|
90 |
for config in pruned_configs}
|
|
|
|
|
|
|
|
|
|
|
91 |
bench_end = time.time()
|
92 |
self.bench_time = bench_end - bench_start
|
93 |
-
|
|
|
94 |
self.hook(args)
|
95 |
self.configs_timings = timings
|
96 |
config = self.cache[key]
|
|
|
81 |
# In my testing this gives decent results, and greatly reduces the amount of tuning required
|
82 |
if self.nearest_power_of_two:
|
83 |
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
|
|
|
84 |
if key not in self.cache:
|
85 |
# prune configs
|
86 |
pruned_configs = self.prune_configs(kwargs)
|
87 |
bench_start = time.time()
|
88 |
timings = {config: self._bench(*args, config=config, **kwargs)
|
89 |
for config in pruned_configs}
|
90 |
+
temp = {}
|
91 |
+
for config in pruned_configs:
|
92 |
+
if isinstance(self._bench(*args, config=config, **kwargs), float) :
|
93 |
+
continue
|
94 |
+
temp[config] = {self._bench(*args, config=config, **kwargs)}
|
95 |
bench_end = time.time()
|
96 |
self.bench_time = bench_end - bench_start
|
97 |
+
|
98 |
+
self.cache[key] = builtins.min(temp, key=timings.get)
|
99 |
self.hook(args)
|
100 |
self.configs_timings = timings
|
101 |
config = self.cache[key]
|