Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1...

24
Stochastic Gradient Descent with Importance Sampling Joint work with Deanna Needell (Claremont McKenna College) and Nathan Srebro (TTIC / Technion) Rachel Ward UT Austin

Transcript of Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1...

Page 1: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

Stochastic Gradient Descent with Importance Sampling

Joint work with Deanna Needell (Claremont McKenna College)

and Nathan Srebro (TTIC / Technion)

Rachel WardUT Austin

Page 2: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

x

(j+1) = x

(j) � �rF (x(j))

= x

(j) � �nX

i=1

rfi(x(j))

Initialize x0

Gradient descent:

Recall: Gradient descent

Problem: Minimize F (x) =

1n

Pni=1 fi(x), n very large

Not practical in huge dimension n!

Page 3: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

Initialize x0

For j = 1, 2, . . .

Stochastic gradient descent

Minimize F (x) = 1n

Pni=1 fi(x)

x

(j+1) = x

(j) � �rfi(x(j))

draw index i = ij at random

Goal: nonasymptotic bounds on Ekx(j) � x

⇤k2

Page 4: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

Minimize F (x) =1

n

nX

i=1

fi(x) = Efi(x)

We begin by assuming

1. krfi(x)�rfi(y)k2 Likx� yk2

2. F is strongly convex: hx� y,rF (x)�rF (y)i � µkx� yk2

Convergence rate of SGD should depend on

1. A condition number,

av =

1n

Pi Li

µ , max

=

maxi Liµ ,

2

=

p1n

Pi L

2i

µ

2. Consistency: �2

= Ekrfi(x⇤)k2

Page 5: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

SGD - convergence rates

Ekx(k) � x⇤k22 " after

k = 2 log("/"0)⇣

supi Li

µ +

�2

µ2"

SGD iterations, with optimized constant step-size.

(Needell, Srebro, W’, 2013) :

Under smoothness and convexity assumptions, the SGD iterates satisfy

Ekx(k) � x⇤k2 [1� 2�µ(1� � supi Li)]k kx0 � x

⇤k2 + ��2

µ(1�� supi Li),

Corollary:

Page 6: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

SGD - convergence rates

We showed:

k = 2 log("/"0)⇣

supi Li

µ +

�2

µ2"

SGD iterations su�ce for Ekx(k) � x⇤k22 "

Compare to:

(Bach and Moulines, 2011) :

k = 2 log("/"0)

✓p1n

Pi L

2i

µ

◆2

+

�2

µ2"

!

SGD iterations su�ce for Ekx(k) � x⇤k22 "

Page 7: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

kx(k+1) � x⇤k2 = kx(k) � x⇤ � �rfi(xk)k2

kx(k) � x⇤k2 � 2� < x

(k) � x⇤,rfi(x(k)) >

+ 2�2krfi(x(k))�rfi(x⇤)k2 + 2�2krfi(x⇤)k2

kx(k) � x⇤k2 � 2� < x

(k) � x⇤,rfi(x(k)) >

+ 2�2Li < x

(k) � x⇤,rfi(xk)�rfi(x⇤) >+ 2�2krfi(x⇤)k2

Di↵erence in proof is that we use the co-coercivity lemma

for smooth functions with Lipschitz gradient

k / supi Li

µ vs. k /⇣

1n

Pi L

2i

µ

⌘2steps

Page 8: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

A =

0

BBBBBBBBBBB@

� a1 �� a2 �� a3 �

...

...

...� an �

1

CCCCCCCCCCCA

F (x) =1

2

nX

i=1

(hai,xi � bi)2

=1

2kAx� bk2

Consider the least squares case:

Assume consistency: Ax⇤ = b, �2= 0

supi Li

µ= (n sup

ikaik2)(kA†k2)

These convergence rates are tight

Page 9: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

0

BBBBB@

1 00 1/

pn

0 1/pn

......

0 1/pn

1

CCCCCA

✓10

◆=

0

BBBBB@

100...0

1

CCCCCA

Consider the system

These convergence rates are tight

Here, supi Li

µ = n supi kaik2kA†k2 = n

In this example, we need k = n steps to get any accuracy

Page 10: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

Better convergence rates using

weighted sampling strategies?

Page 11: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

SGD with weighted sampling

observe:

F (x) = 1n

Pi fi(x) = E(w) 1

wifi(x)

given weights wi such thatP

i wi = n.

SGD unbiased update with weighted sampling: Let f̃i =1wi

fi

x

(j+1) = x

(j) � �rf̃ik(x(j))

= x

(j) � �1

wik

rfik(x(j))

where P(ik = i) = wiPj wj

Page 12: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

Weighted sampling strategies in stochastic optimization not new:

• (Strohmer, Vershynin 2009): randomized Kaczmarz algorithm

• (Nesterov 2010, Lee, Sidford 2013): biased sampling for

accelerated stochastic coordinate descent

SGD with weighted sampling

F (x) = 1n

Pi fi(x) = E(w) 1

wifi(x)

given weights wi such thatP

i wi = n.

Page 13: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

SGD with weighted sampling

Corollary for weighted sampling:

Our previous result: for F (x) = Efi(x) and x

k+1 = x

k � �rfi(xk),

Ekxk � x⇤k22 "

after k = 2 log("/"0)⇣

Lmax

[fi]µ +

�2[fi]µ2"

⌘steps

E(w)kxk � x⇤k22 " after

k = 2 log("/"0)

✓L

max

h1

wifi

i

µ +

�2

h1

wifi

i

µ2"

◆steps

Page 14: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

For F (x) = E(w)fi(x),

Choice of weights

E(w)kxk � x⇤k22 " after

k = 2 log("/"0)

✓L

max

h1

wifi

i

µ +

�2

h1

wifi

i

µ2"

◆steps

E(w)kxk � x⇤k22 " after

k = 2 log("/"0)⇣

1n

Pi Li

µ

⌘steps, using weights wi =

nLiPi Li

If �2

= 0, choose weights to minimize Lmax

h1

wifii:

Page 15: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

0

BBBBB@

1 00 1/

pn

0 1/pn

......

0 1/pn

1

CCCCCA

✓10

◆=

0

BBBBB@

100...0

1

CCCCCA

Improved convergence rate with weighted sampling

Recall the example:

O(n) steps using uniform sampling

Lmax

µ(A) = n )

Li

µ(A) = 2 )

O(1) steps using biased sampling

Since this system is consistent, �2= 0 and

Page 16: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

0

BBBBB@

1 00 1/

pn

0 1/pn

......

0 1/pn

1

CCCCCA

✓10

◆=

0

BBBBB@

100...0

1

CCCCCA

Improved convergence rate with weighted sampling

Recall the example:

But what if the system is not consistent?

Page 17: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

E(w)kxk � x⇤k22 " after

k = 2 log("/"0)

✓L

max

h1

wifi

i

µ +

�2

h1

wifi

i

µ2"

◆steps

Choosing weights wi =nLiPi Li

gives Lmax

h1

wifii=

1

n

Pi Li

Choosing weights wi = 1 gives �2

h1

wifii= �2

[fi]

Partially-biased sampling:Choosing weights wi =

1

2

+

1

2

nLiPi Li

gives

Lmax

h1

wifii 2

1

n

Pi Li and �2

h1

wifii 2�2

[fi]

Partially biased sampling gives strictly better convergence rate,

up to a factor of 2

Page 18: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

SGD - convergence rates

(Bach and Moulines, 2011) :

k = 2 log("/"0)

✓p1n

Pi L

2i

µ

◆2

+

�2

µ2"

!

SGD iterations su�ce for Ekx(k) � x⇤k22 "

Uniform sampling:

(Needell, Srebro, W, 2013:):

k = 2 log("/"0)⇣

supi Li

µ +

�2

µ2"

SGD iterations su�ce for Ekx(k) � x⇤k22 "

(Partially)biased sampling:

k = 4 log("/"0)⇣

1n

Pi Li

µ +

�2

µ2"

SGD iterations su�ce for Ekx(k) � x⇤k22 "

Page 19: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

1. Each rfi has Lipschitz constant Li:

krfi(x)�rfi(y)k2 Likx� yk2

2. F has strong convexity parameter µ:

hx� y,rF (x)�rF (y)i � µkx� yk2

We have been operating in the setting

Other settings and weaker assumptions.- Removing strong convexity assumption- Non-smooth

Page 20: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

Smoothness, but no strong convexity:

• Each rfi has Lipschitz constant Li:

krfi(x)�rfi(y)k2 Likx� yk2

• (Srebro et. al 2010): Number of iterations of SGD:

k = O

✓(supi Li)kx⇤k2

"

F (x⇤) + "

"

• (Foygel and Srebro 2011): Cannot replace supi Li with1n

Pi Li

Using weights wi =Li

1n

Pi Li

,

k = O

✓(

1n

Pi Li)kx⇤k2

"

F (x⇤) + "

"

Page 21: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

• (Srebro et. al 2010): Number of iterations of SGD still

depends linearly on supi Li.

• (Foygel and Srebro 2011): Dependence cannot be

replaced with

1n

Pi Li (using uniform sampling)

Using weights wi =Li

1n

Pi Li

, dependence is replaced by

1n

Pi Li

Even less restrictive, we assume now only that

1. Each fi has Lipschitz constant Gi:

kfi(x)� fi(y)k2 Gikx� yk2

Page 22: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

kfi(x)� fi(y)k2 Gikx� yk2

Less restrictive:

Here, SGD convergence rate depends linearly on G2i

Using weights wi =Gi

1n

Pi Gi

, dependence is reduced to

(

1n

Pi Gi)

2 1n

Pi G

2i

(Zang, Zhao 2013) also consider importance sampling in this setting.

1

n

X

i

G2i = (

1

n

X

i

Gi)2 +Var[Gi]

Page 23: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

Thanks!

Future work:

• Strategies for sampling blocks of indices at each iteration.Optimal block size?

• Optimal adaptive sample strategies given limitedor no information about Li?

• Simulations on real data

Page 24: Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1 n P i L 2 i µ 2 + 2 µ2"! SGD iterations suce for Ekx(k) x ⇤k 2 2 "Uniform sampling:

References

[1] Needell, Srebro, Ward. “Stochastic gradient descent and the randomized

Kaczmarz algorithm.” arXiv preprint arXiv:1310.5715 (2013).