m3hrdadfi commited on
Commit
da9f194
1 Parent(s): 997f011

Fix some bugs

Browse files
src/run_persian.sh CHANGED
@@ -4,7 +4,7 @@ export LC_ALL=C.UTF-8
4
  export LANG=C.UTF-8
5
 
6
  export OUTPUT_DIR=/home/m3hrdadfi/code/wav2vec2-base-persian
7
- export OUTPUT_DIR=/home/m3hrdadfi/data_cache/
8
  export MODEL_NAME_OR_PATH=/home/m3hrdadfi/code/wav2vec2-base-persian
9
 
10
 
 
4
  export LANG=C.UTF-8
5
 
6
  export OUTPUT_DIR=/home/m3hrdadfi/code/wav2vec2-base-persian
7
+ export CACHE_DIR=/home/m3hrdadfi/data_cache/
8
  export MODEL_NAME_OR_PATH=/home/m3hrdadfi/code/wav2vec2-base-persian
9
 
10
 
src/run_wav2vec2_pretrain_flax.py CHANGED
@@ -49,6 +49,7 @@ from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_in
49
 
50
  logger = logging.getLogger(__name__)
51
 
 
52
 
53
  @flax.struct.dataclass
54
  class ModelArguments:
 
49
 
50
  logger = logging.getLogger(__name__)
51
 
52
+ print(f"TPU: {jax.devices())")
53
 
54
  @flax.struct.dataclass
55
  class ModelArguments: