|
from transformers import GenerationConfig |
|
|
|
|
|
class GenerationConfigCustom(GenerationConfig): |
|
def __init__( |
|
self, |
|
ctc_weight=0.0, |
|
ctc_margin=0, |
|
lm_weight=0, |
|
lm_model=None, |
|
space_token_id=-1, |
|
eos_space_trick_weight=0, |
|
apply_eos_space_trick=False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.ctc_weight = ctc_weight |
|
self.ctc_margin = ctc_margin |
|
self.lm_weight = lm_weight |
|
self.lm_model = lm_model |
|
self.space_token_id = space_token_id |
|
self.eos_space_trick_weight = eos_space_trick_weight |
|
self.apply_eos_space_trick = apply_eos_space_trick |
|
|
|
def update_from_string(self, update_str: str): |
|
""" |
|
Updates attributes of this class with attributes from `update_str`. |
|
|
|
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example: |
|
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" |
|
|
|
The keys to change have to already exist in the config object. |
|
|
|
Args: |
|
update_str (`str`): String with attributes that should be updated for this class. |
|
|
|
""" |
|
|
|
d = dict(x.split("=") for x in update_str.split(";")) |
|
for k, v in d.items(): |
|
if not hasattr(self, k): |
|
raise ValueError(f"key {k} isn't in the original config dict") |
|
|
|
old_v = getattr(self, k) |
|
if isinstance(old_v, bool): |
|
if v.lower() in ["true", "1", "y", "yes"]: |
|
v = True |
|
elif v.lower() in ["false", "0", "n", "no"]: |
|
v = False |
|
else: |
|
raise ValueError(f"can't derive true or false from {v} (key {k})") |
|
elif isinstance(old_v, int): |
|
v = int(v) |
|
elif isinstance(old_v, float): |
|
v = float(v) |
|
elif not isinstance(old_v, str): |
|
raise ValueError( |
|
f"You can only update int, float, bool or string values in the config, got {v} for key {k}" |
|
) |
|
|
|
setattr(self, k, v) |
|
|