Statistical learning and optimal control: A framework for biological learning and motor control
description
Transcript of Statistical learning and optimal control: A framework for biological learning and motor control
Statistical learning and optimal control:
A framework for biological learning and motor control
Lecture 1: Iterative learning and the Kalman filter
Reza Shadmehr
Johns Hopkins School of Medicine
Body +environment
State change
Sensory system
ProprioceptionVision
Audition
Measured sensory consequences
Forward model
Predicted sensory consequences
Integration
Belief about state of body
and world
Goalselector
Motor commandgenerator
Stochastic optimal control
Parameter estimation
Kalman filter
Results from classical conditioning
Effect of time on memory: spontaneous recovery
ITI=14 ITI=2
ITI=98
Performance during training
Test at 1 week
ITI=14
ITI=2
ITI=98
Testing at 1 day or 1 week (averaged together)
Effect of time on memory: inter-trial interval and retention
Integration of predicted state with sensory feedback
Choice of motor commands: optimality in saccades and reaching movements
eye velocity
deg
/sec
0 0.05 0.1 0.15 0.2 0.25
0
100
200
300
400
500
Time (sec)
5 10 15 30 40 50 Saccade size
Helpful reading:
1. Mathematical background
• Raul Rojas, The Kalman Filter. Freie Universitat Berlin.
• N.A. Thacker and A.J. Lacey, Tutorial: The Kalman Filter. University of Manchester.
2. Application to animal learning
• Peter Dayan and Angela J. Yu (2003) Uncertainty and learning. IETE Journal of Research 49:171-182.
3. Application to sensorimotor control
• D. Wolpert, Z. Ghahramani, MI Jordan (1995) An internal model for sensorimotor integration. Science
Linear regression, maximum likelihood, and parameter uncertainty
*( ) * ( )
( ) *( ) 2
(1) (1) (1,1) (2) (1,2) ( ) (1, )
1 (1)
(2) (1) (2,1) (2) (2,2) ( ) (2, )
1 (2)
1* 2
0,
, , , , , ,
, , , , , ,
,
i T i
i i
n n
T TML
n n
T TML
TML
y
y y N
D y y y
X X X
D y y y
X X X
N X X
w x
x x x
w y
x x x
w y
w w
A noisy process produces n data points and we form an ML estimate of w.
We run the noisy process again with the same sequence of x’s and re-estimate w:
The distribution of the resulting w will have a var-cov that depends only on the sequence of inputs, the bases that encode those inputs, and the noise sigma.
Bias of the parameter estimates for a given X
• How does the ML estimate behave in the presence of noise in y?
* *
*
X
X
y w
y y ε
y w ε
The “true” underlying process
What we measured
Our model of the process
1
1 1*
1*
T T
T T T T
T T
X X X
X X X X X X
X X X
w y
y ε
w ε
nx1 vector
2,N Iε 0
ML estimate:
Because is normally distributed:
1*
*
, var T TN X X X
E X
w w ε
w w
In other words:
Variance of the parameter estimates for a given X
For a given X, the ML (or least square) estimate of our parameter has this normal distribution:
1*, var T TN X X X
w w ε
1 1 1
1 1
1 12
12
var varT T T T T
T T T T
T T T
T
X X X X X X X X X
X X X X X X
X X X X X X
X X
ε ε
εε
var var TA A Ax x
Matrix of constants
vector of random variables
1* 2, TN X X
w w
2T IεεAssume:
mxm
The Gaussian distribution and its var-cov matrix2
22
1 ( )( ) exp
22
xp x
11 1( ) exp ( ) ( )
2(2 ) | |
T
np C
C
x x μ x μ
[( )( )]ij i i j jc E x x 21 12 1 2 1 1
212 1 2 2 2 2
21 1 2 2
n n
n n
n n n n n
C
A 1-D Gaussian distribution is defined as
In n dimensions, it generalizes to
When x is a vector, the variance is expressed in terms of a covariance matrix C, where ρij corresponds to the degree of correlation between variables xi and xj
2 2
( )( )
( ) ( )
i x i yxy xyi
x yxx yyi x i yi i
x yC C
C Cx y
-2 -1 0 1 2 3
-2
-1
0
1
2
3
-2 -1 0 1 2 3
-2
-1
0
1
2
3
21 12 1 2
212 1 2 2
,N C
C
x μ
1x
2x
0 1 0.9 2,
0 0.9 2 2N
x
-3 -2 -1 0 1 2 3 4
-3
-2
-1
0
1
2
3
4
0 1 0.1 2,
0 0.1 2 2N
x0 1 0.9 2
,0 0.9 2 2
N
x
x1 and x2 are positively correlated x1 and x2 are not correlated x1 and x2 are negatively correlated
Parameter uncertainty: Example 1
• Input history:
1 1 2 2
1 1 2
2 1 2
1* 2
2
ˆ
var cov ,[ ],
cov , var
,
0.5 0.25 0,
0.5 0 1
T
ML
T
y w x w x
w w wN E
w w w
N X X
N
x w
w w
w
1 0 0.5
1 0 0.5
1 0 0.5
1 0 0.5
0 1 0.5
1x 2x*y
x1 was “on” most of the time. I’m pretty certain about w1. However, x2 was “on” only once, so I’m uncertain about w2.
1w
2w
-0.5 0 0.5 1 1.5 2
-0.5
0
0.5
1
1.5
2
1 0
1 0
1 0
1 0
0 1
X
Parameter uncertainty: Example 2
• Input history:
1 1
1 1
1 1
1 1
1 0
X
1 1 2
2 1 2
1* 2
2
var cov ,[ ],
cov , var
,
0.5 1 1,
0.5 1 1.25
ML
T
w w wN E
w w w
N X X
N
w w
w
1 1 1
1 1 1
1 1 1
1 1 1
1 0 0.5
1x 2x*y
x1 and x2 were “on” mostly together. The weight var-cov matrix shows that what I learned is that:
I do not know individual values of w1 and w2 with much certainty.
x1 appeared slightly more often than x2, so I’m a little more certain about the value of w1.
-0.5 0 0.5 1 1.5 2
-0.5
0
0.5
1
1.5
2
1w
2w1 2 1w w
Parameter uncertainty: Example 3
• Input history:
1 1 2
2 1 2
1* 2
2
var cov ,[ ],
cov , var
,
0.5 1.25 0.25,
0.5 0.25 0.25
ML
T
w w wN E
w w w
N X X
N
w w
w
0 1 0.5
0 1 0.5
0 1 0.5
0 1 0.5
1 1 1
1x 2x*y
x2 was mostly “on”. I’m pretty certain about w2, but I am very uncertain about w1. Occasionally x1 and x2 were on together, so I have some reason to believe that:
1w
2w1 2 1w w
-0.5 0 0.5 1 1.5 2
-0.5
0
0.5
1
1.5
2
Effect of uncertainty on learning rate
• When you observe an error in trial n, the amount that you should change w should depend on how certain you are about w. The more certain you are, the less you should be influenced by the error. The less certain you are, the more you should “pay attention” to the error.
( 1) ( ) ( ) ( ) ( ) ( )n n n n n T ny w w k x w
mx1 mx1
Kalman gain
error
Rudolph E. Kalman (1960) A new approach to linear filtering and prediction problems. Transactions of the ASME–Journal of Basic Engineering, 82 (Series D): 35-45.
Research Institute for Advanced Study7212 Bellona Ave, Baltimore, MD
Example of the Kalman gain: running estimate of average
( )
*( ) * ( ) *( ) 2
1( ) ( )
1
1( 1) ( )
1
1( ) ( ) ( ) ( 1) ( ) ( 1) ( )
1
( ) ( 1) ( ) (
1
; 0,
1 1 1
1
1
1
1 1 1 11 1
1
i
i i i
nT
nn T T i
i
nn i
i
nn i n n n n n
i
n n n n
x
y w y y N
X
w X X X yn
w yn
w y y n w y w yn n n n
w w y wn
y
1)
Kalman gain: learning rate decreases as the number of samples increase
As n increases, we trust our past estimate w(n-1) a lot more than the new observation y(n)
Past estimate New measure
w(n) is the online estimate of the mean of y
Example of the Kalman gain: running estimate of variance
*( ) * ( ) * ( ) 2
2 22 ( ) ( ) ( )( )
1 1
1 2 2( ) ( ) ( ) ( )
1
22 ( ) ( )( 1)
22 ( ) ( 1) ( ) ( 1)( 1)
; 0,
1 1ˆ
1
1ˆ1
1 1ˆ1
i i i
n ni i n
ni i
ni n n n
i
n nn
n n n nn
y w y w N
y E y y wn n
y w y wn
n y wn
n y w y wn n
22 ( ) ( 1)( 1)
2 22 ( ) ( 1)( 1)
22 2 ( ) ( 1)( ) ( 1) 2
1 1 1ˆ1 1 1
1 1ˆ1 1
1 1ˆ ˆ
n nn
n nn
n nn n
n y wn n n
n y wn n
n ny w
n n
sigma_hat is the online estimate of the var of y
( ) ( ) * ( ) 2
1
1( ) ( )
1 1( ) ( ) ( )
0,n n T n
n n
n nn n T
n n n n n nn n n T
y N
y
y
x w
w
x w
w w k x w
Objective: adjust learning gain in order to minimize model uncertainty
1 1*
*
1 1 1
n n n n
n n n n
n n n n n n T
n n n n n n T
P E
P E
w w w
w w w
w w
w w
parameter error before I saw the data (a prior error)
parameter error after I saw the data point (a posterior error)
a prior var-cov of parameter error
a posterior var-cov of parameter error
my estimate of w* before I see y in trial n, given that I have seen y up to n-1
error in trial n
my estimate after I see y in trial n
Hypothesis about data observation in trial n
Some observations about model uncertainty
* *
* *
var
Tn n n n n n
Tn n n n
Tn n n n n n n n
n n n n
P E
E
E E E
P
w w w w
w w w w
w w w w
w
We note that P(n) is simply the var-cov matrix of our model weights. It represents the uncertainty in our model.
We want to update the weights so to minimize a measure of this uncertainty.
Trace of parameter var-cov matrix is the sum of squared parameter errors
1 1 1 2
2 2 1 2
2( ) ( )21 11 1
1 1
( )2 ( )21 2 1 2
1
var cov ,0, ,
cov , var0
1 1var
1var var
T
n ni i
i i
ni i
i
P E
w w w wN P N
w w w w
w w E w wn n
trace P w w w wn
ww
w 0
Our objective is to find learning rate k (Kalman gain) such that we minimize the sum of the squared error in our parameter estimates. This sum is the trace of the P matrix. Therefore, given observation y(n), we want to find k such that we minimize the variance of our estimate w.
1 1( ) ( ) ( )n n n n n nn n n Ty
w w k x w
1 1( ) ( ) ( )
1 1( ) ( ) * ( ) ( )
1( ) ( ) ( ) ( ) ( ) ( ) *
1 1
1( ) ( ) ( ) ( ) ( ) ( )
var
var
var
n n n n n nn n n T
n n n n n nn n T n n T
n n n nn n T n n n n T
n n n n
n n n n
Tn nn n T n n T n n
y
I
P
P
I P I
w w k x w
w w k x w x w
w k x w k k x w
w
w
k x k x k
( )
1( ) ( ) ( ) ( ) ( ) 2 ( )
n T
Tn n n nn n T n n T n n TP I P I
k
k x k x k k
Find K to minimize trace of uncertainty
1 1 1( ) ( ) ( ) ( )
1( ) ( ) ( ) 2 ( )
1 1 1( ) ( ) ( ) ( ) ( ) 2 ( )2
n n n n n n n nn n T n n T
n nn n T n n T
n n n n n nn n T n n T n n T
tr P tr P tr P tr P
tr P
tr P tr P tr P
x k k x
k x x k
k x k x x k
1( ) ( ) ( ) ( ) ( ) 2 ( )
1 1 1 1( ) ( ) ( ) ( ) ( ) ( ) ( ) ( ) ( ) 2 ( )
Tn n n nn n T n n T n n T
n n n n n n n nn n T n n T n n T n n T n n T
P I P I
P P P P
k x k x k k
x k k x k x x k k k
tr aB atr B
T
T
tr A tr A
P P
1 1( ) ( ) ( ) 2 ( ) ( ) ( ) 2 ( ) ( )
1( ) ( ) 2 ( ) ( )
1( ) ( ) 2 ( ) ( )
n n n nn n T n n T n T n n n T
n nn T n n n T
n nn T n n T n
tr P tr P
P tr
P
k x x k x x k k
x x k k
x x k k
scalar
Find K to minimize trace of uncertainty
The Kalman gain
Tdtr AB B
dA
If I have a lot of uncertainty about my model, P is large compared to sigma. I will learn a lot from the current error.
If I am pretty certain about my model, P is small compared to sigma. I will tend to ignore the current error.
1 1 1( ) ( ) ( ) ( ) 2 ( ) ( )
1 1( ) ( ) ( ) 2 ( )( )
1 ( )( )
1( ) ( ) 2
2
2 2 0
n n n n n n n nn n T n T n n T n
n n n n n nn n T n nn
n n nn
n nn T n
tr P tr P tr P P
dtr P P P
d
P
P
k x x x k k
x x x kk
xk
x x
Update of model uncertainty
1 1 1 1( ) ( ) ( ) ( ) ( ) ( ) ( ) 2 ( )
11 1( ) ( ) ( ) ( ) 2
1 1 1 1( ) ( ) ( ) 2 ( )
11 1( ) ( ) ( ) 2
n n n n n n n n n nn n T n n T n n T n n T
n n n nn n n T n
Tn n n n n n n n n nn n T n n T
n n n nn n T n
P P P P P
P P
P P P P P
P P
x k k x k x x k
k x x x
x x x x
x x x
1( )
11 1 1( ) ( ) ( ) 2 ( ) ( ) 2
1 1( ) ( ) 2 ( )
11 1 1 1( ) ( ) ( ) 2 ( )
1( ) ( )
n nn T
n n n n n nn n T n n T n
Tn n n nn T n n T
n n n n n n n nn n T n n T
n n n nn n T
P
P P P
P P
P P P P
P I P
x
x x x x x
x x x
x x x x
k xModel uncertainty decreases with every data point that you observe.
* *( 1) ( )
( ) ( ) * ( ) 2
10 10
1 1( ) ( ) ( )
1 ( )( )
1( ) ( ) 2
1( ) ( )
1
1
0,
,
n n
n n T n
n n n n n nn n n T
n n nn
n nn T n
n n n nn n T
n n n n
n n n n
y N
P
y
P
P
P I P
P P
w w
x w
w
w w k x w
xk
x x
k x
w w
*w *w *w
y y y
In this model, we hypothesize that the hidden variables, i.e., the “true” weights, do not change from trial to trial.
Observedvariables
Hidden variable
x x x
A priori estimate of mean and variance of the hidden variable before I observe the first data point
Update of the estimate of the hidden variable after I observed the data point
Forward projection of the estimate to the next trial
* * ( )( 1) ( )
( ) ( ) * ( ) 2
10 10
1 ( )( )
1( ) ( ) 2
1 1( ) ( ) ( )
1( ) ( )
1
1
0,
0,
,
nn n w w
n n T ny y
n n nn
n nn T n
n n n n n nn n n T
n n n nn n T
n n n n
n n n n T
A N Q
y N
P
P
P
y
P I P
A
P AP A Q
w w ε ε
x w
w
xk
x x
w w k x w
k x
w w
*w *w *w
y y y
In this model, we hypothesize that the hidden variables change from trial to trial.
x x x
A priori estimate of mean and variance of the hidden variable before I observe the first data point
Update of the estimate of the hidden variable after I observed the data point
Forward projection of the estimate to the next trial
1n n n n TP AP A Q
• Learning rate is proportional to the ratio between two uncertainties: my model vs. my measurement.
• After we observe an input x, the uncertainty associated with the weight of that input decreases.
• Because of state update noise Q, uncertainty increases as we form the prior for the next trial.
Uncertainty about my model parameters
1 ( )( )
1( ) ( ) 2
n n nn
n nn T n
P
P
xk
x x
Uncertainty about my measurement
1( ) ( )n n n nn n TP I P k x
* *( 1) ( )
( ) ( ) * ( ) 2 0,
n n
n n T ny N
w w
x w
Comparison of Kalman gain to LMS
See derivation of this in homework
( )( )
2
1 1( ) ( ) ( )
1 1( ) ( ) ( )2
n n nn
n n nn n n T
n nn nn n T n
P
y
Py
xk
w w k x w
w x w x
In the Kalman gain approach, the P matrix depends on the history of all previous and current inputs. In LMS, the learning rate is simply a constant that does not depend on past history.
1 1( ) ( ) ( )n n nn n T ny w w x w x
With the Kalman gain, our estimate converges on a single pass over the data set. In LMS, we don’t estimate the var-cov matrix P on each trial, but we will need multiple passes before our estimate converges.
* * ( ) 2( 1) ( )
( ) ( ) * ( ) 2
0.99 0,
=1 0,
nn n w w
n n T ny y
w aw a N q
y x w x N
( )nk
1n nP
2 4 6 8 10
0.65
0.7
0.75
0.8
2 4 6 8 10
2
2.5
3
3.5
4
4.5
5
High noise in the state update model produces increased uncertainty in model parameters. This produces high learning rates.
2 4 6 8 100.5
0.55
0.6
0.65
0.7
0.75
0.8
( )nk
2 4 6 8 10
2
2.5
3
3.5
4
4.5
5
2 22, 1q
2 21, 1q 2 21, 2q
1n nP
2 22, 1q
2 21, 1q
High noise in the measurement also increases parameter uncertainty. But this increase is small relative to measurement uncertainty. Higher measurement noise leads to lower learning rates.
Effect of state and measurement noise on the Kalman gain
* * ( ) 2 2( 1) ( )
( ) ( ) * ( ) 2 2
0, 1
=1 0, 1
nn n w w
n n T ny y
w aw N q q
y x w x N
( )nk
1n nP
Learning rate is higher in a state model that has high auto-correlations (larger a). That is, if the learner assumes that the world is changing slowly (a is close to 1), then the learner will have a large learning rate.
0.99
0.50
0.10
a
a
a
2 4 6 8 10
0.5
0.55
0.6
0.65
0.7
0.75
0.8
2 4 6 8 101
2
3
4
5
Effect of state transition auto-correlation on the Kalman gain