from transformers import PretrainedConfig class ProbUNetConfig(PretrainedConfig): model_type = "ProbUNet" def __init__( self, dim=2, in_channels=1, out_channels=1, num_feature_maps=24, latent_size=3, depth=5, latent_distribution="normal", no_outact_op=False, prob_injection_at="end", **kwargs): self.dim = dim self.in_channels = in_channels self.out_channels = out_channels self.num_feature_maps = num_feature_maps self.latent_size = latent_size self.depth = depth self.latent_distribution = latent_distribution self.no_outact_op = no_outact_op self.prob_injection_at = prob_injection_at super().__init__(**kwargs)