m3hrdadfi commited on
Commit
ffdda5a
1 Parent(s): b2bfca3

Add pytorch version

Browse files
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66a6beff446b16ed2090a1915780a6526bed00aff8c393138f46a979fec19353
3
+ size 380290869
src/convert_flax_to_pytorch.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ from transformers import FlaxWav2Vec2ForPreTraining, Wav2Vec2ForPreTraining
5
+
6
+
7
+ def to_f32(t):
8
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
9
+
10
+
11
+ model_fx = FlaxWav2Vec2ForPreTraining.from_pretrained("../")
12
+ model_fx.params = to_f32(model_fx.params)
13
+ model_fx.save_pretrained("./fx")
14
+
15
+ model_pt = Wav2Vec2ForPreTraining.from_pretrained("./fx", from_flax=True)
16
+ model_pt.save_pretrained("./pt")
src/convert_flax_to_tf.py ADDED
File without changes