A Gram-Gauss-Newton Method Learning Overparameterized Deep...

28
A Gram-Gauss-Newton Method Learning Overparameterized Deep Neural Networks for Regression Problems Tianle Cai Peking University June 21, 2019 Joint work with: Ruiqi Gao, Jikai Hou, Siyu Chen, Dong Wang, Di He, Zhihua Zhang, Liwei Wang Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 1 / 28

Transcript of A Gram-Gauss-Newton Method Learning Overparameterized Deep...

Page 1: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

A Gram-Gauss-Newton Method Learning Overparameterized DeepNeural Networks for Regression Problems

Tianle Cai

Peking University

June 21, 2019

Joint work with: Ruiqi Gao, Jikai Hou, Siyu Chen, Dong Wang, Di He, Zhihua Zhang, LiweiWang

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 1 / 28

Page 2: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Overparameterized Deep Learning

Outline

1 Overparameterized Deep Learning

2 Second-order Methods for Deep Learning

3 Gram-Gauss-Newton Method

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 2 / 28

Page 3: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Overparameterized Deep Learning

Modern Deep Learning

Characterized by highly overparameterization, i.e.,

Overparameterization

The number of parameters m is much larger than the number of data n.

and highly nonconvexity.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 3 / 28

Page 4: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Overparameterized Deep Learning

Strong Expressivity under Optimization

Deep networks can be optimized to zero training error even for noisyproblems [Zhang et al., 2016].

Rethinking Optimization

Why can networks find a global optima despite of nonconvexity?

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 4 / 28

Page 5: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Overparameterized Deep Learning

Interpolation Regime

Overparameterized networks provably converge to zero training loss using gradient descent.[Du et al., 2018, Allen-Zhu et al., 2018, Zou et al., 2018]

Key Insights

Overparameterized networks are approximately linear w.r.t. the parameters in a neighborof initialization (in parameter space).

There is a global optima inside this neighbor.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 5 / 28

Page 6: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Overparameterized Deep Learning

Neural Taylor Expansion

Let f (w , x) denote the network with parameter w ∈ Rm and input x ∈ Rd . We assume theoutput of f (w , x) is a scalar.

Neural Taylor Expansion/ Feature Map of Neural Tangent Kernel (NTK)

Within a neighbor of the initialization w0, ∇w f (w , x) ≈ ∇w f (w0, x). Then we have theapproximate neural Taylor expansion [Chizat and Bach, 2018]:

f (w , x) ≈ f (w0, x)︸ ︷︷ ︸bias term

+ (w − w0)︸ ︷︷ ︸linear parameters

· ∇w f (w0, x)︸ ︷︷ ︸feature of x

,

where ∇w f (w , x) is the gradient of f (w , x) w.r.t. w .

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 6 / 28

Page 7: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Overparameterized Deep Learning

NTK

Neural Tangent Kernel (NTK) [Jacot et al., 2018]

Kernel induced by the feature map Rd → Rm, x → ∇w f (w0, x):

K (x , x ′) = 〈∇w f (w0, x),∇w f (w0, x′)〉.

Remark

Show the behavior of infinite width networks [Jacot et al., 2018].

Exact (kernelized) linear model.

Can be exactly solved by kernel regression. [Arora et al., 2019].

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 7 / 28

Page 8: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Overparameterized Deep Learning

Interpolation Regime (cont.)

Key Insights

Overparameterized networks are approximately linear w.r.t. the parameters in a neighborof initialization (in parameter space).⇒ Gradient descent is applied within a ”not so nonconvex” regime.⇒ Gradient descent can approximately find the optima within this regime.

There is a global optima inside this neighbor.⇒ The optima reached by gradient descent is an (approximate) global optima :)

Question

How about other optimization methods, e.g. second-order methods?

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 8 / 28

Page 9: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Second-order Methods for Deep Learning

Outline

1 Overparameterized Deep Learning

2 Second-order Methods for Deep Learning

3 Gram-Gauss-Newton Method

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 9 / 28

Page 10: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Second-order Methods for Deep Learning

Newton Type Methods

Let L(w) =∑n

i=1 `(f (w , xi ), yi ) denote the loss function on dataset xi , yini=1 with parameterw .

Newton type method

wk+1 = wk − H(wk)−1∇wL(wk),

where ∇wL is the gradient of loss function w.r.t. w and H ∈ Rm×m is the (approximate)Hessian matrix.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 10 / 28

Page 11: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Second-order Methods for Deep Learning

Issues of Second-order Methods for Deep Learning

The number of parameters m is too large. Storing Hessian matrix H ∈ Rm×m andcomputing its inverse is intractable.

Deep networks have complicated structure, so it is hard to get an approximation of theHessian matrix.

Our Goal

Utilize the properties of deep overparameterized networks.

Make second-order methods tractable for deep learning.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 11 / 28

Page 12: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Outline

1 Overparameterized Deep Learning

2 Second-order Methods for Deep Learning

3 Gram-Gauss-Newton Method

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 12 / 28

Page 13: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Recap: Interpolate Regime

Key Insights

Overparameterized networks are approximately linear in a neighbor of initialization. i.e.There is a benign neighbor.

There is a global optima inside this neighbor. i.e. The neighbor is just enough.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 13 / 28

Page 14: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Idea: Exploit the Nearly Linear Behavior within the Interpolation Regime

Regression with Square Loss

minw

L(w) =1

2

n∑i=1

(f (w , xi )− yi )2.

The Hessian H(w) can be calculated as:

∇2wL(w) = J(w)>J(w) +

n∑i=1

(f (w , xi )− yi )∇2w f (w , xi ),

where J(w) = (∇f (w , x1), · · · ,∇f (w , xn))> ∈ Rn×m denotes the Jacobian matrix.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 14 / 28

Page 15: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Idea: Exploit the Nearly Linear Behavior within the Interpolation Regime(cont.)

When f (w , x) is nearly linear w.r.t. w , i.e. ∇2w f (w , x) ≈ 0, we have

H(w) =J(w)>J(w) +n∑

i=1

(f (w , xi )− yi )∇2w f (w , xi )

≈J(w)>J(w),

which suggests J(w)>J(w) ∈ Rm×m is a good approximation of Hessian.

Remark

Same idea as Gauss-Newton method.

Good convergence rate for mildly nonlinear f . [Golub, 1965]

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 15 / 28

Page 16: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Idea: Exploit the Nearly Linear Behavior within the Interpolation Regime(cont.)

Update Rule of Classic Gauss-Newton Method

w →w − ( J(w)>J(w)︸ ︷︷ ︸approximate Hessian

)−1∇wL(w)

=w − ( J(w)>J(w)︸ ︷︷ ︸approximate Hessian

)−1J(w)>(f (w)− y),

where we denote f (w) = (f (w , x1, · · · , f (w , xn))>, y = (y1, · · · , yn)> for convenience. Theequation comes from ∇wL(w) = J(w)>(f (w)− y).

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 16 / 28

Page 17: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Problems of Approximate Hessian

The approximate Hessian J(w)>J(w) ∈ Rm×m is still too large to be stored and inverted.

For overparameterized model, J(w) ∈ Rn×m has rank no more than the number of datan( m). So J(w)>J(w) is not invertible.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 17 / 28

Page 18: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Recap: Neural Taylor Expansian and NTK

Neural Taylor Expansion and NTK

f (w , x) ≈ f (w0, x)︸ ︷︷ ︸bias term

+ (w − w0)︸ ︷︷ ︸linear parameters

· ∇w f (w0, x)︸ ︷︷ ︸feature of x

,

K (x , x ′) = 〈∇w f (w0, x),∇w f (w0, x′)〉.

Remark

The linear expansion approximation used in NTK is same as Gauss-Newton method.

The solution of kernel regression is given by

w = w0 − J(w0)>G (w0)−1(f (w0)− y),

where G (w0)ij = 〈∇w f (w0, xi ),∇w f (w0, xj)〉 is the Gram matrix of NTK.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 18 / 28

Page 19: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Idea: From Hessian to Gram Matrix

Gram Matrix

G (w) = J(w)J(w)> ∈ Rn×n,

where n is the number of data.

Relation between Gram and Approximate Hessian

J(w)>(J(w)J(w)>︸ ︷︷ ︸Gram

)−1 = ( J(w)>J(w)︸ ︷︷ ︸approximate Hessian

)†J(w)>,

where (·)† is the pseudo-inverse.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 19 / 28

Page 20: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Gram-Gauss-Newton Method (GGN)

Update Rule of Gram-Gauss-Newton Method

w → w − J(w)>(J(w)J(w)>︸ ︷︷ ︸Gram

)−1(f (w)− y).

Remark

Can be viewed as solving NTK kernel regression based on current parameters.

Gram matrix is n × n where n m.

Combine with the mini-batch scheme, we can only compute the Gram matrix of eachbatch.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 20 / 28

Page 21: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Second-order Convergence for Overparameterized Networks

Theorem

For two-layer ReLU network, if width M = Ω(

poly(

nλ0

)), then with high probability over the

random initialization, the full-batch GGN satisfies:

1 The Gram matrix G at each iteration is invertible;

2 The loss converges to zero in a way that

‖f (wt+1)− y‖2 ≤C√M‖f (wt)− y‖22

for some constant C that is independent of M.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 21 / 28

Page 22: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Computational Cost

Scale of Key Items

Let m denote the number of parameters, B denote the batch size (e.g. 128).

J(w) ∈ RB×m.

G (w) = J(w)J(w)> ∈ RB×B .

Main Overhead

Computation of Gram matrix.

Computation of the inverse of Gram matrix.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 22 / 28

Page 23: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Experimental Results I

We conduct experiments on two regression datasets, AFAD-LITE task (human age prediction byimage) and RSNA Bone Age regression.

Time

0.0

0.1

0.2

0.3

0.4

0.5

Training Loss

GGN+ResNetFixupSGD+ResNetFixupSGD+ResNetBN

(a) Loss-time curve on AFAD-LITE

0 20 40 60 80 100Epochs

0.0

0.1

0.2

0.3

0.4

0.5

Trai

ning

Los

s

GGN+ResNetFixupSGD+ResNetFixupSGD+ResNetBN

(b) Loss-epoch curve on AFAD-LITE

Figure: Training curves of GGN and SGD on two regression tasks.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 23 / 28

Page 24: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Experimental Results II

Time

0.00

0.05

0.10

0.15

0.20

0.25

0.30

Training Loss

GGN+ResNetFixupSGD+ResNetFixupSGD+ResNetBN

(a) Loss-time curve on RSNA BoneAge

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0Epochs

0.00

0.05

0.10

0.15

0.20

0.25

0.30

Training

Loss

GGN+ResNetFixupSGD+ResNetFixupSGD+ResNetBN

(b) Loss-epoch curve on RSNA BoneAge

Figure: Training curves of GGN and SGD on two regression tasks.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 24 / 28

Page 25: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Take away

Get along with overparameterization: Data-related term (e.g. Gram) is cheaper thanparameter-related term (e.g. Hessian).

Enjoy overparameterization: Utilize the local property of overparameterized networks andachieve higher convergence rate.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 25 / 28

Page 26: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

Q&A

Thank you!See our paper on https://arxiv.org/abs/1905.11675.And slides on http://suo.im/5fsOXj.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 26 / 28

Page 27: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

References I

Allen-Zhu, Z., Li, Y., and Song, Z. (2018).On the convergence rate of training recurrent neural networks.arXiv preprint arXiv:1810.12065.

Arora, S., Du, S. S., Hu, W., Li, Z., Salakhutdinov, R., and Wang, R. (2019).On exact computation with an infinitely wide neural net.arXiv preprint arXiv:1904.11955.

Chizat, L. and Bach, F. (2018).A note on lazy training in supervised differentiable programming.arXiv preprint arXiv:1812.07956.

Du, S. S., Zhai, X., Poczos, B., and Singh, A. (2018).Gradient descent provably optimizes over-parameterized neural networks.arXiv preprint arXiv:1810.02054.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 27 / 28

Page 28: A Gram-Gauss-Newton Method Learning Overparameterized Deep …tianle.mit.edu/sites/default/files/documents/Tianle_GGN... · 2019. 6. 20. · Overparameterized Deep Learning Modern

Gram-Gauss-Newton Method

References II

Golub, G. (1965).Numerical methods for solving linear least squares problems.Numerische Mathematik, 7(3):206–216.

Jacot, A., Gabriel, F., and Hongler, C. (2018).Neural tangent kernel: Convergence and generalization in neural networks.In Advances in neural information processing systems, pages 8571–8580.

Zhang, C., Bengio, S., Hardt, M., Recht, B., and Vinyals, O. (2016).Understanding deep learning requires rethinking generalization.arXiv preprint arXiv:1611.03530.

Zou, D., Cao, Y., Zhou, D., and Gu, Q. (2018).Stochastic gradient descent optimizes over-parameterized deep relu networks.arXiv preprint arXiv:1811.08888.

Tianle Cai (Peking University) Machine learning workshop at PKU June 21, 2019 28 / 28