Composing graphical models and neural networks for ... · Composing graphical models and neural...

1
Composing graphical models and neural networks for structured representations and fast inference modeling idea use PGM priors to organize the latent space, along with neural net observation models for flexible representations inference idea use PGMs to synthesize information from recognition nets instead of making a single inference net do everything TL;DR variational autoencoders + latent graphical models probabilistic graphical models + structured representations + priors and uncertainty + data and computational efficiency within rigid model classes rigid assumptions may not fit feature engineering more flexible models can require slow top-down inference deep neural networks neural net “goo” difficult parameterization can require lots of data + flexible, high capacity + feature learning + recognition networks for fast bottom-up inference automatically learn representations in which structured PGMs fit well supervised learning unsupervised learning motivation q p p q p q natural gradient SVI q ? (x) , arg max q (x) L[ q ()q (x)] q ? (x) , N (x|μ(y ; φ), (y ; φ)) q ? (x) , ? variational autoencoders structured VAEs (this work) + optimal local factor expensive for general obs. + exploit conj. graph structure + arbitrary inference queries + natural gradients suboptimal local factor + fast for general obs. does all local inference limited inference queries no natural gradients ± optimal given conj. evidence + fast for general obs. + exploit conj. graph structure + arbitrary inference queries + some natural gradients φ rear dart z 1 z 2 z 3 z 4 z 5 z 6 z 7 x 1 x 2 x 3 x 4 x 5 x 6 x 7 y 1 y 2 y 3 y 4 y 5 y 6 y 7 γ inference L(, γ , x ) , E q ()q (γ )q (x) h log p(,γ ,x)p(y | x,γ ) q ()q (γ )q (x) i b L(, x , φ) , E q ()q (γ )q (x) h log p(,γ ,x) exp{ (x;y,φ)} q ()q (γ )q (x) i where φ(x; y, φ) is a conjugate potential for p(x | ) x (, φ) , arg max x b L(, x , φ) L SVAE (, γ , φ) , L(, γ , x (, φ)) learning to parse mouse behavior from depth video 1(a) step 1 sample minibatch apply recognition networks… 1(b) …to get PGM potentials 1(c) compute local evidence with recognition networks step 2 to fixed point run local mean field 2(a) and message passing 2(b) 2(c) run fast PGM inference step 3 sample and compute flat gradients w.r.t and φ γ 3(a) use already-computed values to get natural gradient w.r.t. 3(b) compute unbiased ELBO gradients with respect to all parameters Matthew James Johnson David Duvenaud Alexander B. Wiltschko Sandeep R. Datta Ryan P. Adams [email protected] [email protected] [email protected] [email protected] [email protected] Proposition (log evidence lower bound) L SVAE (, γ , φ) max x L(, γ , x ) log p(y ) , γ max x L(, γ , x ) log p(y ) , γ max φ L SVAE (, γ , φ) if 9 φ 2 R m with (x; y, φ)= E q (γ ) log p(y | x, γ ) y n γ x n E q (γ ) log p(y n | x n , γ ) x n γ x n y n (x n ; y n , φ) x n /b/ /ax/ /n/ /ae/ /n/ /ax/ 0 10 20 30 40 50 60 70 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 0 10 20 30 40 50 60 70 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 0 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 0 10 20 30 40 50 60 70 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 0 10 20 30 40 50 60 70 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 mm 0 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 10 20 30 40 mm 10 20 30 40 mm 50 60 10 20 30 40 50 60 70 90 80 100 110 120 130 140 150 dart pause rear main idea learn to summarize complicated evidence with simple conjugate potentials (as in CRFs) fit a latent switching linear dynamical system (SLDS) and a neural network image model for observations SVAE approach flat gradient natural gradient github.com/mattjj/svae

Transcript of Composing graphical models and neural networks for ... · Composing graphical models and neural...

Page 1: Composing graphical models and neural networks for ... · Composing graphical models and neural networks for structured representations and fast inference modeling idea use PGM priors

Composing graphical models and neural networksfor structured representations and fast inference

modelingidea

use PGM priors to organize the latent space, along withneural net observation models for flexible representations

inferenceidea

use PGMs to synthesize information from recognition netsinstead of making a single inference net do everything

TL;DR variational autoencoders + latent graphical models

probabilistic graphical models

+ structured representations

+ priors and uncertainty

+ data and computational efficiency within rigid model classes

– rigid assumptions may not fit

– feature engineering

– more flexible models can require slow top-down inference

deep neural networks

– neural net “goo”

– difficult parameterization

– can require lots of data

+ flexible, high capacity

+ feature learning

+ recognition networks for fast bottom-up inference

automatically learn representations in which structured PGMs fit well

supervisedlearning

unsupervisedlearning

motivation

qp p q p qnatural gradient SVI

q

?

(x) , argmax

q(x)L[ q(✓)q(x) ] q

?(x) , N (x|µ(y;�),⌃(y;�))q

?(x) , ?

variational autoencoders structured VAEs (this work)

+ optimal local factor

– expensive for general obs.

+ exploit conj. graph structure

+ arbitrary inference queries

+ natural gradients

– suboptimal local factor

+ fast for general obs.

– does all local inference

– limited inference queries

– no natural gradients

± optimal given conj. evidence

+ fast for general obs.

+ exploit conj. graph structure

+ arbitrary inference queries

+ some natural gradients

reardart

z1 z2 z3 z4 z5 z6 z7

x1 x2 x3 x4 x5 x6 x7

y1 y2 y3 y4 y5 y6 y7

inference

L(⌘✓

, ⌘�

, ⌘x

) , Eq(✓)q(�)q(x)

hlog

p(✓,�,x)p(y | x,�)q(✓)q(�)q(x)

i

bL(⌘✓

, ⌘x

,�) , Eq(✓)q(�)q(x)

hlog

p(✓,�,x) exp{ (x;y,�)}q(✓)q(�)q(x)

i

where �(x; y,�) is a conjugate potential for p(x | ✓)

⌘⇤x

(⌘✓

,�),argmax

x

bL(⌘✓

, ⌘x

,�) LSVAE(⌘✓, ⌘� ,�),L(⌘✓

, ⌘�

, ⌘⇤x

(⌘✓

,�))

learning to parse mouse behavior from depth video

1(a)

step 1

sampleminibatch

apply recognitionnetworks…

1(b)

…to getPGM potentials

1(c)

compute local evidencewith recognition networks step 2

to fixed point

run localmean field

2(a)

and messagepassing

2(b)

2(c)

run fast PGM inference step 3

sample and computeflat gradients w.r.t and � �3(a)

use already-computed values toget natural gradient w.r.t. ⌘✓

3(b)

compute unbiased ELBO gradientswith respect to all parameters

Matthew James JohnsonDavid Duvenaud

Alexander B. WiltschkoSandeep R. Datta

Ryan P. Adams

[email protected]@[email protected]@[email protected]

Proposition (log evidence lower bound)

LSVAE(⌘✓, ⌘� ,�)

max

x

L(⌘✓

, ⌘�

, ⌘x

)

log p(y)

⌘✓, ⌘�

max

x

L(⌘✓

, ⌘�

, ⌘x

)

log p(y)

⌘✓, ⌘�

max

�LSVAE(⌘✓, ⌘� ,�)

if 9� 2 Rmwith (x; y,�) = Eq(�) log p(y |x, �)

yn

�xn

Eq(�) log p(yn |xn, �)

xn

�xn

yn

(xn; yn,�)

xn

/b/ /ax/ /n/ /ae/ /n/ /ax/

0

10 20 30 40 50 60 7010

2030

40

mm

10

20

30

40

mm

50

60

10 20 30 40 50 60 70

mm0

10 20 30 40 50 60 7010

2030

40

mm

10

20

30

40

mm

50

60

10 20 30 40 50 60 70

mm

0mm10 20 30 40 50 60 70 9080 100 110 120 130 140 150

1020

3040

mm

10

20

30

40

mm

50

60

10 20 30 40 50 60 70 9080 100 110 120 130 140 150

0

10 20 30 40 50 60 7010

2030

40

mm

10

20

30

40

mm

50

60

10 20 30 40 50 60 70

mm0

10 20 30 40 50 60 7010

2030

40

mm

10

20

30

40

mm

50

60

10 20 30 40 50 60 70

mm

0mm10 20 30 40 50 60 70 9080 100 110 120 130 140 150

1020

3040

mm

10

20

30

40

mm

50

60

10 20 30 40 50 60 70 9080 100 110 120 130 140 150

dart pause rear

main idea learn to summarize complicated evidencewith simple conjugate potentials (as in CRFs)

fit a latent switching linear dynamical system (SLDS)and a neural network image model for observations

SVAEapproach

flat gradient

natural gradient

github.com/mattjj/svae