Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
de74f11
1
Parent(s):
9c0e5c9
fix typos and update requirements
Browse files- seq2seq/requirements.txt +2 -0
- seq2seq/run_seq2seq_flax.py +4 -4
seq2seq/requirements.txt
CHANGED
@@ -4,3 +4,5 @@ jaxlib>=0.1.59
|
|
4 |
flax>=0.3.4
|
5 |
optax>=0.0.8
|
6 |
tensorboard
|
|
|
|
|
|
4 |
flax>=0.3.4
|
5 |
optax>=0.0.8
|
6 |
tensorboard
|
7 |
+
nltk
|
8 |
+
wandb
|
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -19,7 +19,7 @@ Script adapted from run_summarization_flax.py
|
|
19 |
"""
|
20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
21 |
|
22 |
-
import logging
|
23 |
import os
|
24 |
import sys
|
25 |
import time
|
@@ -60,7 +60,7 @@ from transformers.file_utils import is_offline_mode
|
|
60 |
|
61 |
import wandb
|
62 |
|
63 |
-
logger =
|
64 |
|
65 |
try:
|
66 |
nltk.data.find("tokenizers/punkt")
|
@@ -389,7 +389,7 @@ def main():
|
|
389 |
data_files["validation"] = data_args.validation_file
|
390 |
if data_args.test_file is not None:
|
391 |
data_files["test"] = data_args.test_file
|
392 |
-
dataset = load_dataset"csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
|
393 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
394 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
395 |
|
@@ -411,7 +411,7 @@ def main():
|
|
411 |
|
412 |
|
413 |
# Create a custom model and initialize it randomly
|
414 |
-
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
415 |
|
416 |
# Use pre-trained weights for encoder
|
417 |
model.params['model']['encoder'] = base_model.params['model']['encoder']
|
|
|
19 |
"""
|
20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
21 |
|
22 |
+
import logging as pylogging # To avoid collision with transformers.utils.logging
|
23 |
import os
|
24 |
import sys
|
25 |
import time
|
|
|
60 |
|
61 |
import wandb
|
62 |
|
63 |
+
logger = pylogging.getLogger(__name__)
|
64 |
|
65 |
try:
|
66 |
nltk.data.find("tokenizers/punkt")
|
|
|
389 |
data_files["validation"] = data_args.validation_file
|
390 |
if data_args.test_file is not None:
|
391 |
data_files["test"] = data_args.test_file
|
392 |
+
dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
|
393 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
394 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
395 |
|
|
|
411 |
|
412 |
|
413 |
# Create a custom model and initialize it randomly
|
414 |
+
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
415 |
|
416 |
# Use pre-trained weights for encoder
|
417 |
model.params['model']['encoder'] = base_model.params['model']['encoder']
|