File size: 1,527 Bytes
885da14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import jax.numpy as jnp
import jax
import torch
from dataclasses import dataclass
import sympy
import sympy as sp
from sympy import Matrix, Symbol
import math
from sde_redefined_param import SDEDimension
@dataclass
class SDEPolynomialConfig:
name = "Custom"
variable = Symbol('t', nonnegative=True, real=True)
drift_dimension = SDEDimension.SCALAR
diffusion_dimension = SDEDimension.SCALAR
diffusion_matrix_dimension = SDEDimension.SCALAR
drift_degree = 20
diffusion_degree = 20
drift_parameters = Matrix([sympy.symbols(f"f:{drift_degree}", real=True)])
# square parameters to ensure positive definiteness
diffusion_parameters = Matrix([sympy.symbols(f"l:{diffusion_degree}", real=True)])
@property
def drift(self):
return -sympy.Abs(sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(1,self.drift_degree+1)]]), self.drift_parameters).doit()))
@property
def diffusion(self):
return sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(1,self.diffusion_degree+1)]]), self.diffusion_parameters.applyfunc(lambda x: x**2)).doit())
# TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t))
diffusion_matrix = 1
initial_variable_value = 0
max_variable_value = 1 # math.inf
min_sample_value = 1e-6
module = 'jax'
drift_integral_form=True
diffusion_integral_form=True
diffusion_integral_decomposition = 'cholesky' # ldl
target = "epsilon" # x0
|