Open
Description
I have implemented versions of both of these:
from __future__ import annotations
from typing import Union, cast
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.scipy.linalg import cho_solve
from jax.scipy.special import gammaln
from numpy.typing import NDArray
from numpyro.distributions import (
Chi2,
Distribution,
MultivariateNormal,
MultivariateStudentT,
Normal,
constraints,
)
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample
def delta(skewers_: NDArray[float], cov_: NDArray[float]):
return (jnp.einsum("...ij,...j->...i", cov_, skewers_)) / jnp.sqrt(
1 + jnp.einsum("...j,...jk,...k->...", skewers_, cov_, skewers_)[..., jnp.newaxis]
)
# Efficient computation of the distribution functions of student's t chi-squared and f to moderate accuracy
# https://sci-hub.se/10.1080/00949658208810542
# Have to use approximation because `betainc` doesn't have grads defined.
# Which means we can't use the official `StudentT.cdf`
@jax.jit
def t_cdf_approx(df: Union[NDArray[float], float], t: Union[NDArray[float], float]):
a = df - 1 / 2
b = 48 * a**2
# Add epsilon to avoid undefined gradient at 0
z = jnp.sqrt(a * jnp.log(1 + t**2 / df) + 1e-24)
u = (
z
+ (z**3 + 3 * z) / b
- (4 * z**7 + 33 * z**5 + 240 * z**3 + 855 * z) / (10 * b * (b + 0.8 * z**4 + 100))
)
return Normal(loc=0, scale=1).cdf(u * jnp.sign(t))
# Regularized Multivariate Regression Models with Skew-t Error Distributions
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateNormal(Distribution): # type: ignore # pylint: disable=too-many-instance-attributes
arg_constraints = {
"loc": constraints.real_vector,
"scale_tril": constraints.lower_cholesky,
"skewers": constraints.real_vector,
}
support = constraints.real_vector
reparametrized_params = ["loc", "scale_tril", "skewers"]
uv_norm = Normal(0.0, 1.0)
@staticmethod
def mk_big_mv_norm(loc: NDArray[float], skewers: NDArray[float], scale_tril: NDArray[float]):
cov = jnp.einsum("...ij,...hj->...ih", scale_tril, scale_tril)
delta_ = delta(skewers, cov)
cov_star = jnp.block(
[
[jnp.ones(skewers.shape[:-1] + (1, 1)), jnp.expand_dims(delta_, axis=-2)],
[jnp.expand_dims(delta_, axis=-1), cov],
]
)
return MultivariateNormal(loc=jnp.zeros(loc.shape[-1] + 1), scale_tril=jnp.linalg.cholesky(cov_star))
def __init__(
self,
scale_tril: NDArray[float],
skewers: NDArray[float],
loc: Union[NDArray[float], float] = 0,
validate_args: None = None,
):
if jnp.ndim(loc) == 0:
(loc_,) = promote_shapes(loc, shape=(1,))
else:
loc_ = cast(NDArray[float], loc)
batch_shape = lax.broadcast_shapes(
jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
)
(self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
(self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
(self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])
cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))
# Used for sampling
self._big_mv_norm = self.mk_big_mv_norm(
# The blog post just uses unstandardized skewers here but that leads to
# a discrepancy between sampling and log_prob
loc=self.loc,
skewers=skewers / self._std_devs,
scale_tril=scale_tril,
)
# Used for log_prob
self._mv_norm = MultivariateNormal(loc_, scale_tril=scale_tril)
skew_mean = jnp.sqrt(2 / jnp.pi) * delta(self.skewers / self._std_devs, cov_batch)
self._mean = self.loc + skew_mean
# The paper just uses `mean` here but that's definitely not right because
# it potentially leads to covariance matrices which are not positive semi definite
self._covariance = cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)
event_shape = jnp.shape(self.scale_tril)[-1:]
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)
@validate_sample
def log_prob(self, value: NDArray[float]) -> NDArray[float]:
return (
jnp.log(2)
+ self._mv_norm.log_prob(value)
+ jnp.log(
self.uv_norm.cdf(jnp.einsum("...k,...k->...", (value - self.loc) / self._std_devs, self.skewers))
)
)
@staticmethod
def infer_shapes(loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
event_shape = (scale_tril[-1],)
batch_shape = lax.broadcast_shapes(loc[:-1], scale_tril[:-2], skewers[:-1])
return batch_shape, event_shape
# https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/
def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
assert is_prng_key(key)
x = self._big_mv_norm.sample(key, sample_shape=sample_shape)
sign_bit, samples = x[..., 0, jnp.newaxis], x[..., 1:]
return jnp.where(sign_bit <= 0, -1 * samples, samples) + self.loc
@property
def mean(self):
return jnp.broadcast_to(self._mean, self.shape())
@property
def covariance_matrix(self):
return self._covariance
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateStudentT(Distribution): # type: ignore # pylint: disable=too-many-instance-attributes
arg_constraints = {
"df": constraints.positive,
"loc": constraints.real_vector,
"scale_tril": constraints.lower_cholesky,
"skewers": constraints.real_vector,
}
support = constraints.real_vector
reparametrized_params = ["df", "loc", "scale_tril", "skewers"]
def __init__( # pylint: disable=too-many-arguments
self,
df: float,
scale_tril: NDArray[float],
skewers: NDArray[float],
loc: Union[NDArray[float], float] = 0,
validate_args: None = None,
):
if jnp.ndim(loc) == 0:
(loc_,) = promote_shapes(loc, shape=(1,))
else:
loc_ = cast(NDArray[float], loc)
batch_shape = lax.broadcast_shapes(
jnp.shape(df), jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
)
(self.df,) = promote_shapes(df, shape=batch_shape)
(self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
(self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
(self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])
self._width = scale_tril.shape[-1]
# For log_prob
self._mv_t = MultivariateStudentT(df=df, scale_tril=scale_tril, loc=loc)
eye = jnp.broadcast_to(jnp.eye(self._width), shape=batch_shape + scale_tril.shape[-2:])
prec_scale_tril = jnp.linalg.cholesky(cho_solve((self.scale_tril, True), eye))
self.prec = jnp.einsum("...ij,...hj->...ih", prec_scale_tril, prec_scale_tril)
cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))
# For sample
self._mv_skew_norm = SkewMultivariateNormal(
scale_tril=scale_tril, loc=jnp.zeros(self._width), skewers=skewers
)
self._chi2 = Chi2(self.df)
# Mean
b = jnp.sqrt(self.df / jnp.pi) * jnp.exp(gammaln((self.df - 1) / 2) - gammaln(self.df / 2))
skew_mean = b[..., jnp.newaxis] * delta(self.skewers / self._std_devs, cov_batch)
self._mean = self.loc + skew_mean
# The paper says we should multiply by the std devs but that produces results that
# disagree with `sample` and with `SkewMultivariateNormal`
# It also says we should use `_mean` instead of `skew_mean` but that allows for
# covariance matrices which are not positive semi-definite
self._covariance = jnp.array((self.df / (self.df - 2)))[
..., jnp.newaxis, jnp.newaxis
] * cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)
event_shape = jnp.shape(self.scale_tril)[-1:]
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)
@validate_sample
def log_prob(self, value: NDArray[float]) -> NDArray[float]:
distance = value - self.loc
Qy = jnp.einsum("...j,...jk,...k->...", distance, self.prec, distance)
# Have to use approximation because `betainc` doesn't have grads defined.
# Which means we can't use the official `StudentT.cdf`
skew = t_cdf_approx(
self.df + self._width,
jnp.einsum(
"...k,...k->...",
self.skewers,
jnp.einsum(
"...i,...->...i", distance / self._std_devs, jnp.sqrt((self.df + self._width) / (Qy + self.df))
),
),
)
return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)
@staticmethod
def infer_shapes(df: float, loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
event_shape = (scale_tril[-1],)
batch_shape = lax.broadcast_shapes(df, loc[:-1], scale_tril[:-2], skewers[:-1])
return batch_shape, event_shape
def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
assert is_prng_key(key)
key_normal, key_chi2 = random.split(key)
normal = self._mv_skew_norm.sample(key_normal, sample_shape=sample_shape)
chi = self._chi2.sample(key_chi2, sample_shape)
return self.loc + jnp.einsum("...i,...->...i", normal, jnp.sqrt(self.df / chi))
@property
def mean(self):
return jnp.broadcast_to(self._mean, self.shape())
@property
def covariance_matrix(self):
return self._covariance
(I also have some coding testing them.)
- Is there interest in upstreaming these?
- Are there obvious simplifications?
SkewMultivariateStudentT
is notably slower thanMultivariateStudentT
in some circumstances. Are there any obvious performance improvements available?