Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express...

26
Riemannian Stein Variational Gradient Descent for Bayesian Inference Chang Liu, Jun Zhu 1 Dept. of Comp. Sci. & Tech., TNList Lab; Center for Bio-Inspired Computing Research State Key Lab for Intell. Tech. & Systems, Tsinghua University, Beijing, China {chang-li14@mails., dcszj@}tsinghua.edu.cn AAAI’18 @ New Orleans 1 Corresponding author. Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 1 / 24

Transcript of Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express...

Page 1: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian Stein Variational Gradient Descent forBayesian Inference

Chang Liu, Jun Zhu1

Dept. of Comp. Sci. & Tech., TNList Lab; Center for Bio-Inspired Computing ResearchState Key Lab for Intell. Tech. & Systems, Tsinghua University, Beijing, China

{chang-li14@mails., dcszj@}tsinghua.edu.cn

AAAI’18 @ New Orleans

1Corresponding author.Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 1 / 24

Page 2: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Introduction

1 Introduction

2 Preliminaries

3 Riemannian SVGDDerivation of the Directional DerivativeDerivation of the Functional GradientExpression in the Embedded Space

4 ExperimentsBayesian Logistic RegressionSpherical Admixture Model

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 2 / 24

Page 3: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Introduction

Introduction

Bayesian inference: given a dataset D and a Bayesian model p(x,D),estimate the posterior of the latent variable p(x|D).

Comparison of current inference methods: model-based variationalinference methods (M-VIs), Monte Carlo methods (MCs) andparticle-based variational inference methods (P-VIs)

Methods M-VIs MCs P-VIs

Asymptotic Accuracy No Yes Promising

Approximation Flexibility Limited Unlimited Unlimited

Iteration Effectiveness Yes Weak Strong

Particle Efficiency (do not apply) Weak Strong

Stein Variational Gradient Descent (SVGD) [7]: a P-VI with minimalassumption and impressive performance.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 3 / 24

Page 4: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Introduction

Introduction

In this work:

Generalize SVGD to the Riemann manifold settings, so that we can:

Purpose 1

Adapt SVGD to tasks on Riemann manifold and introduce the first P-VI tothe Riemannian world.

Purpose 2

Improve SVGD efficiency for usual tasks (ones on Euclidean space) byexploring information geometry.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 4 / 24

Page 5: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Preliminaries

1 Introduction

2 Preliminaries

3 Riemannian SVGDDerivation of the Directional DerivativeDerivation of the Functional GradientExpression in the Embedded Space

4 ExperimentsBayesian Logistic RegressionSpherical Admixture Model

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 5 / 24

Page 6: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Preliminaries

Stein Variational Gradient Descent (SVGD)

The idea of SVGD:

A deterministic continuous-time dynamics ddtx(t) = φ(x(t)) on

M = Rm (where φ : Rm → Rm) will induce a continuously evolvingdistribution qt on M.

At some instant t, for a fixed dynamics φ, find the decreasing rate ofKL(qt||p), i.e. the Directional Derivative − d

dtKL(qt||p) in the“direction” of φ.

Find φ that maximizes the directional derivative , i.e. theFunctional Gradient φ∗ (the steepest ascending “direction”).For close-form solution, φ∗ is chosen from Hm, where H is thereproducing kernel Hilbert space (RKHS) of some kernel.

Apply the dynamics φ∗ to samples {x(s)}Ss=1 of qt:{x(s) + εφ∗(x(s))}Ss=1 forms a set of samples of qt+ε.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 6 / 24

Page 7: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD

1 Introduction

2 Preliminaries

3 Riemannian SVGDDerivation of the Directional DerivativeDerivation of the Functional GradientExpression in the Embedded Space

4 ExperimentsBayesian Logistic RegressionSpherical Admixture Model

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 7 / 24

Page 8: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD

Roadmap

For a general Riemann manifold M,

Any deterministic continuous-time dynamics on M is described by avector field X on M. It induces a continuously evolving distributionon M with density qt (w.r.t. Riemann volume form).

Derive the Directional Derivative − ddtKL(qt||p) under dynamics X.

Derive the Functional GradientX∗ := (max · arg max)‖X‖X=1 − d

dtKL(qt||p).

Moreover, for Purpose 1, express X∗ in the Embedded Space of Mwhen M has no global coordinate systems (c.s.), e.g. hyperspheres.

Finally, simulate the dynamics X∗ for a small time step ε to updatesamples.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 8 / 24

Page 9: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Derivation of the Directional Derivative

Derivation of the Directional Derivative

Let qt be the evolving density under dynamics X.

Lemma (Continuity Equation on Riemann Manifold)

∂qt∂t

= −div(qtX) = −X[qt]− qtdiv(X).

X[qt]: the action of the vector field X on the smooth function qt. Inany c.s., X[qt] = Xi∂iqt.

div(X): the divergence of vector field X. In any c.s.,div(X) = ∂i(

√|G|Xi)/

√|G|, where G is the matrix expression

under the c.s. of the Riemann metric of M.

Theorem (Directional Derivative)Let p be a fixed distribution. Then the directional derivative is

− d

dtKL(qt||p)=Eqt [div(pX)/p]=Eqt

[X[log p]+div(X)

].

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 9 / 24

Page 10: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Derivation of the Functional Gradient

Derivation of the Functional Gradient

The task now:

X∗ := (max · arg max)X∈X,‖X‖X=1 J (X) := Eq[X[log p] + div(X)

],

where X is some subspace of the space of vector fields on M, such thatthe requirements are met:

Requirements on X∗, thus on X

R1: X∗ is a valid vector field on M;

R2: X∗ is coordinate invariant;

R3: X∗ can be expressed in closed form, where q appears only interms of expectation.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 10 / 24

Page 11: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Derivation of the Functional Gradient

Derivation of the Functional Gradient

R1: X∗ is a valid vector field on M.

Why needed: deductions are based on valid vector fields.

Note: non-trivial to guarantee!

Example (Vector fields on hyperspheres)

Vector fields on an even-dimensional hypersphere must have onezero-vector-valued point (critical point) due to the hairy ball theorem ([1],Theorem 8.5.13). The choice in SVGD X = Hm cannot guarantee R1.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 11 / 24

Page 12: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Derivation of the Functional Gradient

Derivation of the Functional Gradient

R2: X∗ is coordinate invariant.

Concept: the expression of an object on M in any c.s. is the same.E.g. vector field, gradient and divergence.

Why needed: necessary to avoid ambiguity or arbitrariness of thesolution. The vector field X∗ should be independent of the choice ofc.s. in which it is expressed.

Note: the choice in SVGD X = Hm cannot guarantee R2.

R3: X∗ can be expressed in closed form, where q appears only in terms ofexpectation.

Why needed: for tractable implementation, and for avoiding makingrestrictive assumptions on q.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 12 / 24

Page 13: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Derivation of the Functional Gradient

Derivation of the Functional Gradient

Our Solution

X = {grad f |f ∈ H}, where H is the RKHS of some kernel.

grad f is the gradient of the smooth function f . In any c.s.,(grad f)j = gij∂if , where gij is the entry of G−1 under the c.s.

Theorem

For Gaussian RKHS, X is isometrically isomorphic to H, thus it is aHilbert space.

Our solution guarantees all the requirements:

The gradient is a well-defined object on M and it is guaranteed to bea valid vector field and coordinate invariant (see paper for detailedinterpretation).

Close-form solution can be derived (see next).

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 13 / 24

Page 14: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Derivation of the Functional Gradient

Derivation of the Functional Gradient

The close-form solution:

Theorem (Functional Gradient)

X∗′ = grad′ f∗′, f∗′ = Eq[(gradK)[log p] + ∆K

],

where notations with prime “ ′ ” take x′ as argument while others take xand K takes both, and ∆f := div(grad f). In any c.s.,

X∗′i

= g′ij∂′jEq[(gab∂a log(p

√|G|) + ∂ag

ab)∂bK + gab∂a∂bK

].

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 14 / 24

Page 15: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Derivation of the Functional Gradient

Derivation of the Functional Gradient

Purpose 2

Improve efficiency for the usual inference tasks on Euclidean space Rm.

Apply the idea of information geometry [3, 2]:for a Bayesian model with prior p(x) and likelihood p(D|x), takeM = {p(·|x) : x ∈ Rm} and treat x as the coordinate of p(·|x). Inthis global c.s., G(x) is the Fisher information matrix of p(·|x) (andtypically subtract by the Hessian of log p(x)).

Calculate the tangent vector at each sample using the c.s. expression

X∗′i

= g′ij∂′jEq[(gab∂a log(p

√|G|) + ∂ag

ab)∂bK + gab∂a∂bK

],

where the target distribution p = p(x|D) ∝ p(x)p(D|x) and theexpectation is estimated by averaging over samples.

Simulate the dynamics for a small time step ε to update samples:

x(s) ← x(s) + εX∗(x(s)).

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 15 / 24

Page 16: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Expression in the Embedded Space

Expression in the Embedded Space

Purpose 1

Enable applicability to inference tasks on non-linear Riemann manifolds.

In the coordinate space of M:

Some manifolds have no global c.s., e.g. hypersphereSn−1 := {x ∈ Rn : ‖x‖ = 1} and Stiefel manifold [5]. Cumbersomeswitch among local c.s.

G would be singular near the edge of coordinate space.

In the embedded space of M:

M can be expressed globally, and is natural for Sn−1 and Stiefelmanifold.

No singularity problems.

Requires exponential map and density w.r.t. Hausdorff measure,which are available for Sn−1 and Stiefel manifold.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 16 / 24

Page 17: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Expression in the Embedded Space

Expression in the Embedded Space

Proposition (Functional Gradient in the Embedded Space)

Let m-dim Riemann manifold M isometrically embedded in Rn (withorthonormal basis {yα}nα=1)) via Ξ :M→ Rn. Let p be the density

w.r.t. the Hausdorff measure on Ξ(M). Then X∗′ = (In−N ′N ′>)∇′f∗′,

f∗′ = Eq[(∇ log

(p√|G|))>(

In −NN>)

(∇K) +∇>∇K

− tr(N>(∇∇>K)N

)+(

(M>∇)>(G−1M>))

(∇K)],

where In ∈ Rn×n is the identity matrix, ∇ = (∂y1 , . . . , ∂yn)>,

M ∈ Rn×m : Mαi = ∂yα

∂xi, N ∈ Rn×(n−m) is the set of orthonormal basis

of the orthogonal complement of Ξ∗(TxM) , and tr(·) is the trace.

Simulating the dynamics requires the exponential map Exp of M:

y(s) ← Expy(s)(εX∗(y(s))).

Expy(v): moves y on Ξ(M) “straightly” along the direction of v.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 17 / 24

Page 18: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Riemannian SVGD Expression in the Embedded Space

Expression in the Embedded Space

Proposition (Functional Gradient for Embedded Hyperspheres)

For Sn−1 isometrically embedded in Rn with orthonormal basis {yα}nα=1,

we have X∗′ = (In − y′y′>)∇′f∗′, where f∗′ =

Eq[(∇log p

)>(∇K) +∇>∇K − y>

(∇∇>K

)y − (y>∇log p+ n− 1)y>∇K

].

Exponential map on Sn−1:

Expy(v) = y cos(‖v‖) + (v/‖v‖) sin(‖v‖).

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 18 / 24

Page 19: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Experiments

1 Introduction

2 Preliminaries

3 Riemannian SVGDDerivation of the Directional DerivativeDerivation of the Functional GradientExpression in the Embedded Space

4 ExperimentsBayesian Logistic RegressionSpherical Admixture Model

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 19 / 24

Page 20: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Experiments Bayesian Logistic Regression

BLR: for Purpose 2

Model: Bayesian Logistic Regression (BLR)w ∼ N (0, αIm), yd ∼ Bern(σ(w>xd)), where σ(x) = 1/(1 + e−x).

Euclidean task: w ∈ Rm.Posterior: p(w|{(xd, yd)}), log-density gradient known.Riemann metric tensor G: FisherInfo−Hessian, known in close form.

Kernel: Gaussian kernel in the coordinate space.

Baselines: vanilla SVGD.

Evaluation: averaged test accuracy.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 20 / 24

Page 21: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Experiments Bayesian Logistic Regression

BLR: for Purpose 2

Results:

0 50 100 150 200 250 300 350 4000.6

0.65

0.7

0.75

0.8

0.85

iteration

accu

racy

SVGDRSVGD (Ours)

(a) On Splice19 dataset

0 50 100 150 200 250 3000.52

0.54

0.56

0.58

0.6

0.62

0.64

0.66

0.68

0.7

0.72

iterationac

cura

cy

SVGDRSVGD (Ours)

(b) On Covertype dataset

Figure: Test accuracy along iteration for BLR. Both methods are run 20 times on Splice19 and

10 times on Covertype. Each run on Covertype uses a random train(80%)-test(20%) split as

in [7].

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 21 / 24

Page 22: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Experiments Spherical Admixture Model

SAM: for Purpose 1

Model: Spherical Admixture Model (SAM) [8]Observed var.: tf-idf representation of documents: vd ∈ SV−1.Latent var.: spherical topics: βt ∈ SV−1.

Non-linear Riemann manifold task: β ∈ (SV−1)T .Posterior: p(β|v) (w.r.t. the Hausdorff measure), log-density gradientcan be estimated [6].

Kernel: von-Mises Fisher (vMF) kernel K(y, y′) = exp(κy>y′), therestriction of Gaussian kernel in Rn on Sn−1.Baselines:

Variational Inference (VI) [8]: the vanilla inference method of SAM.Geodesic Monte Carlo (GMC) [4]: MCMC for RM in the embed. sp.Stochastic Gradient GMC (SGGMC) [6]: SG-MCMC for RM in theembeded space. (-b: mini-batch grad. est. -f: full-batch grad. est.)For MCMCs, -seq: samples from one chain. -par: newest samples frommultiple chains.

Evaluation: log-perplexity (negative log-likelihood of test datasetunder the trained model) [6].

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 22 / 24

Page 23: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Experiments Spherical Admixture Model

SAM: for Purpose 1

Results:

20 40 60 80 100 120 140 160 180 2003800

4000

4200

4400

4600

4800

5000

5200

5400

5600

5800

6000

epochs

log−

perp

lexi

ty

VIGMC−seqGMC−parSGGMCb−seqSGGMCb−parSGGMCf−seqSGGMCf−parRSVGD (Ours)

(a) Results with 100 particles

20 40 60 80 100 120 140 1603800

4000

4200

4400

4600

4800

5000

5200

5400

5600

5800

6000

number of particleslo

g−pe

rple

xity

GMC−seqGMC−parSGGMCb−seqSGGMCb−parSGGMCf−seqSGGMCf−parRSVGD (Ours)

(b) Results at 200 epochs

Figure: Results on the SAM inference task on 20News-different dataset, in log-perplexity. We

run SGGMCf for full batch and SGGMCb for a mini-batch size of 50.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 23 / 24

Page 24: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

Thank you!

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 24 / 24

Page 25: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

References

Ralph Abraham, Jerrold E Marsden, and Tudor Ratiu.

Manifolds, tensor analysis, and applications, volume 75.

Springer Science & Business Media, 2012.

Shun-Ichi Amari.

Information geometry and its applications.

Springer, 2016.

Shun-Ichi Amari and Hiroshi Nagaoka.

Methods of information geometry, volume 191.

American Mathematical Soc., 2007.

Simon Byrne and Mark Girolami.

Geodesic monte carlo on embedded manifolds.

Scandinavian Journal of Statistics, 40(4):825–845, 2013.

I. M. James.

The topology of Stiefel manifolds, volume 24.

Cambridge University Press, 1976.

Chang Liu, Jun Zhu, and Yang Song.

Stochastic gradient geodesic mcmc methods.

In Advances In Neural Information Processing Systems, pages 3009–3017, 2016.Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 24 / 24

Page 26: Riemannian Stein Variational Gradient Descent for Bayesian ... · Moreover, for Purpose 1, express X in the Embedded Space of M when Mhas no global coordinate systems (c.s.), e.g.

References

Qiang Liu and Dilin Wang.

Stein variational gradient descent: A general purpose bayesian inference algorithm.

In Advances in Neural Information Processing Systems, pages 2370–2378, 2016.

Joseph Reisinger, Austin Waters, Bryan Silverthorn, and Raymond J. Mooney.

Spherical topic models.

In Proceedings of the 27th International Conference on Machine Learning(ICML-10), pages 903–910, 2010.

Chang Liu and Jun Zhu (THU) Riemannian Stein Variational Gradient Descent for Bayesian Inference 24 / 24