Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1...
Transcript of Stochastic Gradient Descent with Importance Sampling(Bach and Moulines, 2011) : k = 2log("/" 0) p 1...
Stochastic Gradient Descent with Importance Sampling
Joint work with Deanna Needell (Claremont McKenna College)
and Nathan Srebro (TTIC / Technion)
Rachel WardUT Austin
x
(j+1) = x
(j) � �rF (x(j))
= x
(j) � �nX
i=1
rfi(x(j))
Initialize x0
Gradient descent:
Recall: Gradient descent
Problem: Minimize F (x) =
1n
Pni=1 fi(x), n very large
Not practical in huge dimension n!
Initialize x0
For j = 1, 2, . . .
Stochastic gradient descent
Minimize F (x) = 1n
Pni=1 fi(x)
x
(j+1) = x
(j) � �rfi(x(j))
draw index i = ij at random
Goal: nonasymptotic bounds on Ekx(j) � x
⇤k2
Minimize F (x) =1
n
nX
i=1
fi(x) = Efi(x)
We begin by assuming
1. krfi(x)�rfi(y)k2 Likx� yk2
2. F is strongly convex: hx� y,rF (x)�rF (y)i � µkx� yk2
Convergence rate of SGD should depend on
1. A condition number,
av =
1n
Pi Li
µ , max
=
maxi Liµ ,
2
=
p1n
Pi L
2i
µ
2. Consistency: �2
= Ekrfi(x⇤)k2
SGD - convergence rates
Ekx(k) � x⇤k22 " after
k = 2 log("/"0)⇣
supi Li
µ +
�2
µ2"
⌘
SGD iterations, with optimized constant step-size.
(Needell, Srebro, W’, 2013) :
Under smoothness and convexity assumptions, the SGD iterates satisfy
Ekx(k) � x⇤k2 [1� 2�µ(1� � supi Li)]k kx0 � x
⇤k2 + ��2
µ(1�� supi Li),
Corollary:
SGD - convergence rates
We showed:
k = 2 log("/"0)⇣
supi Li
µ +
�2
µ2"
⌘
SGD iterations su�ce for Ekx(k) � x⇤k22 "
Compare to:
(Bach and Moulines, 2011) :
k = 2 log("/"0)
✓p1n
Pi L
2i
µ
◆2
+
�2
µ2"
!
SGD iterations su�ce for Ekx(k) � x⇤k22 "
kx(k+1) � x⇤k2 = kx(k) � x⇤ � �rfi(xk)k2
kx(k) � x⇤k2 � 2� < x
(k) � x⇤,rfi(x(k)) >
+ 2�2krfi(x(k))�rfi(x⇤)k2 + 2�2krfi(x⇤)k2
kx(k) � x⇤k2 � 2� < x
(k) � x⇤,rfi(x(k)) >
+ 2�2Li < x
(k) � x⇤,rfi(xk)�rfi(x⇤) >+ 2�2krfi(x⇤)k2
Di↵erence in proof is that we use the co-coercivity lemma
for smooth functions with Lipschitz gradient
k / supi Li
µ vs. k /⇣
1n
Pi L
2i
µ
⌘2steps
A =
0
BBBBBBBBBBB@
� a1 �� a2 �� a3 �
...
...
...� an �
1
CCCCCCCCCCCA
F (x) =1
2
nX
i=1
(hai,xi � bi)2
=1
2kAx� bk2
Consider the least squares case:
Assume consistency: Ax⇤ = b, �2= 0
supi Li
µ= (n sup
ikaik2)(kA†k2)
These convergence rates are tight
0
BBBBB@
1 00 1/
pn
0 1/pn
......
0 1/pn
1
CCCCCA
✓10
◆=
0
BBBBB@
100...0
1
CCCCCA
Consider the system
These convergence rates are tight
Here, supi Li
µ = n supi kaik2kA†k2 = n
In this example, we need k = n steps to get any accuracy
Better convergence rates using
weighted sampling strategies?
SGD with weighted sampling
observe:
F (x) = 1n
Pi fi(x) = E(w) 1
wifi(x)
given weights wi such thatP
i wi = n.
SGD unbiased update with weighted sampling: Let f̃i =1wi
fi
x
(j+1) = x
(j) � �rf̃ik(x(j))
= x
(j) � �1
wik
rfik(x(j))
where P(ik = i) = wiPj wj
Weighted sampling strategies in stochastic optimization not new:
• (Strohmer, Vershynin 2009): randomized Kaczmarz algorithm
• (Nesterov 2010, Lee, Sidford 2013): biased sampling for
accelerated stochastic coordinate descent
SGD with weighted sampling
F (x) = 1n
Pi fi(x) = E(w) 1
wifi(x)
given weights wi such thatP
i wi = n.
SGD with weighted sampling
Corollary for weighted sampling:
Our previous result: for F (x) = Efi(x) and x
k+1 = x
k � �rfi(xk),
Ekxk � x⇤k22 "
after k = 2 log("/"0)⇣
Lmax
[fi]µ +
�2[fi]µ2"
⌘steps
E(w)kxk � x⇤k22 " after
k = 2 log("/"0)
✓L
max
h1
wifi
i
µ +
�2
h1
wifi
i
µ2"
◆steps
For F (x) = E(w)fi(x),
Choice of weights
E(w)kxk � x⇤k22 " after
k = 2 log("/"0)
✓L
max
h1
wifi
i
µ +
�2
h1
wifi
i
µ2"
◆steps
E(w)kxk � x⇤k22 " after
k = 2 log("/"0)⇣
1n
Pi Li
µ
⌘steps, using weights wi =
nLiPi Li
If �2
= 0, choose weights to minimize Lmax
h1
wifii:
0
BBBBB@
1 00 1/
pn
0 1/pn
......
0 1/pn
1
CCCCCA
✓10
◆=
0
BBBBB@
100...0
1
CCCCCA
Improved convergence rate with weighted sampling
Recall the example:
O(n) steps using uniform sampling
Lmax
µ(A) = n )
Li
µ(A) = 2 )
O(1) steps using biased sampling
Since this system is consistent, �2= 0 and
0
BBBBB@
1 00 1/
pn
0 1/pn
......
0 1/pn
1
CCCCCA
✓10
◆=
0
BBBBB@
100...0
1
CCCCCA
Improved convergence rate with weighted sampling
Recall the example:
But what if the system is not consistent?
E(w)kxk � x⇤k22 " after
k = 2 log("/"0)
✓L
max
h1
wifi
i
µ +
�2
h1
wifi
i
µ2"
◆steps
Choosing weights wi =nLiPi Li
gives Lmax
h1
wifii=
1
n
Pi Li
Choosing weights wi = 1 gives �2
h1
wifii= �2
[fi]
Partially-biased sampling:Choosing weights wi =
1
2
+
1
2
nLiPi Li
gives
Lmax
h1
wifii 2
1
n
Pi Li and �2
h1
wifii 2�2
[fi]
Partially biased sampling gives strictly better convergence rate,
up to a factor of 2
SGD - convergence rates
(Bach and Moulines, 2011) :
k = 2 log("/"0)
✓p1n
Pi L
2i
µ
◆2
+
�2
µ2"
!
SGD iterations su�ce for Ekx(k) � x⇤k22 "
Uniform sampling:
(Needell, Srebro, W, 2013:):
k = 2 log("/"0)⇣
supi Li
µ +
�2
µ2"
⌘
SGD iterations su�ce for Ekx(k) � x⇤k22 "
(Partially)biased sampling:
k = 4 log("/"0)⇣
1n
Pi Li
µ +
�2
µ2"
⌘
SGD iterations su�ce for Ekx(k) � x⇤k22 "
1. Each rfi has Lipschitz constant Li:
krfi(x)�rfi(y)k2 Likx� yk2
2. F has strong convexity parameter µ:
hx� y,rF (x)�rF (y)i � µkx� yk2
We have been operating in the setting
Other settings and weaker assumptions.- Removing strong convexity assumption- Non-smooth
Smoothness, but no strong convexity:
• Each rfi has Lipschitz constant Li:
krfi(x)�rfi(y)k2 Likx� yk2
• (Srebro et. al 2010): Number of iterations of SGD:
k = O
✓(supi Li)kx⇤k2
"
F (x⇤) + "
"
◆
• (Foygel and Srebro 2011): Cannot replace supi Li with1n
Pi Li
Using weights wi =Li
1n
Pi Li
,
k = O
✓(
1n
Pi Li)kx⇤k2
"
F (x⇤) + "
"
◆
• (Srebro et. al 2010): Number of iterations of SGD still
depends linearly on supi Li.
• (Foygel and Srebro 2011): Dependence cannot be
replaced with
1n
Pi Li (using uniform sampling)
Using weights wi =Li
1n
Pi Li
, dependence is replaced by
1n
Pi Li
Even less restrictive, we assume now only that
1. Each fi has Lipschitz constant Gi:
kfi(x)� fi(y)k2 Gikx� yk2
kfi(x)� fi(y)k2 Gikx� yk2
Less restrictive:
Here, SGD convergence rate depends linearly on G2i
Using weights wi =Gi
1n
Pi Gi
, dependence is reduced to
(
1n
Pi Gi)
2 1n
Pi G
2i
(Zang, Zhao 2013) also consider importance sampling in this setting.
1
n
X
i
G2i = (
1
n
X
i
Gi)2 +Var[Gi]
Thanks!
Future work:
• Strategies for sampling blocks of indices at each iteration.Optimal block size?
• Optimal adaptive sample strategies given limitedor no information about Li?
• Simulations on real data
References
[1] Needell, Srebro, Ward. “Stochastic gradient descent and the randomized
Kaczmarz algorithm.” arXiv preprint arXiv:1310.5715 (2013).