Skip to content

Making rand gradients work for more distributions #123

Open
@dcjones

Description

@dcjones

Trying to compute gradients of the rand function wrt to parameters for certain distributions will produce incorrect results, because some of these functions use branching or iterated algorithms and AD can't take into account how the parameters affect control flow.

A simple demonstration of this is just trying to estimate d/dθ E[x] by estimating E[d/dθ x].

  • Normal of course works: d/dμ E[x] = d/dμ μ = 1 and
julia> mean(gradient-> rand(Normal(μ, 1.0)), 1.0)[1] for _ in 1:10000) # should be ≈ 1.0
1.0

(which works for any values of μ, σ)

Gamma will not return a gradient for some values, and return incorrect results for others. E.g. d/dα E[x] = d/dα αβ = β, yet

julia> mean(gradient-> rand(Gamma(α, 2.0)), 1.01)[1] for _ in 1:10000) # should be ≈ 2.0
2.782440982911109
julia> mean(gradient-> rand(Gamma(α, 2.0)), 1.00)[1] for _ in 1:10000) # should be ≈ 2.0
ERROR: MethodError: no method matching /(::Nothing, ::Int64)

Beta similarly d/dα E[x] = d/dα α/(α+β) = β / (α+β)^2 yet

julia> mean(gradient-> rand(Beta(α, 2.0)), 2.0)[1] for _ in 1:10000) # should be ≈ 0.125
0.14264055366214703
julia> mean(gradient-> rand(Beta(α, 3.0)), 1.0)[1] for _ in 1:10000) # should be ≈ 0.1875
ERROR: MethodError: no method matching /(::Nothing, ::Int64)

It's well known that some distributions (e.g. Gamma, Beta, Dirichlet) don't lend themselves easily to this kind of pathwise gradient which makes them infrequently used as surrogate posteriors for VI, but there have been some papers on trying to work around this using numerical approximations and other techniques. See for example:

Figurnov, Mikhail, Shakir Mohamed, and Andriy Mnih. 2018. “Implicit Reparameterization Gradients.” In Advances in Neural Information Processing Systems 31, edited by S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, 441–52. Curran Associates, Inc.

Jankowiak, Martin, and Fritz Obermeyer. 2018. “Pathwise Derivatives Beyond the Reparameterization Trick.” arXiv [stat.ML]. arXiv. http://arxiv.org/abs/1806.01851.

I'd love to help improve the rand situation, but I'm still getting my bearings with this code, so I was hoping for some pointers.

My vague thought was that there might be a TuringGamma, TuringBeta, etc that implement alternative rand functions that are correctly differentiated. Is there a nicer approach, or is this the best option?

Second, for distributions where there is no viable way to AD rand, is there something better that can be done than report incorrect numbers? Should the remedy be in Distributions?

(Related issue is #113)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions