ydshieh
commited on
Commit
•
7d3b1a0
1
Parent(s):
e5b3a97
Change a parameter in decode(): deterministic -> train
Browse files
vit_gpt2/modeling_flax_vit_gpt2_lm.py
CHANGED
@@ -398,21 +398,6 @@ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
|
|
398 |
)
|
399 |
|
400 |
|
401 |
-
# @add_start_docstrings(
|
402 |
-
# "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.",
|
403 |
-
# BART_START_DOCSTRING,
|
404 |
-
# )
|
405 |
-
# class FlaxViTGPT2LMModel(FlaxViTGPT2LMPreTrainedModel):
|
406 |
-
# config: BartConfig
|
407 |
-
# dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
408 |
-
# module_class = FlaxViTGPT2LMModule
|
409 |
-
#
|
410 |
-
#
|
411 |
-
# append_call_sample_docstring(
|
412 |
-
# FlaxBartModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
|
413 |
-
# )
|
414 |
-
|
415 |
-
|
416 |
class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
417 |
module_class = FlaxViTGPT2LMForConditionalGenerationModule
|
418 |
dtype: jnp.dtype = jnp.float32
|
@@ -428,7 +413,7 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
|
428 |
output_attentions: Optional[bool] = None,
|
429 |
output_hidden_states: Optional[bool] = None,
|
430 |
return_dict: Optional[bool] = None,
|
431 |
-
|
432 |
params: dict = None,
|
433 |
dropout_rng: PRNGKey = None,
|
434 |
):
|
@@ -443,7 +428,7 @@ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
|
443 |
output_attentions,
|
444 |
output_hidden_states,
|
445 |
return_dict,
|
446 |
-
|
447 |
params,
|
448 |
dropout_rng,
|
449 |
)
|
|
|
398 |
)
|
399 |
|
400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
|
402 |
module_class = FlaxViTGPT2LMForConditionalGenerationModule
|
403 |
dtype: jnp.dtype = jnp.float32
|
|
|
413 |
output_attentions: Optional[bool] = None,
|
414 |
output_hidden_states: Optional[bool] = None,
|
415 |
return_dict: Optional[bool] = None,
|
416 |
+
train: bool = False,
|
417 |
params: dict = None,
|
418 |
dropout_rng: PRNGKey = None,
|
419 |
):
|
|
|
428 |
output_attentions,
|
429 |
output_hidden_states,
|
430 |
return_dict,
|
431 |
+
train,
|
432 |
params,
|
433 |
dropout_rng,
|
434 |
)
|