DeCRED-base / generation.py
Lakoc's picture
Upload JointCTCAttentionEncoderDecoder
d7361d8 verified
from transformers import GenerationConfig
class GenerationConfigCustom(GenerationConfig):
def __init__(
self,
ctc_weight=0.0,
ctc_margin=0,
lm_weight=0,
lm_model=None,
space_token_id=-1,
eos_space_trick_weight=0,
apply_eos_space_trick=False,
**kwargs,
):
super().__init__(**kwargs)
self.ctc_weight = ctc_weight
self.ctc_margin = ctc_margin
self.lm_weight = lm_weight
self.lm_model = lm_model
self.space_token_id = space_token_id
self.eos_space_trick_weight = eos_space_trick_weight
self.apply_eos_space_trick = apply_eos_space_trick
def update_from_string(self, update_str: str):
"""
Updates attributes of this class with attributes from `update_str`.
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
The keys to change have to already exist in the config object.
Args:
update_str (`str`): String with attributes that should be updated for this class.
"""
d = dict(x.split("=") for x in update_str.split(";"))
for k, v in d.items():
if not hasattr(self, k):
raise ValueError(f"key {k} isn't in the original config dict")
old_v = getattr(self, k)
if isinstance(old_v, bool):
if v.lower() in ["true", "1", "y", "yes"]:
v = True
elif v.lower() in ["false", "0", "n", "no"]:
v = False
else:
raise ValueError(f"can't derive true or false from {v} (key {k})")
elif isinstance(old_v, int):
v = int(v)
elif isinstance(old_v, float):
v = float(v)
elif not isinstance(old_v, str):
raise ValueError(
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
)
setattr(self, k, v)