m3hrdadfi commited on
Commit
1d83b89
1 Parent(s): d9f0ed0

Update for the revision

Browse files
events.out.tfevents.1626448850.t1v-n-278acf21-w-0.590260.3.v2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1fbe385b41508eae766e3ae9763a6bf8a20b0dad2a36c5058b526b6884a8433a
3
- size 662195
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f8f2848f118433d3ae3412ed5ed0df7242cdf899879357f922313aeaf0b7b5d
3
+ size 809333
flax_model.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dd33994b480ef0a93c7821a12df82c34656dc30539b623c1fb2050b1ba03be19
3
- size 190539834
 
 
 
 
src/run_wav2vec2_pretrain_flax.py CHANGED
@@ -160,7 +160,6 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
160
  """
161
  Data collator that will dynamically pad the inputs received and prepare masked indices
162
  for self-supervised pretraining.
163
-
164
  Args:
165
  model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`):
166
  The Wav2Vec2 model used for pretraining. The data collator needs to have access
@@ -203,6 +202,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
203
 
204
  batch_size = batch["input_values"].shape[0]
205
 
 
206
  if batch["attention_mask"] is not None:
207
  output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
208
  attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
@@ -225,9 +225,11 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
225
  batch["sampled_negative_indices"] = _sample_negative_indices(
226
  (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
227
  self.model.config.num_negatives,
 
228
  )
229
 
230
  return batch
 
231
 
232
  def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
233
  logging.basicConfig(
 
160
  """
161
  Data collator that will dynamically pad the inputs received and prepare masked indices
162
  for self-supervised pretraining.
 
163
  Args:
164
  model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`):
165
  The Wav2Vec2 model used for pretraining. The data collator needs to have access
 
202
 
203
  batch_size = batch["input_values"].shape[0]
204
 
205
+ attention_mask = None
206
  if batch["attention_mask"] is not None:
207
  output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
208
  attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
 
225
  batch["sampled_negative_indices"] = _sample_negative_indices(
226
  (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
227
  self.model.config.num_negatives,
228
+ attention_mask=attention_mask,
229
  )
230
 
231
  return batch
232
+
233
 
234
  def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
235
  logging.basicConfig(