WITH ADEMACHER OMPLEXITY OUNDS - GitHub Pages · 2019-10-07 · heuristic (MED). In terms of the...

1
H YPERPARAMETER L EARNING FOR C ONDITIONAL M EAN E MBEDDINGS WITH R ADEMACHER C OMPLEXITY B OUNDS Kelvin Hsu, Richard Nock, Fabio Ramos R ESEARCH S UMMARY Background: Conditional mean embeddings (CMEs) are kernel models that nonparamet- rically encode expectations under conditional distributions, forming a flexible and powerful framework for probabilistic inference. Problem: Their hyperparameters are notori- ously difficult to tune or learn. Question: Can we design a scalable hyper- parameter learning algorithm for CMEs to ensure good generalization? Contribution: We show that when CMEs are used to estimate multiclass probabili- ties, there are learning-theoretic bounds based on Rademacher complexities that result in a complexity-based hyperparameter learning al- gorithm which 1) balances data fit and model complexity, 2) amends to batch stochastic gra- dient updates, and 3) demonstrates capabil- ity to learn more flexible kernels such as those constructed from neural networks. T OY E XAMPLE :N ON -S EPARABLE I RIS Setup: The data is non-separable by any means – the same x R 2 may be assigned different labels y ∈{1, 2, 3}. It is very easy for models to overfit by forcing a pattern or underfit by giving up. Result: Our learning algorithm can drive the model from any initial state, overfitted (left) or un- derfitted (right), to a complexity balanced state where generalization accuracy is the highest. A LGORITHM 1: Input: kernel family k θ : X×X R, dataset {x i ,y i } n i=1 , initial kernel hyperparameters θ 0 , initial regularization hyperparameter λ 0 , learning rate η , cross entropy loss threshold , batch size n b 2: θ θ 0 , λ λ 0 3: repeat 4: Sample the next batch I b N n , |I b | = n b 5: Y ←{δ (y i ,c): i ∈I b ,c N m } ∈{0, 1} n b ×m 6: K θ ←{k θ (x i ,x j ): i ∈I b ,j ∈I b } R n b ×n b 7: L θ,λ cholesky(K θ + n b λI n b ) R n b ×n b 8: V θ,λ L T θ,λ \(L θ,λ \Y ) R n b ×m 9: P θ,λ K θ V θ,λ R n b ×m 10: r (θ,λ) α(θ ) q trace(V T θ,λ K θ V θ,λ ) 11: q (θ, λ) 1 n b n b i=1 L ((Y ) i , (P θ,λ ) i )+4er (θ,λ) 12: (θ, λ) GradientBasedUpdate(q,θ,λ; η ) 13: until maximum iterations reached 14: Output: kernel hyperparameters θ , regularisation hyperparameter λ M ETHOD Idea: Consider the CME in the multiclass setting, referred to as the mul- ticlass conditional embedding (MCE), whose empirical form is: ˆ p(x)= f (x) := Y T (K θ + nλI ) -1 k θ (x). (1) Use Rademacher complexity bounds (RCB) to bound its expected risk: Theorem 4.1 For any n N + and observations {x i ,y i } n i=1 used to define f θ,λ (1), with probability 1 - β over iid samples {X i ,Y i } n i=1 of length n from P XY , every θ Θ, λ Λ, and (0,e -1 ) satisfies E[L e -1 (Y, f θ,λ (X ))] 1 n n i=1 L (Y i , f θ,λ (X i ))+4er (θ,λ)+ q 8 n log 2 β , where the RCB is r (θ,λ) := p sup x∈X k θ (x, x)tr(Y T (K θ + nλI ) -1 K θ (K θ + nλI ) -1 Y ). Propose an objective based on this bound, and extensions thereof, to en- sure good generalization by balancing data fit and model complexity: q (θ,λ) := 1 n n X i=1 L (y i , f θ,λ (x i )) + 4er (θ,λ). (2) D EEP MNIST & ARD MNIST Deep MNIST: Apply our learning algorithm to an MCE with kernels constructed from deep convolutional neural networks. Result: Highly scalable with flexible representations Improved test accuracy: 99.48% v.s. 99.26% Faster convergence in network training ARD MNIST: Apply our learning algorithm to perform automatic relevance determination (ARD) on MNIST pixels. Result: UCI E XPERIMENTS Experimental Setup: Compare our learning algorithm to exisiting hyperparameter tuning algo- rithms on UCI datasets, as well as to other models as standard benchmarks. GMCE, GMCE-SGD, and CEN-1/2 are variations to our approach. GMCE and GMCE-SGD use anisotropic Gaussian kernels with full and batch stochastic gradient update. CEN-1 and CEN-2 employ kernels constructed from fully connected neural networks with 16-32-8 and 96-32 hidden units. Result: Our learning algorithm achieves higher test accuracy across a range of datasets compared to other methods such as empirical risk minimization (ERM), cross validation (CV), and median heuristic (MED). In terms of the comparison to benchmark models, our algorithm performs on par with benchmarks using neural networks (a, c), probabilistic binary trees (b), decision trees (d), and regularized discriminant analysis (e). Table 1: Test accuracy (%) of multiclass conditional embeddings on UCI datasets against benchmarks Method banknote ecoli robot segment wine yeast GMCE 99.9 ± 0.2 87.5 ± 4.4 96.7 ± 0.9 98.4 ± 0.8 97.2 ± 3.7 52.5 ± 2.1 GMCE-SGD 98.8 ± 0.9 84.5 ± 5.0 95.5 ± 0.9 96.1 ± 1.5 93.3 ± 6.0 60.3 ± 4.4 CEN-1 99.5 ± 1.0 87.5 ± 3.2 82.3 ± 7.1 94.6 ± 1.6 96.1 ± 5.0 55.8 ± 5.0 CEN-2 99.4 ± 0.9 86.3 ± 6.0 94.5 ± 0.8 96.7 ± 1.1 97.2 ± 5.1 59.6 ± 4.0 ERM 99.9 ± 0.2 72.1 ± 20.5 91.0 ± 3.7 98.1 ± 1.1 93.9 ± 5.2 45.9 ± 6.4 CV 99.9 ± 0.2 73.8 ± 23.8 90.9 ± 3.4 98.3 ± 1.3 93.3 ± 7.4 58.0 ± 5.8 MED 92.0 ± 4.3 42.1 ± 47.7 81.1 ± 6.2 27.3 ± 26.4 93.3 ± 7.8 31.2 ± 14.1 Benchmarks 99.78 a 81.1 b 97.59 c 96.83 d 100 e 55.0 b

Transcript of WITH ADEMACHER OMPLEXITY OUNDS - GitHub Pages · 2019-10-07 · heuristic (MED). In terms of the...

Page 1: WITH ADEMACHER OMPLEXITY OUNDS - GitHub Pages · 2019-10-07 · heuristic (MED). In terms of the comparison to benchmark models, our algorithm performs on par with benchmarks using

HYPERPARAMETER LEARNING FORCONDITIONAL MEAN EMBEDDINGS

WITH RADEMACHER COMPLEXITY BOUNDSKelvin Hsu, Richard Nock, Fabio Ramos

RESEARCH SUMMARYBackground: Conditional mean embeddings(CMEs) are kernel models that nonparamet-rically encode expectations under conditionaldistributions, forming a flexible and powerfulframework for probabilistic inference.

Problem: Their hyperparameters are notori-ously difficult to tune or learn.

Question: Can we design a scalable hyper-parameter learning algorithm for CMEs toensure good generalization?

Contribution: We show that when CMEsare used to estimate multiclass probabili-ties, there are learning-theoretic bounds basedon Rademacher complexities that result in acomplexity-based hyperparameter learning al-gorithm which 1) balances data fit and modelcomplexity, 2) amends to batch stochastic gra-dient updates, and 3) demonstrates capabil-ity to learn more flexible kernels such as thoseconstructed from neural networks.

TOY EXAMPLE: NON-SEPARABLE IRIS

0 100 200 300 400 500Iterations

−1.5−1.0−0.50.00.51.01.52.0

log(r

(θ,λ

))

Rademacher Complexity Bound

Initially Overfitted ModelInitially Underfitted Model

0 100 200 300 400 500Iterations

556065707580859095

Accurac

y (%

) AccuracyTraining AccuracyTest Accuracy

Setup: The data is non-separable by any means – the same x ∈ R2 may be assigned different labelsy ∈ {1, 2, 3}. It is very easy for models to overfit by forcing a pattern or underfit by giving up.

Result: Our learning algorithm can drive the model from any initial state, overfitted (left) or un-derfitted (right), to a complexity balanced state where generalization accuracy is the highest.

ALGORITHM1: Input: kernel family kθ : X × X → R, dataset {xi, yi}ni=1, initial

kernel hyperparameters θ0, initial regularization hyperparameter λ0,learning rate η, cross entropy loss threshold ε, batch size nb

2: θ ← θ0, λ← λ03: repeat4: Sample the next batch Ib ⊆ Nn, |Ib| = nb5: Y ← {δ(yi, c) : i ∈ Ib, c ∈ Nm} ∈ {0, 1}nb×m

6: Kθ ← {kθ(xi, xj) : i ∈ Ib, j ∈ Ib} ∈ Rnb×nb

7: Lθ,λ ← cholesky(Kθ + nbλInb) ∈ Rnb×nb

8: Vθ,λ ← LTθ,λ\(Lθ,λ\Y ) ∈ Rnb×m

9: Pθ,λ ← KθVθ,λ ∈ Rnb×m

10: r(θ, λ)← α(θ)√

trace(V Tθ,λKθVθ,λ)

11: q(θ, λ)← 1nb

∑nb

i=1 Lε((Y )i, (Pθ,λ)i) + 4e r(θ, λ)

12: (θ, λ)← GradientBasedUpdate(q, θ, λ; η)13: until maximum iterations reached14: Output: kernel hyperparameters θ, regularisation hyperparameter λ

METHODIdea: Consider the CME in the multiclass setting, referred to as the mul-ticlass conditional embedding (MCE), whose empirical form is:

p̂(x) = f(x) := YT (Kθ + nλI)−1kθ(x). (1)

Use Rademacher complexity bounds (RCB) to bound its expected risk:Theorem 4.1 For any n ∈ N+ and observations {xi, yi}ni=1 used to definefθ,λ (1), with probability 1 − β over iid samples {Xi, Yi}ni=1 of length n fromPXY , every θ ∈ Θ, λ ∈ Λ, and ε ∈ (0, e−1) satisfies E[Le−1(Y, fθ,λ(X))] ≤1n

∑ni=1 Lε(Yi, fθ,λ(Xi))+4e r(θ, λ)+

√8n log 2

β , where the RCB is r(θ, λ) :=√supx∈X kθ(x, x)tr(YT (Kθ + nλI)−1Kθ(Kθ + nλI)−1Y).

Propose an objective based on this bound, and extensions thereof, to en-sure good generalization by balancing data fit and model complexity:

q(θ, λ) :=1

n

n∑i=1

Lε(yi, fθ,λ(xi)) + 4e r(θ, λ). (2)

DEEP MNIST & ARD MNISTDeep MNIST: Apply our learning algorithmto an MCE with kernels constructed from deepconvolutional neural networks. Result:

• Highly scalable with flexible representations

• Improved test accuracy: 99.48% v.s. 99.26%

• Faster convergence in network training

0 100 200 300 400 500 600 700 800Epochs

98.0098.2598.5098.7599.0099.2599.5099.75

Accurac

y (%

)

Te() Perf%rmance by Learning C%nv. Fea)ure(

C%ndi)i%nal Embedding Ne)w%rkOriginal C%nv%lu)i%nal Ne)w%rk

0.020.030.040.050.060.070.080.09

L%((

ARD MNIST: Apply our learning algorithmto perform automatic relevance determination(ARD) on MNIST pixels. Result:

50 500 1000 1500 2000 5000Size of Training Data

80828486889092949698

Accurac

y (%

)

Test Acc(%acy by lea%ning Ga(ssian Ke%nels

SVCGPC

MCE lea%ned )itho(t RCBMCE lea%ned )ith RCB

UCI EXPERIMENTSExperimental Setup: Compare our learning algorithm to exisiting hyperparameter tuning algo-rithms on UCI datasets, as well as to other models as standard benchmarks. GMCE, GMCE-SGD,and CEN-1/2 are variations to our approach. GMCE and GMCE-SGD use anisotropic Gaussiankernels with full and batch stochastic gradient update. CEN-1 and CEN-2 employ kernelsconstructed from fully connected neural networks with 16-32-8 and 96-32 hidden units.

Result: Our learning algorithm achieves higher test accuracy across a range of datasets comparedto other methods such as empirical risk minimization (ERM), cross validation (CV), and medianheuristic (MED). In terms of the comparison to benchmark models, our algorithm performs on parwith benchmarks using neural networks (a, c), probabilistic binary trees (b), decision trees (d), andregularized discriminant analysis (e).

Table 1: Test accuracy (%) of multiclass conditional embeddings on UCI datasets against benchmarksMethod banknote ecoli robot segment wine yeast

GMCE 99.9± 0.2 87.5± 4.4 96.7± 0.9 98.4± 0.8 97.2± 3.7 52.5± 2.1GMCE-SGD 98.8± 0.9 84.5± 5.0 95.5± 0.9 96.1± 1.5 93.3± 6.0 60.3± 4.4CEN-1 99.5± 1.0 87.5± 3.2 82.3± 7.1 94.6± 1.6 96.1± 5.0 55.8± 5.0CEN-2 99.4± 0.9 86.3± 6.0 94.5± 0.8 96.7± 1.1 97.2± 5.1 59.6± 4.0ERM 99.9± 0.2 72.1± 20.5 91.0± 3.7 98.1± 1.1 93.9± 5.2 45.9± 6.4CV 99.9± 0.2 73.8± 23.8 90.9± 3.4 98.3± 1.3 93.3± 7.4 58.0± 5.8MED 92.0± 4.3 42.1± 47.7 81.1± 6.2 27.3± 26.4 93.3± 7.8 31.2± 14.1Benchmarks 99.78a 81.1b 97.59c 96.83d 100e 55.0b