Skip to content

SkewMultivariateNormal and SkewMultivariateStudentT #1452

Open
@colehaus

Description

@colehaus

I have implemented versions of both of these:

from __future__ import annotations

from typing import Union, cast

import jax
from jax import lax
from jax import random
import jax.numpy as jnp
from jax.random import PRNGKey
from jax.scipy.linalg import cho_solve
from jax.scipy.special import gammaln
from numpy.typing import NDArray
from numpyro.distributions import (
    Chi2,
    Distribution,
    MultivariateNormal,
    MultivariateStudentT,
    Normal,
    constraints,
)
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample

def delta(skewers_: NDArray[float], cov_: NDArray[float]):
    return (jnp.einsum("...ij,...j->...i", cov_, skewers_)) / jnp.sqrt(
        1 + jnp.einsum("...j,...jk,...k->...", skewers_, cov_, skewers_)[..., jnp.newaxis]
    )

# Efficient computation of the distribution functions of student's t chi-squared and f to moderate accuracy
# https://sci-hub.se/10.1080/00949658208810542
# Have to use approximation because `betainc` doesn't have grads defined.
# Which means we can't use the official `StudentT.cdf`
@jax.jit
def t_cdf_approx(df: Union[NDArray[float], float], t: Union[NDArray[float], float]):
    a = df - 1 / 2
    b = 48 * a**2
    # Add epsilon to avoid undefined gradient at 0
    z = jnp.sqrt(a * jnp.log(1 + t**2 / df) + 1e-24)
    u = (
        z
        + (z**3 + 3 * z) / b
        - (4 * z**7 + 33 * z**5 + 240 * z**3 + 855 * z) / (10 * b * (b + 0.8 * z**4 + 100))
    )
    return Normal(loc=0, scale=1).cdf(u * jnp.sign(t))

# Regularized Multivariate Regression Models with Skew-t Error Distributions
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateNormal(Distribution):  # type: ignore # pylint: disable=too-many-instance-attributes
    arg_constraints = {
        "loc": constraints.real_vector,
        "scale_tril": constraints.lower_cholesky,
        "skewers": constraints.real_vector,
    }
    support = constraints.real_vector
    reparametrized_params = ["loc", "scale_tril", "skewers"]
    uv_norm = Normal(0.0, 1.0)

    @staticmethod
    def mk_big_mv_norm(loc: NDArray[float], skewers: NDArray[float], scale_tril: NDArray[float]):
        cov = jnp.einsum("...ij,...hj->...ih", scale_tril, scale_tril)
        delta_ = delta(skewers, cov)
        cov_star = jnp.block(
            [
                [jnp.ones(skewers.shape[:-1] + (1, 1)), jnp.expand_dims(delta_, axis=-2)],
                [jnp.expand_dims(delta_, axis=-1), cov],
            ]
        )

        return MultivariateNormal(loc=jnp.zeros(loc.shape[-1] + 1), scale_tril=jnp.linalg.cholesky(cov_star))
    def __init__(
        self,
        scale_tril: NDArray[float],
        skewers: NDArray[float],
        loc: Union[NDArray[float], float] = 0,
        validate_args: None = None,
    ):
        if jnp.ndim(loc) == 0:
            (loc_,) = promote_shapes(loc, shape=(1,))
        else:
            loc_ = cast(NDArray[float], loc)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
        )
        (self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
        (self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
        (self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])
        cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
        self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))

        # Used for sampling
        self._big_mv_norm = self.mk_big_mv_norm(
            # The blog post just uses unstandardized skewers here but that leads to
            # a discrepancy between sampling and log_prob
            loc=self.loc,
            skewers=skewers / self._std_devs,
            scale_tril=scale_tril,
        )
        # Used for log_prob
        self._mv_norm = MultivariateNormal(loc_, scale_tril=scale_tril)

        skew_mean = jnp.sqrt(2 / jnp.pi) * delta(self.skewers / self._std_devs, cov_batch)
        self._mean = self.loc + skew_mean
        # The paper just uses `mean` here but that's definitely not right because
        # it potentially leads to covariance matrices which are not positive semi definite
        self._covariance = cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)

        event_shape = jnp.shape(self.scale_tril)[-1:]
        super().__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
        )
    @validate_sample
    def log_prob(self, value: NDArray[float]) -> NDArray[float]:
        return (
            jnp.log(2)
            + self._mv_norm.log_prob(value)
            + jnp.log(
                self.uv_norm.cdf(jnp.einsum("...k,...k->...", (value - self.loc) / self._std_devs, self.skewers))
            )
        )
    @staticmethod
    def infer_shapes(loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
        event_shape = (scale_tril[-1],)
        batch_shape = lax.broadcast_shapes(loc[:-1], scale_tril[:-2], skewers[:-1])
        return batch_shape, event_shape
    # https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/
    def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
        assert is_prng_key(key)
        x = self._big_mv_norm.sample(key, sample_shape=sample_shape)
        sign_bit, samples = x[..., 0, jnp.newaxis], x[..., 1:]
        return jnp.where(sign_bit <= 0, -1 * samples, samples) + self.loc
    @property
    def mean(self):
        return jnp.broadcast_to(self._mean, self.shape())
    @property
    def covariance_matrix(self):
        return self._covariance

# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
class SkewMultivariateStudentT(Distribution):  # type: ignore # pylint: disable=too-many-instance-attributes
    arg_constraints = {
        "df": constraints.positive,
        "loc": constraints.real_vector,
        "scale_tril": constraints.lower_cholesky,
        "skewers": constraints.real_vector,
    }
    support = constraints.real_vector
    reparametrized_params = ["df", "loc", "scale_tril", "skewers"]

    def __init__(  # pylint: disable=too-many-arguments
        self,
        df: float,
        scale_tril: NDArray[float],
        skewers: NDArray[float],
        loc: Union[NDArray[float], float] = 0,
        validate_args: None = None,
    ):
        if jnp.ndim(loc) == 0:
            (loc_,) = promote_shapes(loc, shape=(1,))
        else:
            loc_ = cast(NDArray[float], loc)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(df), jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
        )
        (self.df,) = promote_shapes(df, shape=batch_shape)
        (self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
        (self.skewers,) = promote_shapes(skewers, shape=batch_shape + skewers.shape[-1:])
        (self.scale_tril,) = promote_shapes(scale_tril, shape=batch_shape + scale_tril.shape[-2:])

        self._width = scale_tril.shape[-1]

        # For log_prob
        self._mv_t = MultivariateStudentT(df=df, scale_tril=scale_tril, loc=loc)
        eye = jnp.broadcast_to(jnp.eye(self._width), shape=batch_shape + scale_tril.shape[-2:])
        prec_scale_tril = jnp.linalg.cholesky(cho_solve((self.scale_tril, True), eye))
        self.prec = jnp.einsum("...ij,...hj->...ih", prec_scale_tril, prec_scale_tril)
        cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
        self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))

        # For sample
        self._mv_skew_norm = SkewMultivariateNormal(
            scale_tril=scale_tril, loc=jnp.zeros(self._width), skewers=skewers
        )
        self._chi2 = Chi2(self.df)

        # Mean
        b = jnp.sqrt(self.df / jnp.pi) * jnp.exp(gammaln((self.df - 1) / 2) - gammaln(self.df / 2))
        skew_mean = b[..., jnp.newaxis] * delta(self.skewers / self._std_devs, cov_batch)
        self._mean = self.loc + skew_mean
        # The paper says we should multiply by the std devs but that produces results that
        # disagree with `sample` and with `SkewMultivariateNormal`
        # It also says we should use `_mean` instead of `skew_mean` but that allows for
        # covariance matrices which are not positive semi-definite
        self._covariance = jnp.array((self.df / (self.df - 2)))[
            ..., jnp.newaxis, jnp.newaxis
        ] * cov_batch - jnp.einsum("...i,...j->...ij", skew_mean, skew_mean)

        event_shape = jnp.shape(self.scale_tril)[-1:]
        super().__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args,
        )
    @validate_sample
    def log_prob(self, value: NDArray[float]) -> NDArray[float]:
        distance = value - self.loc
        Qy = jnp.einsum("...j,...jk,...k->...", distance, self.prec, distance)
        # Have to use approximation because `betainc` doesn't have grads defined.
        # Which means we can't use the official `StudentT.cdf`
        skew = t_cdf_approx(
            self.df + self._width,
            jnp.einsum(
                "...k,...k->...",
                self.skewers,
                jnp.einsum(
                    "...i,...->...i", distance / self._std_devs, jnp.sqrt((self.df + self._width) / (Qy + self.df))
                ),
            ),
        )
        return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)
    @staticmethod
    def infer_shapes(df: float, loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]):
        event_shape = (scale_tril[-1],)
        batch_shape = lax.broadcast_shapes(df, loc[:-1], scale_tril[:-2], skewers[:-1])
        return batch_shape, event_shape
    def sample(self, key: PRNGKey, sample_shape: tuple[int, ...] = ()) -> NDArray[float]:
        assert is_prng_key(key)
        key_normal, key_chi2 = random.split(key)
        normal = self._mv_skew_norm.sample(key_normal, sample_shape=sample_shape)
        chi = self._chi2.sample(key_chi2, sample_shape)
        return self.loc + jnp.einsum("...i,...->...i", normal, jnp.sqrt(self.df / chi))
    @property
    def mean(self):
        return jnp.broadcast_to(self._mean, self.shape())
    @property
    def covariance_matrix(self):
        return self._covariance

(I also have some coding testing them.)

  1. Is there interest in upstreaming these?
  2. Are there obvious simplifications?
  3. SkewMultivariateStudentT is notably slower than MultivariateStudentT in some circumstances. Are there any obvious performance improvements available?

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions