Spaces:
Running
Running
feat(train): local jax cache
Browse files- tools/train/train.py +1 -3
tools/train/train.py
CHANGED
@@ -57,9 +57,7 @@ from dalle_mini.model import (
|
|
57 |
set_partitions,
|
58 |
)
|
59 |
|
60 |
-
cc.initialize_cache(
|
61 |
-
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
|
62 |
-
)
|
63 |
|
64 |
logger = logging.getLogger(__name__)
|
65 |
|
|
|
57 |
set_partitions,
|
58 |
)
|
59 |
|
60 |
+
cc.initialize_cache("./jax_cache", max_cache_size_bytes=5 * 2**30)
|
|
|
|
|
61 |
|
62 |
logger = logging.getLogger(__name__)
|
63 |
|