reach-vb's picture
reach-vb HF staff
8274b78f238537f55a2fdac5624687534650241ed111eeb7b3bab8b2371bbf9b
d2e7957
raw
history blame
5.27 kB
from ..libmp.backend import xrange
from .functions import defun, defun_wrapped
@defun
def gammaprod(ctx, a, b, _infsign=False):
a = [ctx.convert(x) for x in a]
b = [ctx.convert(x) for x in b]
poles_num = []
poles_den = []
regular_num = []
regular_den = []
for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x)
for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x)
# One more pole in numerator or denominator gives 0 or inf
if len(poles_num) < len(poles_den): return ctx.zero
if len(poles_num) > len(poles_den):
# Get correct sign of infinity for x+h, h -> 0 from above
# XXX: hack, this should be done properly
if _infsign:
a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num]
b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den]
return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf
else:
return ctx.inf
# All poles cancel
# lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
p = ctx.one
orig = ctx.prec
try:
ctx.prec = orig + 15
while poles_num:
i = poles_num.pop()
j = poles_den.pop()
p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i)
for x in regular_num: p *= ctx.gamma(x)
for x in regular_den: p /= ctx.gamma(x)
finally:
ctx.prec = orig
return +p
@defun
def beta(ctx, x, y):
x = ctx.convert(x)
y = ctx.convert(y)
if ctx.isinf(y):
x, y = y, x
if ctx.isinf(x):
if x == ctx.inf and not ctx._im(y):
if y == ctx.ninf:
return ctx.nan
if y > 0:
return ctx.zero
if ctx.isint(y):
return ctx.nan
if y < 0:
return ctx.sign(ctx.gamma(y)) * ctx.inf
return ctx.nan
xy = ctx.fadd(x, y, prec=2*ctx.prec)
return ctx.gammaprod([x, y], [xy])
@defun
def binomial(ctx, n, k):
n1 = ctx.fadd(n, 1, prec=2*ctx.prec)
k1 = ctx.fadd(k, 1, prec=2*ctx.prec)
nk1 = ctx.fsub(n1, k, prec=2*ctx.prec)
return ctx.gammaprod([n1], [k1, nk1])
@defun
def rf(ctx, x, n):
xn = ctx.fadd(x, n, prec=2*ctx.prec)
return ctx.gammaprod([xn], [x])
@defun
def ff(ctx, x, n):
x1 = ctx.fadd(x, 1, prec=2*ctx.prec)
xn1 = ctx.fadd(ctx.fsub(x, n, prec=2*ctx.prec), 1, prec=2*ctx.prec)
return ctx.gammaprod([x1], [xn1])
@defun_wrapped
def fac2(ctx, x):
if ctx.isinf(x):
if x == ctx.inf:
return x
return ctx.nan
return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1)
@defun_wrapped
def barnesg(ctx, z):
if ctx.isinf(z):
if z == ctx.inf:
return z
return ctx.nan
if ctx.isnan(z):
return z
if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)):
return z*0
# Account for size (would not be needed if computing log(G))
if abs(z) > 5:
ctx.dps += 2*ctx.log(abs(z),2)
# Reflection formula
if ctx.re(z) < -ctx.dps:
w = 1-z
pi2 = 2*ctx.pi
u = ctx.expjpi(2*w)
v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \
ctx.j*ctx.polylog(2, u)/pi2
v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w
if ctx._is_real_type(z):
v = ctx._re(v)
return v
# Estimate terms for asymptotic expansion
# TODO: fixme, obviously
N = ctx.dps // 2 + 5
G = 1
while abs(z) < N or ctx.re(z) < 1:
G /= ctx.gamma(z)
z += 1
z -= 1
s = ctx.mpf(1)/12
s -= ctx.log(ctx.glaisher)
s += z*ctx.log(2*ctx.pi)/2
s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z)
s -= 3*z**2/4
z2k = z2 = z**2
for k in xrange(1, N+1):
t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k)
if abs(t) < ctx.eps:
#print k, N # check how many terms were needed
break
z2k *= z2
s += t
#if k == N:
# print "warning: series for barnesg failed to converge", ctx.dps
return G*ctx.exp(s)
@defun
def superfac(ctx, z):
return ctx.barnesg(z+2)
@defun_wrapped
def hyperfac(ctx, z):
# XXX: estimate needed extra bits accurately
if z == ctx.inf:
return z
if abs(z) > 5:
extra = 4*int(ctx.log(abs(z),2))
else:
extra = 0
ctx.prec += extra
if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)):
n = int(ctx.re(z))
h = ctx.hyperfac(-n-1)
if ((n+1)//2) & 1:
h = -h
if ctx._is_complex_type(z):
return h + 0j
return h
zp1 = z+1
# Wrong branch cut
#v = ctx.gamma(zp1)**z
#ctx.prec -= extra
#return v / ctx.barnesg(zp1)
v = ctx.exp(z*ctx.loggamma(zp1))
ctx.prec -= extra
return v / ctx.barnesg(zp1)
'''
@defun
def psi0(ctx, z):
"""Shortcut for psi(0,z) (the digamma function)"""
return ctx.psi(0, z)
@defun
def psi1(ctx, z):
"""Shortcut for psi(1,z) (the trigamma function)"""
return ctx.psi(1, z)
@defun
def psi2(ctx, z):
"""Shortcut for psi(2,z) (the tetragamma function)"""
return ctx.psi(2, z)
@defun
def psi3(ctx, z):
"""Shortcut for psi(3,z) (the pentagamma function)"""
return ctx.psi(3, z)
'''