|
from tensorflow import keras |
|
|
|
from maxim import maxim |
|
from maxim.configs import MAXIM_CONFIGS |
|
|
|
|
|
def Model(variant=None, input_resolution=(256, 256), **kw) -> keras.Model: |
|
"""Factory function to easily create a Model variant like "S". |
|
|
|
Args: |
|
variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3' |
|
| 'M-1' | 'M-2' | 'M-3' |
|
input_resolution: Size of the input images. |
|
**kw: Other UNet config dicts. |
|
|
|
Returns: |
|
The MAXIM model. |
|
""" |
|
|
|
if variant is not None: |
|
config = MAXIM_CONFIGS[variant] |
|
for k, v in config.items(): |
|
kw.setdefault(k, v) |
|
|
|
if "variant" in kw: |
|
_ = kw.pop("variant") |
|
if "input_resolution" in kw: |
|
_ = kw.pop("input_resolution") |
|
model_name = kw.pop("name") |
|
|
|
maxim_model = maxim.MAXIM(**kw) |
|
|
|
inputs = keras.Input((*input_resolution, 3)) |
|
outputs = maxim_model(inputs) |
|
final_model = keras.Model(inputs, outputs, name=f"{model_name}_model") |
|
|
|
return final_model |
|
|