Upload JointCTCAttentionEncoderDecoder
Browse files- config.json +1 -1
- generation.py +61 -0
- modeling_decred.py +10 -6
config.json
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
],
|
6 |
"auto_map": {
|
7 |
"AutoConfig": "configuration_decred.JointCTCAttentionEncoderDecoderConfig",
|
8 |
-
"
|
9 |
},
|
10 |
"ctc_weight": 0.3,
|
11 |
"decoder": {
|
|
|
5 |
],
|
6 |
"auto_map": {
|
7 |
"AutoConfig": "configuration_decred.JointCTCAttentionEncoderDecoderConfig",
|
8 |
+
"AutoModel": "modeling_decred.JointCTCAttentionEncoderDecoder"
|
9 |
},
|
10 |
"ctc_weight": 0.3,
|
11 |
"decoder": {
|
generation.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GenerationConfig
|
2 |
+
|
3 |
+
|
4 |
+
class GenerationConfigCustom(GenerationConfig):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
ctc_weight=0.0,
|
8 |
+
ctc_margin=0,
|
9 |
+
lm_weight=0,
|
10 |
+
lm_model=None,
|
11 |
+
space_token_id=-1,
|
12 |
+
eos_space_trick_weight=0,
|
13 |
+
apply_eos_space_trick=False,
|
14 |
+
**kwargs,
|
15 |
+
):
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
self.ctc_weight = ctc_weight
|
18 |
+
self.ctc_margin = ctc_margin
|
19 |
+
self.lm_weight = lm_weight
|
20 |
+
self.lm_model = lm_model
|
21 |
+
self.space_token_id = space_token_id
|
22 |
+
self.eos_space_trick_weight = eos_space_trick_weight
|
23 |
+
self.apply_eos_space_trick = apply_eos_space_trick
|
24 |
+
|
25 |
+
def update_from_string(self, update_str: str):
|
26 |
+
"""
|
27 |
+
Updates attributes of this class with attributes from `update_str`.
|
28 |
+
|
29 |
+
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
|
30 |
+
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
31 |
+
|
32 |
+
The keys to change have to already exist in the config object.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
update_str (`str`): String with attributes that should be updated for this class.
|
36 |
+
|
37 |
+
"""
|
38 |
+
|
39 |
+
d = dict(x.split("=") for x in update_str.split(";"))
|
40 |
+
for k, v in d.items():
|
41 |
+
if not hasattr(self, k):
|
42 |
+
raise ValueError(f"key {k} isn't in the original config dict")
|
43 |
+
|
44 |
+
old_v = getattr(self, k)
|
45 |
+
if isinstance(old_v, bool):
|
46 |
+
if v.lower() in ["true", "1", "y", "yes"]:
|
47 |
+
v = True
|
48 |
+
elif v.lower() in ["false", "0", "n", "no"]:
|
49 |
+
v = False
|
50 |
+
else:
|
51 |
+
raise ValueError(f"can't derive true or false from {v} (key {k})")
|
52 |
+
elif isinstance(old_v, int):
|
53 |
+
v = int(v)
|
54 |
+
elif isinstance(old_v, float):
|
55 |
+
v = float(v)
|
56 |
+
elif not isinstance(old_v, str):
|
57 |
+
raise ValueError(
|
58 |
+
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
|
59 |
+
)
|
60 |
+
|
61 |
+
setattr(self, k, v)
|
modeling_decred.py
CHANGED
@@ -8,7 +8,6 @@ from transformers import (
|
|
8 |
AutoConfig,
|
9 |
AutoModelForCausalLM,
|
10 |
AutoModelForSpeechSeq2Seq,
|
11 |
-
GenerationConfig,
|
12 |
LogitsProcessor,
|
13 |
PretrainedConfig,
|
14 |
PreTrainedModel,
|
@@ -28,6 +27,7 @@ from .auto_wrappers import CustomAutoModelForCTC
|
|
28 |
from .configuration_decred import JointCTCAttentionEncoderDecoderConfig
|
29 |
from .ctc_scorer import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
|
30 |
from .embeddings import AdaptiveEmbedding, PositionalEmbedding
|
|
|
31 |
from .multi_head_gpt2 import GPT2LMMultiHeadModel
|
32 |
|
33 |
logger = logging.get_logger("transformers")
|
@@ -433,7 +433,7 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
|
|
433 |
|
434 |
def _get_logits_processor(
|
435 |
self,
|
436 |
-
generation_config:
|
437 |
input_ids_seq_length: int,
|
438 |
encoder_input_ids: torch.LongTensor,
|
439 |
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
@@ -464,9 +464,13 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
|
|
464 |
self.generation_config.ctc_margin,
|
465 |
self.generation_config.ctc_weight,
|
466 |
self.generation_config.num_beams,
|
467 |
-
self.generation_config.space_token_id,
|
468 |
-
self.generation_config.apply_eos_space_trick
|
469 |
-
self.generation_config
|
|
|
|
|
|
|
|
|
470 |
)
|
471 |
processors.append(self.ctc_rescorer)
|
472 |
if hasattr(generation_config, "lm_weight") and generation_config.lm_weight > 0:
|
@@ -524,7 +528,7 @@ class JointCTCAttentionEncoderDecoder(SpeechEncoderDecoderModel):
|
|
524 |
def generate(
|
525 |
self,
|
526 |
inputs: Optional[torch.Tensor] = None,
|
527 |
-
generation_config: Optional[
|
528 |
logits_processor: Optional[LogitsProcessorList] = None,
|
529 |
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
530 |
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
|
|
8 |
AutoConfig,
|
9 |
AutoModelForCausalLM,
|
10 |
AutoModelForSpeechSeq2Seq,
|
|
|
11 |
LogitsProcessor,
|
12 |
PretrainedConfig,
|
13 |
PreTrainedModel,
|
|
|
27 |
from .configuration_decred import JointCTCAttentionEncoderDecoderConfig
|
28 |
from .ctc_scorer import CTCRescorerLogitsProcessor, LogSoftmaxProcessor
|
29 |
from .embeddings import AdaptiveEmbedding, PositionalEmbedding
|
30 |
+
from .generation import GenerationConfigCustom
|
31 |
from .multi_head_gpt2 import GPT2LMMultiHeadModel
|
32 |
|
33 |
logger = logging.get_logger("transformers")
|
|
|
433 |
|
434 |
def _get_logits_processor(
|
435 |
self,
|
436 |
+
generation_config: GenerationConfigCustom,
|
437 |
input_ids_seq_length: int,
|
438 |
encoder_input_ids: torch.LongTensor,
|
439 |
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
|
|
464 |
self.generation_config.ctc_margin,
|
465 |
self.generation_config.ctc_weight,
|
466 |
self.generation_config.num_beams,
|
467 |
+
self.generation_config.space_token_id if hasattr(self.generation_config, "space_token_id") else None,
|
468 |
+
self.generation_config.apply_eos_space_trick
|
469 |
+
if hasattr(self.generation_config, "apply_eos_space_trick")
|
470 |
+
else False,
|
471 |
+
self.generation_config.eos_space_trick_weight
|
472 |
+
if hasattr(self.generation_config, "eos_space_trick_weight")
|
473 |
+
else 0.0,
|
474 |
)
|
475 |
processors.append(self.ctc_rescorer)
|
476 |
if hasattr(generation_config, "lm_weight") and generation_config.lm_weight > 0:
|
|
|
528 |
def generate(
|
529 |
self,
|
530 |
inputs: Optional[torch.Tensor] = None,
|
531 |
+
generation_config: Optional[GenerationConfigCustom] = None,
|
532 |
logits_processor: Optional[LogitsProcessorList] = None,
|
533 |
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
534 |
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|