|
|
|
|
|
|
|
|
|
|
|
|
|
"""Hyperparameter values.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import json |
|
import numbers |
|
import re |
|
import six |
|
|
|
|
|
|
|
|
|
|
|
|
|
PARAM_RE = re.compile( |
|
r""" |
|
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" |
|
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None |
|
\s*=\s* |
|
((?P<val>[^,\[]*) # single value: "a" or None |
|
| |
|
\[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3" |
|
($|,\s*)""", |
|
re.VERBOSE, |
|
) |
|
|
|
|
|
def _parse_fail(name, var_type, value, values): |
|
"""Helper function for raising a value error for bad assignment.""" |
|
raise ValueError( |
|
"Could not parse hparam '%s' of type '%s' with value '%s' in %s" |
|
% (name, var_type.__name__, value, values) |
|
) |
|
|
|
|
|
def _reuse_fail(name, values): |
|
"""Helper function for raising a value error for reuse of name.""" |
|
raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values)) |
|
|
|
|
|
def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary): |
|
"""Update results_dictionary with a scalar value. |
|
|
|
Used to update the results_dictionary to be returned by parse_values when |
|
encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) |
|
|
|
Mutates results_dictionary. |
|
|
|
Args: |
|
name: Name of variable in assignment ("s" or "arr"). |
|
parse_fn: Function for parsing the actual value. |
|
var_type: Type of named variable. |
|
m_dict: Dictionary constructed from regex parsing. |
|
m_dict['val']: RHS value (scalar) |
|
m_dict['index']: List index value (or None) |
|
values: Full expression being parsed |
|
results_dictionary: The dictionary being updated for return by the parsing |
|
function. |
|
|
|
Raises: |
|
ValueError: If the name has already been used. |
|
""" |
|
try: |
|
parsed_value = parse_fn(m_dict["val"]) |
|
except ValueError: |
|
_parse_fail(name, var_type, m_dict["val"], values) |
|
|
|
|
|
if not m_dict["index"]: |
|
if name in results_dictionary: |
|
_reuse_fail(name, values) |
|
results_dictionary[name] = parsed_value |
|
else: |
|
if name in results_dictionary: |
|
|
|
|
|
if not isinstance(results_dictionary.get(name), dict): |
|
_reuse_fail(name, values) |
|
else: |
|
results_dictionary[name] = {} |
|
|
|
index = int(m_dict["index"]) |
|
|
|
if index in results_dictionary[name]: |
|
_reuse_fail("{}[{}]".format(name, index), values) |
|
results_dictionary[name][index] = parsed_value |
|
|
|
|
|
def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary): |
|
"""Update results_dictionary from a list of values. |
|
|
|
Used to update results_dictionary to be returned by parse_values when |
|
encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) |
|
|
|
Mutates results_dictionary. |
|
|
|
Args: |
|
name: Name of variable in assignment ("arr"). |
|
parse_fn: Function for parsing individual values. |
|
var_type: Type of named variable. |
|
m_dict: Dictionary constructed from regex parsing. |
|
m_dict['val']: RHS value (scalar) |
|
values: Full expression being parsed |
|
results_dictionary: The dictionary being updated for return by the parsing |
|
function. |
|
|
|
Raises: |
|
ValueError: If the name has an index or the values cannot be parsed. |
|
""" |
|
if m_dict["index"] is not None: |
|
raise ValueError("Assignment of a list to a list index.") |
|
elements = filter(None, re.split("[ ,]", m_dict["vals"])) |
|
|
|
if name in results_dictionary: |
|
raise _reuse_fail(name, values) |
|
try: |
|
results_dictionary[name] = [parse_fn(e) for e in elements] |
|
except ValueError: |
|
_parse_fail(name, var_type, m_dict["vals"], values) |
|
|
|
|
|
def _cast_to_type_if_compatible(name, param_type, value): |
|
"""Cast hparam to the provided type, if compatible. |
|
|
|
Args: |
|
name: Name of the hparam to be cast. |
|
param_type: The type of the hparam. |
|
value: The value to be cast, if compatible. |
|
|
|
Returns: |
|
The result of casting `value` to `param_type`. |
|
|
|
Raises: |
|
ValueError: If the type of `value` is not compatible with param_type. |
|
* If `param_type` is a string type, but `value` is not. |
|
* If `param_type` is a boolean, but `value` is not, or vice versa. |
|
* If `param_type` is an integer type, but `value` is not. |
|
* If `param_type` is a float type, but `value` is not a numeric type. |
|
""" |
|
fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % ( |
|
name, |
|
param_type, |
|
value, |
|
) |
|
|
|
|
|
if issubclass(param_type, type(None)): |
|
return value |
|
|
|
|
|
if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance( |
|
value, (six.string_types, six.binary_type) |
|
): |
|
raise ValueError(fail_msg) |
|
|
|
|
|
if issubclass(param_type, bool) != isinstance(value, bool): |
|
raise ValueError(fail_msg) |
|
|
|
|
|
if issubclass(param_type, numbers.Integral) and not isinstance( |
|
value, numbers.Integral |
|
): |
|
raise ValueError(fail_msg) |
|
|
|
|
|
if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number): |
|
raise ValueError(fail_msg) |
|
|
|
return param_type(value) |
|
|
|
|
|
def parse_values(values, type_map, ignore_unknown=False): |
|
"""Parses hyperparameter values from a string into a python map. |
|
|
|
`values` is a string containing comma-separated `name=value` pairs. |
|
For each pair, the value of the hyperparameter named `name` is set to |
|
`value`. |
|
|
|
If a hyperparameter name appears multiple times in `values`, a ValueError |
|
is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). |
|
|
|
If a hyperparameter name in both an index assignment and scalar assignment, |
|
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). |
|
|
|
The hyperparameter name may contain '.' symbols, which will result in an |
|
attribute name that is only accessible through the getattr and setattr |
|
functions. (And must be first explicit added through add_hparam.) |
|
|
|
WARNING: Use of '.' in your variable names is allowed, but is not well |
|
supported and not recommended. |
|
|
|
The `value` in `name=value` must follows the syntax according to the |
|
type of the parameter: |
|
|
|
* Scalar integer: A Python-parsable integer point value. E.g.: 1, |
|
100, -12. |
|
* Scalar float: A Python-parsable floating point value. E.g.: 1.0, |
|
-.54e89. |
|
* Boolean: Either true or false. |
|
* Scalar string: A non-empty sequence of characters, excluding comma, |
|
spaces, and square brackets. E.g.: foo, bar_1. |
|
* List: A comma separated list of scalar values of the parameter type |
|
enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. |
|
|
|
When index assignment is used, the corresponding type_map key should be the |
|
list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not |
|
"arr[1]"). |
|
|
|
Args: |
|
values: String. Comma separated list of `name=value` pairs where |
|
'value' must follow the syntax described above. |
|
type_map: A dictionary mapping hyperparameter names to types. Note every |
|
parameter name in values must be a key in type_map. The values must |
|
conform to the types indicated, where a value V is said to conform to a |
|
type T if either V has type T, or V is a list of elements of type T. |
|
Hence, for a multidimensional parameter 'x' taking float values, |
|
'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. |
|
ignore_unknown: Bool. Whether values that are missing a type in type_map |
|
should be ignored. If set to True, a ValueError will not be raised for |
|
unknown hyperparameter type. |
|
|
|
Returns: |
|
A python map mapping each name to either: |
|
* A scalar value. |
|
* A list of scalar values. |
|
* A dictionary mapping index numbers to scalar values. |
|
(e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") |
|
|
|
Raises: |
|
ValueError: If there is a problem with input. |
|
* If `values` cannot be parsed. |
|
* If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). |
|
* If the same rvalue is assigned two different values (e.g. 'a=1,a=2', |
|
'a[1]=1,a[1]=2', or 'a=1,a=[1]') |
|
""" |
|
results_dictionary = {} |
|
pos = 0 |
|
while pos < len(values): |
|
m = PARAM_RE.match(values, pos) |
|
if not m: |
|
raise ValueError("Malformed hyperparameter value: %s" % values[pos:]) |
|
|
|
pos = m.end() |
|
|
|
m_dict = m.groupdict() |
|
name = m_dict["name"] |
|
if name not in type_map: |
|
if ignore_unknown: |
|
continue |
|
raise ValueError("Unknown hyperparameter type for %s" % name) |
|
type_ = type_map[name] |
|
|
|
|
|
if type_ == bool: |
|
|
|
def parse_bool(value): |
|
if value in ["true", "True"]: |
|
return True |
|
elif value in ["false", "False"]: |
|
return False |
|
else: |
|
try: |
|
return bool(int(value)) |
|
except ValueError: |
|
_parse_fail(name, type_, value, values) |
|
|
|
parse = parse_bool |
|
else: |
|
parse = type_ |
|
|
|
|
|
if m_dict["val"] is not None: |
|
_process_scalar_value( |
|
name, parse, type_, m_dict, values, results_dictionary |
|
) |
|
|
|
|
|
elif m_dict["vals"] is not None: |
|
_process_list_value(name, parse, type_, m_dict, values, results_dictionary) |
|
|
|
else: |
|
_parse_fail(name, type_, "", values) |
|
|
|
return results_dictionary |
|
|
|
|
|
class HParams(object): |
|
"""Class to hold a set of hyperparameters as name-value pairs. |
|
|
|
A `HParams` object holds hyperparameters used to build and train a model, |
|
such as the number of hidden units in a neural net layer or the learning rate |
|
to use when training. |
|
|
|
You first create a `HParams` object by specifying the names and values of the |
|
hyperparameters. |
|
|
|
To make them easily accessible the parameter names are added as direct |
|
attributes of the class. A typical usage is as follows: |
|
|
|
```python |
|
# Create a HParams object specifying names and values of the model |
|
# hyperparameters: |
|
hparams = HParams(learning_rate=0.1, num_hidden_units=100) |
|
|
|
# The hyperparameter are available as attributes of the HParams object: |
|
hparams.learning_rate ==> 0.1 |
|
hparams.num_hidden_units ==> 100 |
|
``` |
|
|
|
Hyperparameters have type, which is inferred from the type of their value |
|
passed at construction type. The currently supported types are: integer, |
|
float, boolean, string, and list of integer, float, boolean, or string. |
|
|
|
You can override hyperparameter values by calling the |
|
[`parse()`](#HParams.parse) method, passing a string of comma separated |
|
`name=value` pairs. This is intended to make it possible to override |
|
any hyperparameter values from a single command-line flag to which |
|
the user passes 'hyper-param=value' pairs. It avoids having to define |
|
one flag for each hyperparameter. |
|
|
|
The syntax expected for each value depends on the type of the parameter. |
|
See `parse()` for a description of the syntax. |
|
|
|
Example: |
|
|
|
```python |
|
# Define a command line flag to pass name=value pairs. |
|
# For example using argparse: |
|
import argparse |
|
parser = argparse.ArgumentParser(description='Train my model.') |
|
parser.add_argument('--hparams', type=str, |
|
help='Comma separated list of "name=value" pairs.') |
|
args = parser.parse_args() |
|
... |
|
def my_program(): |
|
# Create a HParams object specifying the names and values of the |
|
# model hyperparameters: |
|
hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, |
|
activations=['relu', 'tanh']) |
|
|
|
# Override hyperparameters values by parsing the command line |
|
hparams.parse(args.hparams) |
|
|
|
# If the user passed `--hparams=learning_rate=0.3` on the command line |
|
# then 'hparams' has the following attributes: |
|
hparams.learning_rate ==> 0.3 |
|
hparams.num_hidden_units ==> 100 |
|
hparams.activations ==> ['relu', 'tanh'] |
|
|
|
# If the hyperparameters are in json format use parse_json: |
|
hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') |
|
``` |
|
""" |
|
|
|
_HAS_DYNAMIC_ATTRIBUTES = True |
|
|
|
def __init__(self, model_structure=None, **kwargs): |
|
"""Create an instance of `HParams` from keyword arguments. |
|
|
|
The keyword arguments specify name-values pairs for the hyperparameters. |
|
The parameter types are inferred from the type of the values passed. |
|
|
|
The parameter names are added as attributes of `HParams` object, so they |
|
can be accessed directly with the dot notation `hparams._name_`. |
|
|
|
Example: |
|
|
|
```python |
|
# Define 3 hyperparameters: 'learning_rate' is a float parameter, |
|
# 'num_hidden_units' an integer parameter, and 'activation' a string |
|
# parameter. |
|
hparams = tf.HParams( |
|
learning_rate=0.1, num_hidden_units=100, activation='relu') |
|
|
|
hparams.activation ==> 'relu' |
|
``` |
|
|
|
Note that a few names are reserved and cannot be used as hyperparameter |
|
names. If you use one of the reserved name the constructor raises a |
|
`ValueError`. |
|
|
|
Args: |
|
model_structure: An instance of ModelStructure, defining the feature |
|
crosses to be used in the Trial. |
|
**kwargs: Key-value pairs where the key is the hyperparameter name and |
|
the value is the value for the parameter. |
|
|
|
Raises: |
|
ValueError: If both `hparam_def` and initialization values are provided, |
|
or if one of the arguments is invalid. |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
self._hparam_types = {} |
|
self._model_structure = model_structure |
|
for name, value in six.iteritems(kwargs): |
|
self.add_hparam(name, value) |
|
|
|
def add_hparam(self, name, value): |
|
"""Adds {name, value} pair to hyperparameters. |
|
|
|
Args: |
|
name: Name of the hyperparameter. |
|
value: Value of the hyperparameter. Can be one of the following types: |
|
int, float, string, int list, float list, or string list. |
|
|
|
Raises: |
|
ValueError: if one of the arguments is invalid. |
|
""" |
|
|
|
|
|
|
|
if getattr(self, name, None) is not None: |
|
raise ValueError("Hyperparameter name is reserved: %s" % name) |
|
if isinstance(value, (list, tuple)): |
|
if not value: |
|
raise ValueError( |
|
"Multi-valued hyperparameters cannot be empty: %s" % name |
|
) |
|
self._hparam_types[name] = (type(value[0]), True) |
|
else: |
|
self._hparam_types[name] = (type(value), False) |
|
setattr(self, name, value) |
|
|
|
def set_hparam(self, name, value): |
|
"""Set the value of an existing hyperparameter. |
|
|
|
This function verifies that the type of the value matches the type of the |
|
existing hyperparameter. |
|
|
|
Args: |
|
name: Name of the hyperparameter. |
|
value: New value of the hyperparameter. |
|
|
|
Raises: |
|
KeyError: If the hyperparameter doesn't exist. |
|
ValueError: If there is a type mismatch. |
|
""" |
|
param_type, is_list = self._hparam_types[name] |
|
if isinstance(value, list): |
|
if not is_list: |
|
raise ValueError( |
|
"Must not pass a list for single-valued parameter: %s" % name |
|
) |
|
setattr( |
|
self, |
|
name, |
|
[_cast_to_type_if_compatible(name, param_type, v) for v in value], |
|
) |
|
else: |
|
if is_list: |
|
raise ValueError( |
|
"Must pass a list for multi-valued parameter: %s." % name |
|
) |
|
setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) |
|
|
|
def del_hparam(self, name): |
|
"""Removes the hyperparameter with key 'name'. |
|
|
|
Does nothing if it isn't present. |
|
|
|
Args: |
|
name: Name of the hyperparameter. |
|
""" |
|
if hasattr(self, name): |
|
delattr(self, name) |
|
del self._hparam_types[name] |
|
|
|
def parse(self, values): |
|
"""Override existing hyperparameter values, parsing new values from a string. |
|
|
|
See parse_values for more detail on the allowed format for values. |
|
|
|
Args: |
|
values: String. Comma separated list of `name=value` pairs where 'value' |
|
must follow the syntax described above. |
|
|
|
Returns: |
|
The `HParams` instance. |
|
|
|
Raises: |
|
ValueError: If `values` cannot be parsed or a hyperparameter in `values` |
|
doesn't exist. |
|
""" |
|
type_map = {} |
|
for name, t in self._hparam_types.items(): |
|
param_type, _ = t |
|
type_map[name] = param_type |
|
|
|
values_map = parse_values(values, type_map) |
|
return self.override_from_dict(values_map) |
|
|
|
def override_from_dict(self, values_dict): |
|
"""Override existing hyperparameter values, parsing new values from a dictionary. |
|
|
|
Args: |
|
values_dict: Dictionary of name:value pairs. |
|
|
|
Returns: |
|
The `HParams` instance. |
|
|
|
Raises: |
|
KeyError: If a hyperparameter in `values_dict` doesn't exist. |
|
ValueError: If `values_dict` cannot be parsed. |
|
""" |
|
for name, value in values_dict.items(): |
|
self.set_hparam(name, value) |
|
return self |
|
|
|
def set_model_structure(self, model_structure): |
|
self._model_structure = model_structure |
|
|
|
def get_model_structure(self): |
|
return self._model_structure |
|
|
|
def to_json(self, indent=None, separators=None, sort_keys=False): |
|
"""Serializes the hyperparameters into JSON. |
|
|
|
Args: |
|
indent: If a non-negative integer, JSON array elements and object members |
|
will be pretty-printed with that indent level. An indent level of 0, or |
|
negative, will only insert newlines. `None` (the default) selects the |
|
most compact representation. |
|
separators: Optional `(item_separator, key_separator)` tuple. Default is |
|
`(', ', ': ')`. |
|
sort_keys: If `True`, the output dictionaries will be sorted by key. |
|
|
|
Returns: |
|
A JSON string. |
|
""" |
|
|
|
def remove_callables(x): |
|
"""Omit callable elements from input with arbitrary nesting.""" |
|
if isinstance(x, dict): |
|
return { |
|
k: remove_callables(v) |
|
for k, v in six.iteritems(x) |
|
if not callable(v) |
|
} |
|
elif isinstance(x, list): |
|
return [remove_callables(i) for i in x if not callable(i)] |
|
return x |
|
|
|
return json.dumps( |
|
remove_callables(self.values()), |
|
indent=indent, |
|
separators=separators, |
|
sort_keys=sort_keys, |
|
) |
|
|
|
def parse_json(self, values_json): |
|
"""Override existing hyperparameter values, parsing new values from a json object. |
|
|
|
Args: |
|
values_json: String containing a json object of name:value pairs. |
|
|
|
Returns: |
|
The `HParams` instance. |
|
|
|
Raises: |
|
KeyError: If a hyperparameter in `values_json` doesn't exist. |
|
ValueError: If `values_json` cannot be parsed. |
|
""" |
|
values_map = json.loads(values_json) |
|
return self.override_from_dict(values_map) |
|
|
|
def values(self): |
|
"""Return the hyperparameter values as a Python dictionary. |
|
|
|
Returns: |
|
A dictionary with hyperparameter names as keys. The values are the |
|
hyperparameter values. |
|
""" |
|
return {n: getattr(self, n) for n in self._hparam_types.keys()} |
|
|
|
def get(self, key, default=None): |
|
"""Returns the value of `key` if it exists, else `default`.""" |
|
if key in self._hparam_types: |
|
|
|
if default is not None: |
|
param_type, is_param_list = self._hparam_types[key] |
|
type_str = "list<%s>" % param_type if is_param_list else str(param_type) |
|
fail_msg = ( |
|
"Hparam '%s' of type '%s' is incompatible with " |
|
"default=%s" % (key, type_str, default) |
|
) |
|
|
|
is_default_list = isinstance(default, list) |
|
if is_param_list != is_default_list: |
|
raise ValueError(fail_msg) |
|
|
|
try: |
|
if is_default_list: |
|
for value in default: |
|
_cast_to_type_if_compatible(key, param_type, value) |
|
else: |
|
_cast_to_type_if_compatible(key, param_type, default) |
|
except ValueError as e: |
|
raise ValueError("%s. %s" % (fail_msg, e)) |
|
|
|
return getattr(self, key) |
|
|
|
return default |
|
|
|
def __contains__(self, key): |
|
return key in self._hparam_types |
|
|
|
def __str__(self): |
|
return str(sorted(self.values().items())) |
|
|
|
def __repr__(self): |
|
return "%s(%s)" % (type(self).__name__, self.__str__()) |
|
|
|
@staticmethod |
|
def _get_kind_name(param_type, is_list): |
|
"""Returns the field name given parameter type and is_list. |
|
|
|
Args: |
|
param_type: Data type of the hparam. |
|
is_list: Whether this is a list. |
|
|
|
Returns: |
|
A string representation of the field name. |
|
|
|
Raises: |
|
ValueError: If parameter type is not recognized. |
|
""" |
|
if issubclass(param_type, bool): |
|
|
|
|
|
typename = "bool" |
|
elif issubclass(param_type, six.integer_types): |
|
|
|
|
|
typename = "int64" |
|
elif issubclass(param_type, (six.string_types, six.binary_type)): |
|
|
|
|
|
typename = "bytes" |
|
elif issubclass(param_type, float): |
|
typename = "float" |
|
else: |
|
raise ValueError("Unsupported parameter type: %s" % str(param_type)) |
|
|
|
suffix = "list" if is_list else "value" |
|
return "_".join([typename, suffix]) |
|
|