Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel...

30
Bayesian Deep Learning and a Probabilistic Perspective of Generalization Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap- proach is marginalization, rather than using a sin- gle setting of weights. Bayesian marginalization can particularly improve the accuracy and cali- bration of modern deep neural networks, which are typically underspecified by the data, and can represent many compelling but different solu- tions. We show that deep ensembles provide an effective mechanism for approximate Bayesian marginalization, and propose a related approach that further improves the predictive distribution by marginalizing within basins of attraction, with- out significant overhead. We also investigate the prior over functions implied by a vague distribu- tion over neural network weights, explaining the generalization properties of such models from a probabilistic perspective. From this perspective, we explain results that have been presented as mysterious and distinct to neural network gener- alization, such as the ability to fit images with random labels, and show that these results can be reproduced with Gaussian processes. We also show that Bayesian model averaging alleviates double descent, resulting in monotonic perfor- mance improvements with increased flexibility. Finally, we provide a Bayesian perspective on tempering for calibrating predictive distributions. 1. Introduction Imagine fitting the airline passenger data in Figure 1. Which model would you choose: (1) f 1 (x)= w 0 + w 1 x, (2) 3 j=0 w j x j , or (3) f 3 (x)= 10 4 j=0 w j x j ? Put this way, most audiences overwhelmingly favour choices (1) and (2), for fear of overfitting. But of these options, choice (3) most honestly represents our beliefs. Indeed, it is likely that the ground truth explanation for the data is out of class for any of these choices, but there is some setting of the coefficients {w j } in choice (3) which provides a better description of reality than could be managed by choices (1) and (2), which are special cases of choice (3). Moreover, our beliefs about the generative processes for our observations, 1949 1951 1953 1955 1957 1959 Year 100k 200k 300k 400k 500k Airline Passengers Figure 1. Airline passenger numbers recorded monthly. which are often very sophisticated, typically ought to be independent of how many data points we happen to observe. And in modern practice, we are implicitly favouring choice (3): we often use neural networks with millions of param- eters to fit datasets with thousands of points. Furthermore, non-parametric methods such as Gaussian processes often involve infinitely many parameters, enabling the flexibil- ity for universal approximation (Rasmussen & Williams, 2006), yet in many cases provide very simple predictive distributions. Indeed, parameter counting is a poor proxy for understanding generalization behaviour. From a probabilistic perspective, we argue that generaliza- tion depends largely on two properties, the support and the inductive biases of a model. Consider Figure 2(a), where on the horizontal axis we have a conceptualization of all possible datasets, and on the vertical axis the Bayesian ev- idence for a model. The evidence, or marginal likelihood, p(D|M)= R p(D|M,w)p(w)dw, is the probability we would generate a dataset if we were to randomly sample from the prior over functions p(f (x)) induced by a prior over parameters p(w). We define the support as the range of datasets for which p(D|M) > 0. We define the inductive biases as the relative prior probabilities of different datasets — the distribution of support given by p(D|M). A similar schematic to Figure 2(a) was used by MacKay (1992) to understand an Occam’s razor effect in using the evidence for model selection; we believe it can also be used to reason about model construction and generalization. arXiv:2002.08791v3 [cs.LG] 27 Apr 2020

Transcript of Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel...

Page 1: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Andrew Gordon Wilson Pavel IzmailovNew York University

AbstractThe key distinguishing property of a Bayesian ap-proach is marginalization rather than using a sin-gle setting of weights Bayesian marginalizationcan particularly improve the accuracy and cali-bration of modern deep neural networks whichare typically underspecified by the data and canrepresent many compelling but different solu-tions We show that deep ensembles provide aneffective mechanism for approximate Bayesianmarginalization and propose a related approachthat further improves the predictive distributionby marginalizing within basins of attraction with-out significant overhead We also investigate theprior over functions implied by a vague distribu-tion over neural network weights explaining thegeneralization properties of such models from aprobabilistic perspective From this perspectivewe explain results that have been presented asmysterious and distinct to neural network gener-alization such as the ability to fit images withrandom labels and show that these results canbe reproduced with Gaussian processes We alsoshow that Bayesian model averaging alleviatesdouble descent resulting in monotonic perfor-mance improvements with increased flexibilityFinally we provide a Bayesian perspective ontempering for calibrating predictive distributions

1 IntroductionImagine fitting the airline passenger data in Figure 1 Whichmodel would you choose (1) f1(x) = w0 + w1x (2)sum3j=0 wjx

j or (3) f3(x) =sum104

j=0 wjxj

Put this way most audiences overwhelmingly favour choices(1) and (2) for fear of overfitting But of these optionschoice (3) most honestly represents our beliefs Indeed it islikely that the ground truth explanation for the data is out ofclass for any of these choices but there is some setting ofthe coefficients wj in choice (3) which provides a betterdescription of reality than could be managed by choices (1)and (2) which are special cases of choice (3) Moreover ourbeliefs about the generative processes for our observations

1949 1951 1953 1955 1957 1959

Year

100k

200k

300k

400k

500k

Air

line

Pas

seng

ers

Figure 1 Airline passenger numbers recorded monthly

which are often very sophisticated typically ought to beindependent of how many data points we happen to observe

And in modern practice we are implicitly favouring choice(3) we often use neural networks with millions of param-eters to fit datasets with thousands of points Furthermorenon-parametric methods such as Gaussian processes ofteninvolve infinitely many parameters enabling the flexibil-ity for universal approximation (Rasmussen amp Williams2006) yet in many cases provide very simple predictivedistributions Indeed parameter counting is a poor proxyfor understanding generalization behaviour

From a probabilistic perspective we argue that generaliza-tion depends largely on two properties the support and theinductive biases of a model Consider Figure 2(a) whereon the horizontal axis we have a conceptualization of allpossible datasets and on the vertical axis the Bayesian ev-idence for a model The evidence or marginal likelihoodp(D|M) =

intp(D|M w)p(w)dw is the probability we

would generate a dataset if we were to randomly samplefrom the prior over functions p(f(x)) induced by a priorover parameters p(w) We define the support as the range ofdatasets for which p(D|M) gt 0 We define the inductivebiases as the relative prior probabilities of different datasetsmdash the distribution of support given by p(D|M) A similarschematic to Figure 2(a) was used by MacKay (1992) tounderstand an Occamrsquos razor effect in using the evidencefor model selection we believe it can also be used to reasonabout model construction and generalization

arX

iv2

002

0879

1v3

[cs

LG

] 2

7 A

pr 2

020

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

p(D|M)

CorruptedCIFAR-10

CIFAR-10 MNIST DatasetStructured Image Datasets

Complex ModelPoor Inductive BiasesExample MLP

Simple ModelPoor Inductive BiasesExample Linear Function

Well-Specified ModelCalibrated Inductive BiasesExample CNN

(a)

True

Model

Prior Hypothesis Space

Posterior

(b)

True Model

Prior Hypothesis Space

Posterior

(c)

True Model

Prior Hypothesis Space

Posterior

(d)

Figure 2 A probabilistic perspective of generalization (a) Ideally a model supports a wide range of datasets but with inductive biasesthat provide high prior probability to a particular class of problems being considered Here the CNN is preferred over the linear modeland the fully-connected MLP for CIFAR-10 (while we do not consider MLP models to in general have poor inductive biases here we areconsidering a hypothetical example involving images and a very large MLP) (b) By representing a large hypothesis space a model cancontract around a true solution which in the real-world is often very sophisticated (c) With truncated support a model will converge to anerroneous solution (d) Even if the hypothesis space contains the truth a model will not efficiently contract unless it also has reasonableinductive biases

From this perspective we want the support of the model tobe large so that we can represent any hypothesis we believeto be possible even if it is unlikely We would even wantthe model to be able to represent pure noise such as noisyCIFAR (Zhang et al 2016) as long as we honestly believethere is some non-zero but potentially arbitrarily smallprobability that the data are simply noise Crucially wealso need the inductive biases to carefully represent whichhypotheses we believe to be a priori likely for a particularproblem class If we are modelling images then our modelshould have statistical properties such as convolutionalstructure which are good descriptions of images

Figure 2(a) illustrates three models We can imagine theblue curve as a simple linear function f(x) = w0 + w1xcombined with a distribution over parameters p(w0 w1)eg N (0 I) which induces a distribution over functionsp(f(x)) Parameters we sample from our prior p(w0 w1)give rise to functions f(x) that correspond to straight lineswith different slopes and intercepts This model thus hastruncated support it cannot even represent a quadratic func-tion But because the marginal likelihood must normal-ize over datasets D this model assigns much mass to thedatasets it does support The red curve could represent alarge fully-connected MLP This model is highly flexiblebut distributes its support across datasets too evenly to beparticularly compelling for many image datasets The greencurve could represent a convolutional neural network whichrepresents a compelling specification of support and induc-tive biases for image recognition this model has the flexibil-ity to represent many solutions but its structural propertiesprovide particularly good support for many image problems

With large support we cast a wide enough net that the poste-rior can contract around the true solution to a given problem

as in Figure 2(b) which in reality we often believe to bevery sophisticated On the other hand the simple model willhave a posterior that contracts around an erroneous solutionif it is not contained in the hypothesis space as in Figure 2(c)Moreover in Figure 2(d) the model has wide support butdoes not contract around a good solution because its supportis too evenly distributed

Returning to the opening example we can justify the highorder polynomial by wanting large support But we wouldstill have to carefully choose the prior on the coefficientsto induce a distribution over functions that would have rea-sonable inductive biases Indeed this Bayesian notion ofgeneralization is not based on a single number but is a twodimensional concept From this probabilistic perspectiveit is crucial not to conflate the flexibility of a model withthe complexity of a model class Indeed Gaussian processeswith RBF kernels have large support and are thus flexiblebut have inductive biases towards very simple solutions Wealso see that parameter counting has no significance in thisperspective of generalization what matters is how a distri-bution over parameters combines with a functional form of amodel to induce a distribution over solutions Rademachercomplexity (Mohri amp Rostamizadeh 2009) VC dimension(Vapnik 1998) and many conventional metrics are by con-trast one dimensional notions corresponding roughly to thesupport of the model which is why they have been foundto provide an incomplete picture of generalization in deeplearning (Zhang et al 2016)

In this paper we reason about Bayesian deep learning froma probabilistic perspective of generalization The key dis-tinguishing property of a Bayesian approach is marginaliza-tion instead of optimization where we represent solutionsgiven by all settings of parameters weighted by their pos-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

terior probabilities rather than bet everything on a singlesetting of parameters Neural networks are typically under-specified by the data and can represent many different buthigh performing models corresponding to different settingsof parameters which is exactly when marginalization willmake the biggest difference for accuracy and calibrationMoreover we clarify that the recent deep ensembles (Lak-shminarayanan et al 2017) are not a competing approachto Bayesian inference but can be viewed as a compellingmechanism for Bayesian marginalization Indeed we em-pirically demonstrate that deep ensembles can provide abetter approximation to the Bayesian predictive distributionthan standard Bayesian approaches We further proposea new method MultiSWAG inspired by deep ensembleswhich marginalizes within basins of attraction mdash achievingsignificantly improved performance with a similar trainingtime

We then investigate the properties of priors over functionsinduced by priors over the weights of neural networks show-ing that they have reasonable inductive biases We also showthat the mysterious generalization properties recently pre-sented in Zhang et al (2016) can be understood by reasoningabout prior distributions over functions and are not specificto neural networks Indeed we show Gaussian processes canalso perfectly fit images with random labels yet generalizeon the noise-free problem These results are a consequenceof large support but reasonable inductive biases for com-mon problem settings We further show that while Bayesianneural networks can fit the noisy datasets the marginal like-lihood has much better support for the noise free datasetsin line with Figure 2 We additionally show that the mul-timodal marginalization in MultiSWAG alleviates doubledescent so as to achieve monotonic improvements in per-formance with model flexibility in line with our perspectiveof generalization MultiSWAG also provides significant im-provements in both accuracy and NLL over SGD trainingand unimodal marginalization Finally we provide severalperspectives on tempering in Bayesian deep learning

In the Appendix we provide several additional experimentsand results We also provide code at httpsgithubcomizmailovpavelunderstandingbdl

2 Related WorkNotable early works on Bayesian neural networks includeMacKay (1992) MacKay (1995) and Neal (1996) Theseworks generally argue in favour of making the model classfor Bayesian approaches as flexible as possible in line withBox amp Tiao (1973) Accordingly Neal (1996) pursued thelimits of large Bayesian neural networks showing that as thenumber of hidden units approached infinity these modelsbecome Gaussian processes with particular kernel functionsThis work harmonizes with recent work describing the neu-

ral tangent kernel (eg Jacot et al 2018)

The marginal likelihood is often used for Bayesian hypothe-sis testing model comparison and hyperparameter tuningwith Bayes factors used to select between models (Kass ampRaftery 1995) MacKay (2003 Ch 28) uses a diagramsimilar to Fig 2(a) to show the marginal likelihood has anOccamrsquos razor property favouring the simplest model con-sistent with a given dataset even if the prior assigns equalprobability to the various models Rasmussen amp Ghahra-mani (2001) reasons about how the marginal likelihoodcan favour large flexible models as long as such modelscorrespond to a reasonable distribution over functions

There has been much recent interest in developing Bayesianapproaches for modern deep learning with new challengesand architectures quite different from what had been con-sidered in early work Recent work has largely focused onscalable inference (eg Blundell et al 2015 Gal amp Ghahra-mani 2016 Kendall amp Gal 2017 Ritter et al 2018 Khanet al 2018 Maddox et al 2019) function-space inspiredpriors (eg Yang et al 2019 Louizos et al 2019 Sunet al 2019 Hafner et al 2018) and developing flat objec-tive priors in parameter space directly leveraging the biasesof the neural network functional form (eg Nalisnick 2018)Wilson (2020) provides a note motivating Bayesian deeplearning

Moreover PAC-Bayes (McAllester 1999) has emerged asan active direction for generalization bounds on stochas-tic neural networks (eg Guedj 2019) with distributionsover the parameters Langford amp Caruana (2002) devised aPAC-Bayes generalization bound for small stochastic neuralnetworks (two layer with two hidden units) achieving animprovement over the existing deterministic generalizationbounds Dziugaite amp Roy (2017) extended this approachoptimizing a PAC-Bayes bound with respect to a parametricdistribution over the weights of the network exploiting theflatness of solutions discovered by SGD As a result they de-rive the first non-vacuous generalization bounds for stochas-tic over-parameterized fully-connected neural networks onbinary MNIST classification Neyshabur et al (2017) alsodiscuss the connection between PAC-Bayes bounds andsharpness Neyshabur et al (2018) then devises PAC-Bayesbounds based on spectral norms of the layers and the Frobe-nius norm of the weights of the network Achille amp Soatto(2018) additionally combine PAC-Bayes and informationtheoretic approaches to argue that flat minima have lowinformation content Masegosa (2019) also proposes varia-tional and ensemble learning methods based on PAC-Bayesanalysis under model misspecification Jiang et al (2019)provide a review and comparison of several generalizationbounds including PAC-Bayes

Early works tend to provide a connection between lossgeometry and generalization using minimum description

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

length frameworks (eg Hinton amp Van Camp 1993 Hochre-iter amp Schmidhuber 1997 MacKay 1995) EmpiricallyKeskar et al (2016) argue that smaller batch SGD providesbetter generalization than large batch SGD by finding flat-ter minima Chaudhari et al (2019) and Izmailov et al(2018) design optimization procedures to specifically findflat minima

By connecting flat solutions with ensemble approximationsIzmailov et al (2018) also suggest that functions associatedwith parameters in flat regions ought to provide differentpredictions on test data for flatness to be helpful in gener-alization which is distinct from the flatness in Dinh et al(2017) Garipov et al (2018) also show that there are modeconnecting curves forming loss valleys which contain avariety of distinct solutions We argue that flat regions ofthe loss containing a diversity of solutions is particularlyrelevant for Bayesian model averaging since the model aver-age will then contain many compelling and complementaryexplanations for the data Additionally Huang et al (2019)describes neural networks as having a blessing of dimension-ality since flat regions will occupy much greater volume ina high dimensional space which we argue means that flatsolutions will dominate in the Bayesian model average

Smith amp Le (2018) and MacKay (2003 Chapter 28) ad-ditionally connect the width of the posterior with Occamfactors from a Bayesian perspective larger width corre-sponds to a smaller Occam factor and thus ought to providebetter generalization Dziugaite amp Roy (2017) and Smith ampLe (2018) also provide different perspectives on the resultsin Zhang et al (2016) which shows that deep convolutionalneural networks can fit CIFAR-10 with random labels andno training error The PAC-Bayes bound of Dziugaite ampRoy (2017) becomes vacuous when applied to randomly-labelled binary MNIST Smith amp Le (2018) show that logis-tic regression can fit noisy labels on sub-sampled MNISTinterpreting the result from an Occam factor perspective

In general PAC-Bayes provides a compelling frameworkfor deriving explicit non-asymptotic generalization boundsThese bounds can be improved by eg fewer parame-ters and very compact priors which can be different fromwhat provides optimal generalization From our perspec-tive model flexibility and priors with large support ratherthan compactness are desirable Our focus is complemen-tary and largely prescriptive aiming to provide intuitionson model construction inference generalization and neu-ral network priors as well as new connections betweenBayesian model averaging and deep ensembles benefitsof Bayesian model averaging specifically in the context ofmodern deep neural networks perspectives on temperingin Bayesian deep learning views of marginalization thatcontrast with simple Monte Carlo as well as new methodsfor Bayesian marginalization in deep learning

In other work Pearce et al (2018) propose a modificationof deep ensembles and argue that it performs approximateBayesian inference and Gustafsson et al (2019) briefly men-tion how deep ensembles can be viewed as samples froman approximate posterior In the context of deep ensembleswe believe it is natural to consider the BMA integral sepa-rately from the simple Monte Carlo approximation that isoften used to approximate this integral to compute an accu-rate predictive distribution we do not need samples from aposterior or even a faithful approximation to the posterior

Fort et al (2019) considered the diversity of predictionsproduced by models from a single SGD run and modelsfrom independent SGD runs and suggested to ensembleaverages of SGD iterates Although MultiSWA (one of themethods considered in Section 4) is related to this idea thecrucial practical difference is that MultiSWA uses a learningrate schedule that selects for flat regions of the loss the keyto the success of the SWA method (Izmailov et al 2018)Section 4 also shows that MultiSWAG which we proposefor multimodal Bayesian marginalization outperforms Mul-tiSWA

Double descent which describes generalization error thatdecreases increases and then again decreases with modelflexibility was demonstrated early by Opper et al (1990)Recently Belkin et al (2019) extensively demonstrateddouble descent leading to a surge of modern interest withNakkiran et al (2019) showing double descent in deep learn-ing Nakkiran et al (2020) shows that tuned l2 regulariza-tion can mitigate double descent Alternatively we showthat Bayesian model averaging particularly based on multi-modal marginalization can alleviate even prominent doubledescent behaviour

Tempering in Bayesian modelling has been considered underthe names Safe Bayes generalized Bayesian inference andfractional Bayesian inference (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) We provide several perspectives of tempering inBayesian deep learning and analyze the results in a recentpaper by Wenzel et al (2020) that questions tempering forBayesian neural networks

3 Bayesian MarginalizationOften the predictive distribution we want to compute isgiven by

p(y|xD) =

intp(y|xw)p(w|D)dw (1)

The outputs are y (eg regression values class labels )indexed by inputs x (eg spatial locations images ) theweights (or parameters) of the neural network f(xw) arew and D are the data Eq (1) represents a Bayesian model

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

average (BMA) Rather than bet everything on one hypoth-esis mdash with a single setting of parameters w mdash we want touse all settings of parameters weighted by their posteriorprobabilities This procedure is called marginalization ofthe parametersw as the predictive distribution of interest nolonger conditions on w This is not a controversial equationbut simply the sum and product rules of probability

31 Importance of Marginalization in Deep Learning

In general we can view classical training as performingapproximate Bayesian inference using the approximate pos-terior p(w|D) asymp δ(w = w) to compute Eq (1) where δis a Dirac delta function that is zero everywhere except atw = argmaxwp(w|D) In this case we recover the standardpredictive distribution p(y|x w) From this perspectivemany alternatives albeit imperfect will be preferable mdashincluding impoverished Gaussian posterior approximationsfor p(w|D) even if the posterior or likelihood are actuallyhighly non-Gaussian and multimodal

The difference between a classical and Bayesian approachwill depend on how sharp the posterior p(w|D) becomes Ifthe posterior is sharply peaked and the conditional predic-tive distribution p(y|xw) does not vary significantly wherethe posterior has mass there may be almost no differencesince a delta function may then be a reasonable approxi-mation of the posterior for the purpose of BMA Howevermodern neural networks are usually highly underspecifiedby the available data and therefore have diffuse likelihoodsp(D|w) not strongly favouring any one setting of param-eters Not only are the likelihoods diffuse but differentsettings of the parameters correspond to a diverse variety ofcompelling hypotheses for the data (Garipov et al 2018Izmailov et al 2019) This is exactly the setting when wemost want to perform a Bayesian model average whichwill lead to an ensemble containing many different but highperforming models for better calibration and accuracy thanclassical training

Loss Valleys Flat regions of low loss (negative log poste-rior density minus log p(w|D)) are associated with good gener-alization (eg Hochreiter amp Schmidhuber 1997 Hinton ampVan Camp 1993 Dziugaite amp Roy 2017 Izmailov et al2018 Keskar et al 2016) While flat solutions that general-ize poorly can be contrived through reparametrization (Dinhet al 2017) the flat regions that lead to good generalizationcontain a diversity of high performing models on test data(Izmailov et al 2018) corresponding to different parametersettings in those regions And indeed there are large con-tiguous regions of low loss that contain such solutions evenconnecting together different SGD solutions (Garipov et al2018 Izmailov et al 2019) (see also Figure 11 Appendix)

Since these regions of the loss represent a large volume in a

high-dimensional space (Huang et al 2019) and providea diversity of solutions they will dominate in forming thepredictive distribution in a Bayesian model average Bycontrast if the parameters in these regions provided similarfunctions as would be the case in flatness obtained throughreparametrization these functions would be redundant inthe model average That is although the solutions of highposterior density can provide poor generalization it is the so-lutions that generalize well that will have greatest posteriormass and thus be automatically favoured by the BMA

Calibration by Epistemic Uncertainty RepresentationIt has been noticed that modern neural networks are oftenmiscalibrated in the sense that their predictions are typicallyoverconfident (Guo et al 2017) For example in classifi-cation the highest softmax output of a convolutional neuralnetwork is typically much larger than the probability of theassociated class label The fundamental reason for miscali-bration is ignoring epistemic uncertainty A neural networkcan represent many models that are consistent with our ob-servations By selecting only one in a classical procedurewe lose uncertainty when the models disagree for a testpoint In regression we can visualize epistemic uncertaintyby looking at the spread of the predictive distribution aswe move away from the data there are a greater varietyof consistent solutions leading to larger uncertainty as inFigure 4 We can further calibrate the model with temperingwhich we discuss in the Appendix Section 8

Accuracy An often overlooked benefit of Bayesian modelaveraging in modern deep learning is improved accuracy Ifwe average the predictions of many high performing modelsthat disagree in some cases we should see significantly im-proved accuracy This benefit is now starting to be observedin practice (eg Izmailov et al 2019) Improvements inaccuracy are very convincingly exemplified by deep en-sembles (Lakshminarayanan et al 2017) which have beenperceived as a competing approach to Bayesian methodsbut in fact provides a compelling mechanism for approxi-mate Bayesian model averaging as we show in Section 33We also demonstrate significant accuracy benefits for multi-modal Bayesian marginalization in Section 7

32 Beyond Monte Carlo

Nearly all approaches to estimating the integral in Eq (1)when it cannot be computed in closed form involvea simple Monte Carlo approximation p(y|xD) asymp1J

sumJj=1 p(y|xwj) wj sim p(w|D) In practice the sam-

ples from the posterior p(w|D) are also approximate andfound through MCMC or deterministic methods The de-terministic methods approximate p(w|D) with a differentmore convenient density q(w|D θ) from which we can sam-ple often chosen to be Gaussian The parameters θ are

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

selected to make q close to p in some sense for exam-ple variational approximations (eg Beal 2003) whichhave emerged as a popular deterministic approach findargminθKL(q||p) Other standard deterministic approxima-tions include Laplace (eg MacKay 1995) EP (Minka2001a) and INLA (Rue et al 2009)

From the perspective of estimating the predictive distribu-tion in Eq (1) we can view simple Monte Carlo as ap-proximating the posterior with a set of point masses withlocations given by samples from another approximate pos-terior q even if q is a continuous distribution That isp(w|D) asympsumJ

j=1 δ(w = wj) wj sim q(w|D)

Ultimately the goal is to accurately compute the predictivedistribution in Eq (1) rather than find a generally accu-rate representation of the posterior In particular we mustcarefully represent the posterior in regions that will makethe greatest contributions to the BMA integral In terms ofefficiently computing the predictive distribution we do notnecessarily want to place point masses at locations givenby samples from the posterior For example functional di-versity is important for a good approximation to the BMAintegral because we are summing together terms of theform p(y|xw) if two settings of the weights wi and wjeach provide high likelihood (and consequently high pos-terior density) but give rise to similar functions f(xwi)f(xwj) then they will be largely redundant in the modelaverage and the second setting of parameters will not con-tribute much to estimating the BMA integral for the uncon-ditional predictive distribution In Sections 33 and 4 weconsider how various approaches approximate the predictivedistribution

33 Deep Ensembles are BMA

Deep ensembles (Lakshminarayanan et al 2017) is fastbecoming a gold standard for accurate and well-calibratedpredictive distributions Recent reports (eg Ovadia et al2019 Ashukha et al 2020) show that deep ensembles ap-pear to outperform some particular approaches to Bayesianneural networks for uncertainty representation leading tothe confusion that deep ensembles and Bayesian methodsare competing approaches These methods are often explic-itly referred to as non-Bayesian (eg Lakshminarayananet al 2017 Ovadia et al 2019 Wenzel et al 2020) Tothe contrary we argue that deep ensembles are actually acompelling approach to Bayesian model averaging in thevein of Section 32

There is a fundamental difference between a Bayesian modelaverage and some approaches to ensembling The Bayesianmodel average assumes that one hypothesis (one parame-ter setting) is correct and averages over models due to aninability to distinguish between hypotheses given limitedinformation (Minka 2000) As we observe more data the

posterior collapses onto a single hypothesis If the trueexplanation for the data is a combination of hypothesesthen the Bayesian model average may appear to performworse as we observe more data Some ensembling methodswork by enriching the hypothesis space and therefore donot collapse in this way Deep ensembles however areformed by MAP or maximum likelihood retraining of thesame architecture multiple times leading to different basinsof attraction The deep ensemble will therefore collapse inthe same way as a Bayesian model average as the posteriorconcentrates Since the hypotheses space (support) for amodern neural network is large containing many differentpossible explanations for the data posterior collapse willoften be desirable

Furthermore by representing multiple basins of attractiondeep ensembles can provide a better approximation to theBMA than the Bayesian approaches in Ovadia et al (2019)Indeed the functional diversity is important for a good ap-proximation to the BMA integral as per Section 32 Theapproaches referred to as Bayesian in Ovadia et al (2019)instead focus their approximation on a single basin whichmay contain a lot of redundancy in function space makinga relatively minimal contribution to computing the Bayesianpredictive distribution On the other hand retraining a neu-ral network multiple times for deep ensembles incurs asignificant computational expense The single basin ap-proaches may be preferred if we are to control for computa-tion We explore these questions in Section 4

4 An Empirical Study of MarginalizationWe have shown that deep ensembles can be interpretedas an approximate approach to Bayesian marginalizationwhich selects for functional diversity by representing multi-ple basins of attraction in the posterior Most Bayesian deeplearning methods instead focus on faithfully approximatinga posterior within a single basin of attraction We propose anew method MultiSWAG which combines these two typesof approaches MultiSWAG combines multiple indepen-dently trained SWAG approximations (Maddox et al 2019)to create a mixture of Gaussians approximation to the pos-terior with each Gaussian centred on a different basin ofattraction We note that MultiSWAG does not require anyadditional training time over standard deep ensembles

We illustrate the conceptual difference between deep en-sembles a standard variational single basin approach andMultiSWAG in Figure 3 In the top panel we have a con-ceptualization of a multimodal posterior VI approximatesthe posterior with multiple samples within a single basinBut we see in the middle panel that the conditional predic-tive distribution p(y|xw) does not vary significantly withinthe basin and thus each additional sample contributes min-imally to computing the marginal predictive distribution

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Deep Ensembles VI Multi-SWAG

ww

w

w

dist(p q)

p(y|w)

p(w|D)

Figure 3 Approximating the BMAp(y|xD) =

intp(y|xw)p(w|D)dw Top p(w|D) with repre-

sentations from VI (orange) deep ensembles (blue) MultiSWAG(red) Middle p(y|xw) as a function of w for a test input xThis function does not vary much within modes but changes sig-nificantly between modes Bottom Distance between the truepredictive distribution and the approximation as a function of rep-resenting a posterior at an additional point w assuming we havesampled the mode in dark green There is more to be gained byexploring new basins than continuing to explore the same basin

p(y|xD) On the other hand p(y|xw) varies significantlybetween basins and thus each point mass for deep ensem-bles contributes significantly to the marginal predictive dis-tribution By sampling within the basins MultiSWAG pro-vides additional contributions to the predictive distributionIn the bottom panel we have the gain in approximatingthe predictive distribution when adding a point mass to therepresentation of the posterior as a function of its locationassuming we have already sampled the mode in dark greenIncluding samples from different modes provides significantgain over continuing to sample from the same mode andincluding weights in wide basins provide relatively moregain than the narrow ones

In Figure 4 we evaluate single basin and multi-basin ap-proaches in a case where we can near-exactly compute thepredictive distribution We provide details for generatingthe data and training the models in Appendix D1 We seethat the predictive distribution given by deep ensembles isqualitatively closer to the true distribution compared to thesingle basin variational method between data clusters thedeep ensemble approach provides a similar representationof epistemic uncertainty whereas the variational method isextremely overconfident in these regions Moreover we seethat the Wasserstein distance between the true predictive dis-tribution and these two approximations quickly shrinks withnumber of samples for deep ensembles but is roughly inde-pendent of number of samples for the variational approachThus the deep ensemble is providing a better approximation

of the Bayesian model average in Eq (1) than the singlebasin variational approach which has traditionally beenlabelled as the Bayesian alternative

Next we evaluate MultiSWAG under distribution shift onthe CIFAR-10 dataset (Krizhevsky et al 2014) replicatingthe setup in Ovadia et al (2019) We consider 16 data cor-ruptions each at 5 different levels of severity introducedby Hendrycks amp Dietterich (2019) For each corruptionwe evaluate the performance of deep ensembles and Mul-tiSWAG varying the training budget For deep ensembleswe show performance as a function of the number of inde-pendently trained models in the ensemble For MultiSWAGwe show performance as a function of the number of inde-pendent SWAG approximations that we construct we thensample 20 models from each of these approximations toconstruct the final ensemble

While the training time for MultiSWAG is the same as fordeep ensembles at test time MultiSWAG is more expen-sive as the corresponding ensemble consists of a largernumber of models To account for situations when testtime is constrained we also propose MultiSWA a methodthat ensembles independently trained SWA solutions (Iz-mailov et al 2018) SWA solutions are the means of thecorresponding Gaussian SWAG approximations Izmailovet al (2018) argue that SWA solutions approximate the localensembles represented by SWAG with a single model

In Figure 5 we show the negative log-likelihood as a func-tion of the number of independently trained models for aPreactivation ResNet-20 on CIFAR-10 corrupted with Gaus-sian blur with varying levels of intensity (increasing fromleft to right) in Figure 5 MultiSWAG outperforms deepensembles significantly on highly corrupted data For lowerlevels of corruption MultiSWAG works particularly wellwhen only a small number of independently trained modelsare available We note that MultiSWA also outperforms deepensembles and has the same computational requirements attraining and test time as deep ensembles We present resultsfor other types of corruption in Appendix Figures 14 15 1617 showing similar trends In general there is an extensiveevaluation of MultiSWAG in the Appendix

Our perspective of generalization is deeply connected withBayesian marginalization In order to best realize the bene-fits of marginalization in deep learning we need to consideras many hypotheses as possible through multimodal poste-rior approximations such as MultiSWAG In Section 7 wereturn to MultiSWAG showing how it can entirely alleviateprominent double descent behaviour and lead to strikingimprovements in generalization over SGD and single basinmarginalization for both accuracy and NLL

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10

minus05

00

05

10

15

20

25

30

35

(a) Exact

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(b) Deep Ensembles

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(c) Variational Inference

0 10 20 30 40 50

Samples

008

010

012

014

016

018

Avg

Wx(q

p)

Deep Ensembles

SVI

(d) Discrepancy with True BMA

Figure 4 Approximating the true predictive distribution (a) A close approximation of the true predictive distribution obtained bycombining 200 HMC chains (b) Deep ensembles predictive distribution using 50 independently trained networks (c) Predictivedistribution for factorized variational inference (VI) (d) Convergence of the predictive distributions for deep ensembles and variationalinference as a function of the number of samples we measure the average Wasserstein distance between the marginals in the range ofinput positions The multi-basin deep ensembles approach provides a more faithful approximation of the Bayesian predictive distributionthan the conventional single-basin VI approach which is overconfident between data clusters The top panels show the Wassersteindistance between the true predictive distribution and the deep ensemble and VI approximations as a function of inputs x

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

Figure 5 Negative log likelihood for Deep Ensembles MultiSWAG and MultiSWA using a PreResNet-20 on CIFAR-10 with varyingintensity of the Gaussian blur corruption The image in each plot shows the intensity of corruption For all levels of intensity MultiSWAGand MultiSWA outperform Deep Ensembles for a small number of independent models For high levels of corruption MultiSWAGsignificantly outperforms other methods even for many independent models We present results for other corruptions in the Appendix

5 Neural Network PriorsA prior over parameters p(w) combines with the functionalform of a model f(xw) to induce a distribution over func-tions p(f(xw)) It is this distribution over functions thatcontrols the generalization properties of the model the priorover parameters in isolation has no meaning Neural net-works are imbued with structural properties that providegood inductive biases such as translation equivariance hi-erarchical representations and sparsity In the sense ofFigure 2 the prior will have large support due to the flexi-bility of neural networks but its inductive biases provide themost mass to datasets which are representative of problemsettings where neural networks are often applied In thissection we study the properties of the induced distributionover functions We directly continue the discussion of priorsin Section 6 with a focus on examining the noisy CIFAR re-sults in Zhang et al (2016) from a probabilistic perspectiveof generalization These sections are best read together

We also provide several additional experiments in the Ap-pendix In Section E we present analytic results on thedependence of the prior distribution in function space onthe variance of the prior over parameters considering also

layer-wise parameter priors with ReLU activations As partof a discussion on tempering in Section 84 we study theeffect of α in p(w) = N (0 α2I) on prior class probabilitiesfor individual sample functions p(f(xw)) the predictivedistribution and posterior samples as we observe varyingamounts of data In Section F we further study the correla-tion structure over images induced by neural network priorssubject to perturbations of the images In Section D3 weprovide additional experimental details

51 Deep Image Prior and Random Network Features

Two recent results provide strong evidence that vague Gaus-sian priors over parameters when combined with a neu-ral network architecture induce a distribution over func-tions with useful inductive biases In the deep image priorUlyanov et al (2018) show that randomly initialized convo-lutional neural networks without training provide excellentperformance for image denoising super-resolution and in-painting This result demonstrates the ability for a sam-ple function from a random prior over neural networksp(f(xw)) to capture low-level image statistics before anytraining Similarly Zhang et al (2016) shows that pre-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

Figure 6 Induced prior correlation function Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNISTinduced by LeNet-5 for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images fromdifferent classes suggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α (d)NLL of an ensemble of 20 SWAG samples on MNIST as a function of α using a LeNet-5

00 02 04 06 08 104

3

2

1

0

1

2

3

4

(a) Prior Draws

00 02 04 06 08 10

(b) True Labels

00 02 04 06 08 10

(c) Corrupted Labels

0 5k 10k 12k 15k 20k 25k

Altered Labels

minus070

minus065

minus060

minus055

minus050

minus045

Mar

gina

lL

ikel

ihoo

dE

stim

ate

15

20

25

30

35

40

45

Tes

tE

rror

(d) Gaussian Process

0 5k 10k 20k 30k 40k 50k

Resampled Labels

minus5

minus4

minus3

minus2

minus1

Mar

gina

lL

ikel

ihoo

dE

stim

ate times107

20

40

60

80

Tes

tE

rror

(e) PreResNet-20

Figure 7 Rethinking generalization (a) Sample functions from a Gaussian process prior (b) GP fit (with 95 credible region) tostructured data generated as ygreen(x) = sin(x middot 2π)+ ε ε sim N (0 022) (c) GP fit with no training error after a significant addition ofcorrupted data in red drawn from Uniform[05 1] (d) Variational GP marginal likelihood with RBF kernel for two classes of CIFAR-10(e) Laplace BNN marginal likelihood for a PreResNet-20 on CIFAR-10 with different fractions of random labels The marginal likelihoodfor both the GP and BNN decreases as we increase the level of corruption in the labels suggesting reasonable inductive biases in the priorover functions Moreover both the GP and BNN have 100 training accuracy on images with fully corrupted labels

processing CIFAR-10 with a randomly initialized untrainedconvolutional neural network dramatically improves the testperformance of a simple Gaussian kernel on pixels from54 accuracy to 71 Adding `2 regularization only im-proves the accuracy by an additional 2 These resultsagain indicate that broad Gaussian priors over parametersinduce reasonable priors over networks with a minor ad-ditional gain from decreasing the variance of the prior inparameter space which corresponds to `2 regularization

52 Prior Class Correlations

In Figure 6 we study the prior correlations in the outputs ofthe LeNet-5 convolutional network (LeCun et al 1998) onobjects of different MNIST classes We sample networkswith weights p(w) = N (0 α2I) and compute the values oflogits corresponding to the first class for all pairs of imagesand compute correlations of these logits For all levels of αthe correlations between objects corresponding to the sameclass are consistently higher than the correlation betweenobjects of different classes showing that the network in-duces a reasonable prior similarity metric over these imagesAdditionally we observe that the prior correlations some-

what decrease as we increase α showing that bounding thenorm of the weights has some minor utility in accordancewith Section 51 Similarly in panel (d) we see that the NLLsignificantly decreases as α increases in [0 05] and thenslightly increases but is relatively constant thereafter

In the Appendix we further describe analytic results andillustrate the effect of α on sample functions

53 Effect of Prior Variance on CIFAR-10

We further study the effect of the parameter prior stan-dard deviation α measuring performance of approximateBayesian inference for CIFAR-10 with a PreactivationResNet-20 (He et al 2016) and VGG-16 (Simonyan ampZisserman 2014) For each of these architectures we runSWAG (Maddox et al 2019) with fixed hyper-parametersand varying α We report the results in Figure 12(d) (h) Forboth architectures the performance is near-optimal in therange α isin [10minus2 10minus1] Smaller α constrains the weightstoo much Performance is reasonable and becomes mostlyinsensitive to α as it continues to increase due to the induc-tive biases of the functional form of the neural network

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 2: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

p(D|M)

CorruptedCIFAR-10

CIFAR-10 MNIST DatasetStructured Image Datasets

Complex ModelPoor Inductive BiasesExample MLP

Simple ModelPoor Inductive BiasesExample Linear Function

Well-Specified ModelCalibrated Inductive BiasesExample CNN

(a)

True

Model

Prior Hypothesis Space

Posterior

(b)

True Model

Prior Hypothesis Space

Posterior

(c)

True Model

Prior Hypothesis Space

Posterior

(d)

Figure 2 A probabilistic perspective of generalization (a) Ideally a model supports a wide range of datasets but with inductive biasesthat provide high prior probability to a particular class of problems being considered Here the CNN is preferred over the linear modeland the fully-connected MLP for CIFAR-10 (while we do not consider MLP models to in general have poor inductive biases here we areconsidering a hypothetical example involving images and a very large MLP) (b) By representing a large hypothesis space a model cancontract around a true solution which in the real-world is often very sophisticated (c) With truncated support a model will converge to anerroneous solution (d) Even if the hypothesis space contains the truth a model will not efficiently contract unless it also has reasonableinductive biases

From this perspective we want the support of the model tobe large so that we can represent any hypothesis we believeto be possible even if it is unlikely We would even wantthe model to be able to represent pure noise such as noisyCIFAR (Zhang et al 2016) as long as we honestly believethere is some non-zero but potentially arbitrarily smallprobability that the data are simply noise Crucially wealso need the inductive biases to carefully represent whichhypotheses we believe to be a priori likely for a particularproblem class If we are modelling images then our modelshould have statistical properties such as convolutionalstructure which are good descriptions of images

Figure 2(a) illustrates three models We can imagine theblue curve as a simple linear function f(x) = w0 + w1xcombined with a distribution over parameters p(w0 w1)eg N (0 I) which induces a distribution over functionsp(f(x)) Parameters we sample from our prior p(w0 w1)give rise to functions f(x) that correspond to straight lineswith different slopes and intercepts This model thus hastruncated support it cannot even represent a quadratic func-tion But because the marginal likelihood must normal-ize over datasets D this model assigns much mass to thedatasets it does support The red curve could represent alarge fully-connected MLP This model is highly flexiblebut distributes its support across datasets too evenly to beparticularly compelling for many image datasets The greencurve could represent a convolutional neural network whichrepresents a compelling specification of support and induc-tive biases for image recognition this model has the flexibil-ity to represent many solutions but its structural propertiesprovide particularly good support for many image problems

With large support we cast a wide enough net that the poste-rior can contract around the true solution to a given problem

as in Figure 2(b) which in reality we often believe to bevery sophisticated On the other hand the simple model willhave a posterior that contracts around an erroneous solutionif it is not contained in the hypothesis space as in Figure 2(c)Moreover in Figure 2(d) the model has wide support butdoes not contract around a good solution because its supportis too evenly distributed

Returning to the opening example we can justify the highorder polynomial by wanting large support But we wouldstill have to carefully choose the prior on the coefficientsto induce a distribution over functions that would have rea-sonable inductive biases Indeed this Bayesian notion ofgeneralization is not based on a single number but is a twodimensional concept From this probabilistic perspectiveit is crucial not to conflate the flexibility of a model withthe complexity of a model class Indeed Gaussian processeswith RBF kernels have large support and are thus flexiblebut have inductive biases towards very simple solutions Wealso see that parameter counting has no significance in thisperspective of generalization what matters is how a distri-bution over parameters combines with a functional form of amodel to induce a distribution over solutions Rademachercomplexity (Mohri amp Rostamizadeh 2009) VC dimension(Vapnik 1998) and many conventional metrics are by con-trast one dimensional notions corresponding roughly to thesupport of the model which is why they have been foundto provide an incomplete picture of generalization in deeplearning (Zhang et al 2016)

In this paper we reason about Bayesian deep learning froma probabilistic perspective of generalization The key dis-tinguishing property of a Bayesian approach is marginaliza-tion instead of optimization where we represent solutionsgiven by all settings of parameters weighted by their pos-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

terior probabilities rather than bet everything on a singlesetting of parameters Neural networks are typically under-specified by the data and can represent many different buthigh performing models corresponding to different settingsof parameters which is exactly when marginalization willmake the biggest difference for accuracy and calibrationMoreover we clarify that the recent deep ensembles (Lak-shminarayanan et al 2017) are not a competing approachto Bayesian inference but can be viewed as a compellingmechanism for Bayesian marginalization Indeed we em-pirically demonstrate that deep ensembles can provide abetter approximation to the Bayesian predictive distributionthan standard Bayesian approaches We further proposea new method MultiSWAG inspired by deep ensembleswhich marginalizes within basins of attraction mdash achievingsignificantly improved performance with a similar trainingtime

We then investigate the properties of priors over functionsinduced by priors over the weights of neural networks show-ing that they have reasonable inductive biases We also showthat the mysterious generalization properties recently pre-sented in Zhang et al (2016) can be understood by reasoningabout prior distributions over functions and are not specificto neural networks Indeed we show Gaussian processes canalso perfectly fit images with random labels yet generalizeon the noise-free problem These results are a consequenceof large support but reasonable inductive biases for com-mon problem settings We further show that while Bayesianneural networks can fit the noisy datasets the marginal like-lihood has much better support for the noise free datasetsin line with Figure 2 We additionally show that the mul-timodal marginalization in MultiSWAG alleviates doubledescent so as to achieve monotonic improvements in per-formance with model flexibility in line with our perspectiveof generalization MultiSWAG also provides significant im-provements in both accuracy and NLL over SGD trainingand unimodal marginalization Finally we provide severalperspectives on tempering in Bayesian deep learning

In the Appendix we provide several additional experimentsand results We also provide code at httpsgithubcomizmailovpavelunderstandingbdl

2 Related WorkNotable early works on Bayesian neural networks includeMacKay (1992) MacKay (1995) and Neal (1996) Theseworks generally argue in favour of making the model classfor Bayesian approaches as flexible as possible in line withBox amp Tiao (1973) Accordingly Neal (1996) pursued thelimits of large Bayesian neural networks showing that as thenumber of hidden units approached infinity these modelsbecome Gaussian processes with particular kernel functionsThis work harmonizes with recent work describing the neu-

ral tangent kernel (eg Jacot et al 2018)

The marginal likelihood is often used for Bayesian hypothe-sis testing model comparison and hyperparameter tuningwith Bayes factors used to select between models (Kass ampRaftery 1995) MacKay (2003 Ch 28) uses a diagramsimilar to Fig 2(a) to show the marginal likelihood has anOccamrsquos razor property favouring the simplest model con-sistent with a given dataset even if the prior assigns equalprobability to the various models Rasmussen amp Ghahra-mani (2001) reasons about how the marginal likelihoodcan favour large flexible models as long as such modelscorrespond to a reasonable distribution over functions

There has been much recent interest in developing Bayesianapproaches for modern deep learning with new challengesand architectures quite different from what had been con-sidered in early work Recent work has largely focused onscalable inference (eg Blundell et al 2015 Gal amp Ghahra-mani 2016 Kendall amp Gal 2017 Ritter et al 2018 Khanet al 2018 Maddox et al 2019) function-space inspiredpriors (eg Yang et al 2019 Louizos et al 2019 Sunet al 2019 Hafner et al 2018) and developing flat objec-tive priors in parameter space directly leveraging the biasesof the neural network functional form (eg Nalisnick 2018)Wilson (2020) provides a note motivating Bayesian deeplearning

Moreover PAC-Bayes (McAllester 1999) has emerged asan active direction for generalization bounds on stochas-tic neural networks (eg Guedj 2019) with distributionsover the parameters Langford amp Caruana (2002) devised aPAC-Bayes generalization bound for small stochastic neuralnetworks (two layer with two hidden units) achieving animprovement over the existing deterministic generalizationbounds Dziugaite amp Roy (2017) extended this approachoptimizing a PAC-Bayes bound with respect to a parametricdistribution over the weights of the network exploiting theflatness of solutions discovered by SGD As a result they de-rive the first non-vacuous generalization bounds for stochas-tic over-parameterized fully-connected neural networks onbinary MNIST classification Neyshabur et al (2017) alsodiscuss the connection between PAC-Bayes bounds andsharpness Neyshabur et al (2018) then devises PAC-Bayesbounds based on spectral norms of the layers and the Frobe-nius norm of the weights of the network Achille amp Soatto(2018) additionally combine PAC-Bayes and informationtheoretic approaches to argue that flat minima have lowinformation content Masegosa (2019) also proposes varia-tional and ensemble learning methods based on PAC-Bayesanalysis under model misspecification Jiang et al (2019)provide a review and comparison of several generalizationbounds including PAC-Bayes

Early works tend to provide a connection between lossgeometry and generalization using minimum description

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

length frameworks (eg Hinton amp Van Camp 1993 Hochre-iter amp Schmidhuber 1997 MacKay 1995) EmpiricallyKeskar et al (2016) argue that smaller batch SGD providesbetter generalization than large batch SGD by finding flat-ter minima Chaudhari et al (2019) and Izmailov et al(2018) design optimization procedures to specifically findflat minima

By connecting flat solutions with ensemble approximationsIzmailov et al (2018) also suggest that functions associatedwith parameters in flat regions ought to provide differentpredictions on test data for flatness to be helpful in gener-alization which is distinct from the flatness in Dinh et al(2017) Garipov et al (2018) also show that there are modeconnecting curves forming loss valleys which contain avariety of distinct solutions We argue that flat regions ofthe loss containing a diversity of solutions is particularlyrelevant for Bayesian model averaging since the model aver-age will then contain many compelling and complementaryexplanations for the data Additionally Huang et al (2019)describes neural networks as having a blessing of dimension-ality since flat regions will occupy much greater volume ina high dimensional space which we argue means that flatsolutions will dominate in the Bayesian model average

Smith amp Le (2018) and MacKay (2003 Chapter 28) ad-ditionally connect the width of the posterior with Occamfactors from a Bayesian perspective larger width corre-sponds to a smaller Occam factor and thus ought to providebetter generalization Dziugaite amp Roy (2017) and Smith ampLe (2018) also provide different perspectives on the resultsin Zhang et al (2016) which shows that deep convolutionalneural networks can fit CIFAR-10 with random labels andno training error The PAC-Bayes bound of Dziugaite ampRoy (2017) becomes vacuous when applied to randomly-labelled binary MNIST Smith amp Le (2018) show that logis-tic regression can fit noisy labels on sub-sampled MNISTinterpreting the result from an Occam factor perspective

In general PAC-Bayes provides a compelling frameworkfor deriving explicit non-asymptotic generalization boundsThese bounds can be improved by eg fewer parame-ters and very compact priors which can be different fromwhat provides optimal generalization From our perspec-tive model flexibility and priors with large support ratherthan compactness are desirable Our focus is complemen-tary and largely prescriptive aiming to provide intuitionson model construction inference generalization and neu-ral network priors as well as new connections betweenBayesian model averaging and deep ensembles benefitsof Bayesian model averaging specifically in the context ofmodern deep neural networks perspectives on temperingin Bayesian deep learning views of marginalization thatcontrast with simple Monte Carlo as well as new methodsfor Bayesian marginalization in deep learning

In other work Pearce et al (2018) propose a modificationof deep ensembles and argue that it performs approximateBayesian inference and Gustafsson et al (2019) briefly men-tion how deep ensembles can be viewed as samples froman approximate posterior In the context of deep ensembleswe believe it is natural to consider the BMA integral sepa-rately from the simple Monte Carlo approximation that isoften used to approximate this integral to compute an accu-rate predictive distribution we do not need samples from aposterior or even a faithful approximation to the posterior

Fort et al (2019) considered the diversity of predictionsproduced by models from a single SGD run and modelsfrom independent SGD runs and suggested to ensembleaverages of SGD iterates Although MultiSWA (one of themethods considered in Section 4) is related to this idea thecrucial practical difference is that MultiSWA uses a learningrate schedule that selects for flat regions of the loss the keyto the success of the SWA method (Izmailov et al 2018)Section 4 also shows that MultiSWAG which we proposefor multimodal Bayesian marginalization outperforms Mul-tiSWA

Double descent which describes generalization error thatdecreases increases and then again decreases with modelflexibility was demonstrated early by Opper et al (1990)Recently Belkin et al (2019) extensively demonstrateddouble descent leading to a surge of modern interest withNakkiran et al (2019) showing double descent in deep learn-ing Nakkiran et al (2020) shows that tuned l2 regulariza-tion can mitigate double descent Alternatively we showthat Bayesian model averaging particularly based on multi-modal marginalization can alleviate even prominent doubledescent behaviour

Tempering in Bayesian modelling has been considered underthe names Safe Bayes generalized Bayesian inference andfractional Bayesian inference (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) We provide several perspectives of tempering inBayesian deep learning and analyze the results in a recentpaper by Wenzel et al (2020) that questions tempering forBayesian neural networks

3 Bayesian MarginalizationOften the predictive distribution we want to compute isgiven by

p(y|xD) =

intp(y|xw)p(w|D)dw (1)

The outputs are y (eg regression values class labels )indexed by inputs x (eg spatial locations images ) theweights (or parameters) of the neural network f(xw) arew and D are the data Eq (1) represents a Bayesian model

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

average (BMA) Rather than bet everything on one hypoth-esis mdash with a single setting of parameters w mdash we want touse all settings of parameters weighted by their posteriorprobabilities This procedure is called marginalization ofthe parametersw as the predictive distribution of interest nolonger conditions on w This is not a controversial equationbut simply the sum and product rules of probability

31 Importance of Marginalization in Deep Learning

In general we can view classical training as performingapproximate Bayesian inference using the approximate pos-terior p(w|D) asymp δ(w = w) to compute Eq (1) where δis a Dirac delta function that is zero everywhere except atw = argmaxwp(w|D) In this case we recover the standardpredictive distribution p(y|x w) From this perspectivemany alternatives albeit imperfect will be preferable mdashincluding impoverished Gaussian posterior approximationsfor p(w|D) even if the posterior or likelihood are actuallyhighly non-Gaussian and multimodal

The difference between a classical and Bayesian approachwill depend on how sharp the posterior p(w|D) becomes Ifthe posterior is sharply peaked and the conditional predic-tive distribution p(y|xw) does not vary significantly wherethe posterior has mass there may be almost no differencesince a delta function may then be a reasonable approxi-mation of the posterior for the purpose of BMA Howevermodern neural networks are usually highly underspecifiedby the available data and therefore have diffuse likelihoodsp(D|w) not strongly favouring any one setting of param-eters Not only are the likelihoods diffuse but differentsettings of the parameters correspond to a diverse variety ofcompelling hypotheses for the data (Garipov et al 2018Izmailov et al 2019) This is exactly the setting when wemost want to perform a Bayesian model average whichwill lead to an ensemble containing many different but highperforming models for better calibration and accuracy thanclassical training

Loss Valleys Flat regions of low loss (negative log poste-rior density minus log p(w|D)) are associated with good gener-alization (eg Hochreiter amp Schmidhuber 1997 Hinton ampVan Camp 1993 Dziugaite amp Roy 2017 Izmailov et al2018 Keskar et al 2016) While flat solutions that general-ize poorly can be contrived through reparametrization (Dinhet al 2017) the flat regions that lead to good generalizationcontain a diversity of high performing models on test data(Izmailov et al 2018) corresponding to different parametersettings in those regions And indeed there are large con-tiguous regions of low loss that contain such solutions evenconnecting together different SGD solutions (Garipov et al2018 Izmailov et al 2019) (see also Figure 11 Appendix)

Since these regions of the loss represent a large volume in a

high-dimensional space (Huang et al 2019) and providea diversity of solutions they will dominate in forming thepredictive distribution in a Bayesian model average Bycontrast if the parameters in these regions provided similarfunctions as would be the case in flatness obtained throughreparametrization these functions would be redundant inthe model average That is although the solutions of highposterior density can provide poor generalization it is the so-lutions that generalize well that will have greatest posteriormass and thus be automatically favoured by the BMA

Calibration by Epistemic Uncertainty RepresentationIt has been noticed that modern neural networks are oftenmiscalibrated in the sense that their predictions are typicallyoverconfident (Guo et al 2017) For example in classifi-cation the highest softmax output of a convolutional neuralnetwork is typically much larger than the probability of theassociated class label The fundamental reason for miscali-bration is ignoring epistemic uncertainty A neural networkcan represent many models that are consistent with our ob-servations By selecting only one in a classical procedurewe lose uncertainty when the models disagree for a testpoint In regression we can visualize epistemic uncertaintyby looking at the spread of the predictive distribution aswe move away from the data there are a greater varietyof consistent solutions leading to larger uncertainty as inFigure 4 We can further calibrate the model with temperingwhich we discuss in the Appendix Section 8

Accuracy An often overlooked benefit of Bayesian modelaveraging in modern deep learning is improved accuracy Ifwe average the predictions of many high performing modelsthat disagree in some cases we should see significantly im-proved accuracy This benefit is now starting to be observedin practice (eg Izmailov et al 2019) Improvements inaccuracy are very convincingly exemplified by deep en-sembles (Lakshminarayanan et al 2017) which have beenperceived as a competing approach to Bayesian methodsbut in fact provides a compelling mechanism for approxi-mate Bayesian model averaging as we show in Section 33We also demonstrate significant accuracy benefits for multi-modal Bayesian marginalization in Section 7

32 Beyond Monte Carlo

Nearly all approaches to estimating the integral in Eq (1)when it cannot be computed in closed form involvea simple Monte Carlo approximation p(y|xD) asymp1J

sumJj=1 p(y|xwj) wj sim p(w|D) In practice the sam-

ples from the posterior p(w|D) are also approximate andfound through MCMC or deterministic methods The de-terministic methods approximate p(w|D) with a differentmore convenient density q(w|D θ) from which we can sam-ple often chosen to be Gaussian The parameters θ are

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

selected to make q close to p in some sense for exam-ple variational approximations (eg Beal 2003) whichhave emerged as a popular deterministic approach findargminθKL(q||p) Other standard deterministic approxima-tions include Laplace (eg MacKay 1995) EP (Minka2001a) and INLA (Rue et al 2009)

From the perspective of estimating the predictive distribu-tion in Eq (1) we can view simple Monte Carlo as ap-proximating the posterior with a set of point masses withlocations given by samples from another approximate pos-terior q even if q is a continuous distribution That isp(w|D) asympsumJ

j=1 δ(w = wj) wj sim q(w|D)

Ultimately the goal is to accurately compute the predictivedistribution in Eq (1) rather than find a generally accu-rate representation of the posterior In particular we mustcarefully represent the posterior in regions that will makethe greatest contributions to the BMA integral In terms ofefficiently computing the predictive distribution we do notnecessarily want to place point masses at locations givenby samples from the posterior For example functional di-versity is important for a good approximation to the BMAintegral because we are summing together terms of theform p(y|xw) if two settings of the weights wi and wjeach provide high likelihood (and consequently high pos-terior density) but give rise to similar functions f(xwi)f(xwj) then they will be largely redundant in the modelaverage and the second setting of parameters will not con-tribute much to estimating the BMA integral for the uncon-ditional predictive distribution In Sections 33 and 4 weconsider how various approaches approximate the predictivedistribution

33 Deep Ensembles are BMA

Deep ensembles (Lakshminarayanan et al 2017) is fastbecoming a gold standard for accurate and well-calibratedpredictive distributions Recent reports (eg Ovadia et al2019 Ashukha et al 2020) show that deep ensembles ap-pear to outperform some particular approaches to Bayesianneural networks for uncertainty representation leading tothe confusion that deep ensembles and Bayesian methodsare competing approaches These methods are often explic-itly referred to as non-Bayesian (eg Lakshminarayananet al 2017 Ovadia et al 2019 Wenzel et al 2020) Tothe contrary we argue that deep ensembles are actually acompelling approach to Bayesian model averaging in thevein of Section 32

There is a fundamental difference between a Bayesian modelaverage and some approaches to ensembling The Bayesianmodel average assumes that one hypothesis (one parame-ter setting) is correct and averages over models due to aninability to distinguish between hypotheses given limitedinformation (Minka 2000) As we observe more data the

posterior collapses onto a single hypothesis If the trueexplanation for the data is a combination of hypothesesthen the Bayesian model average may appear to performworse as we observe more data Some ensembling methodswork by enriching the hypothesis space and therefore donot collapse in this way Deep ensembles however areformed by MAP or maximum likelihood retraining of thesame architecture multiple times leading to different basinsof attraction The deep ensemble will therefore collapse inthe same way as a Bayesian model average as the posteriorconcentrates Since the hypotheses space (support) for amodern neural network is large containing many differentpossible explanations for the data posterior collapse willoften be desirable

Furthermore by representing multiple basins of attractiondeep ensembles can provide a better approximation to theBMA than the Bayesian approaches in Ovadia et al (2019)Indeed the functional diversity is important for a good ap-proximation to the BMA integral as per Section 32 Theapproaches referred to as Bayesian in Ovadia et al (2019)instead focus their approximation on a single basin whichmay contain a lot of redundancy in function space makinga relatively minimal contribution to computing the Bayesianpredictive distribution On the other hand retraining a neu-ral network multiple times for deep ensembles incurs asignificant computational expense The single basin ap-proaches may be preferred if we are to control for computa-tion We explore these questions in Section 4

4 An Empirical Study of MarginalizationWe have shown that deep ensembles can be interpretedas an approximate approach to Bayesian marginalizationwhich selects for functional diversity by representing multi-ple basins of attraction in the posterior Most Bayesian deeplearning methods instead focus on faithfully approximatinga posterior within a single basin of attraction We propose anew method MultiSWAG which combines these two typesof approaches MultiSWAG combines multiple indepen-dently trained SWAG approximations (Maddox et al 2019)to create a mixture of Gaussians approximation to the pos-terior with each Gaussian centred on a different basin ofattraction We note that MultiSWAG does not require anyadditional training time over standard deep ensembles

We illustrate the conceptual difference between deep en-sembles a standard variational single basin approach andMultiSWAG in Figure 3 In the top panel we have a con-ceptualization of a multimodal posterior VI approximatesthe posterior with multiple samples within a single basinBut we see in the middle panel that the conditional predic-tive distribution p(y|xw) does not vary significantly withinthe basin and thus each additional sample contributes min-imally to computing the marginal predictive distribution

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Deep Ensembles VI Multi-SWAG

ww

w

w

dist(p q)

p(y|w)

p(w|D)

Figure 3 Approximating the BMAp(y|xD) =

intp(y|xw)p(w|D)dw Top p(w|D) with repre-

sentations from VI (orange) deep ensembles (blue) MultiSWAG(red) Middle p(y|xw) as a function of w for a test input xThis function does not vary much within modes but changes sig-nificantly between modes Bottom Distance between the truepredictive distribution and the approximation as a function of rep-resenting a posterior at an additional point w assuming we havesampled the mode in dark green There is more to be gained byexploring new basins than continuing to explore the same basin

p(y|xD) On the other hand p(y|xw) varies significantlybetween basins and thus each point mass for deep ensem-bles contributes significantly to the marginal predictive dis-tribution By sampling within the basins MultiSWAG pro-vides additional contributions to the predictive distributionIn the bottom panel we have the gain in approximatingthe predictive distribution when adding a point mass to therepresentation of the posterior as a function of its locationassuming we have already sampled the mode in dark greenIncluding samples from different modes provides significantgain over continuing to sample from the same mode andincluding weights in wide basins provide relatively moregain than the narrow ones

In Figure 4 we evaluate single basin and multi-basin ap-proaches in a case where we can near-exactly compute thepredictive distribution We provide details for generatingthe data and training the models in Appendix D1 We seethat the predictive distribution given by deep ensembles isqualitatively closer to the true distribution compared to thesingle basin variational method between data clusters thedeep ensemble approach provides a similar representationof epistemic uncertainty whereas the variational method isextremely overconfident in these regions Moreover we seethat the Wasserstein distance between the true predictive dis-tribution and these two approximations quickly shrinks withnumber of samples for deep ensembles but is roughly inde-pendent of number of samples for the variational approachThus the deep ensemble is providing a better approximation

of the Bayesian model average in Eq (1) than the singlebasin variational approach which has traditionally beenlabelled as the Bayesian alternative

Next we evaluate MultiSWAG under distribution shift onthe CIFAR-10 dataset (Krizhevsky et al 2014) replicatingthe setup in Ovadia et al (2019) We consider 16 data cor-ruptions each at 5 different levels of severity introducedby Hendrycks amp Dietterich (2019) For each corruptionwe evaluate the performance of deep ensembles and Mul-tiSWAG varying the training budget For deep ensembleswe show performance as a function of the number of inde-pendently trained models in the ensemble For MultiSWAGwe show performance as a function of the number of inde-pendent SWAG approximations that we construct we thensample 20 models from each of these approximations toconstruct the final ensemble

While the training time for MultiSWAG is the same as fordeep ensembles at test time MultiSWAG is more expen-sive as the corresponding ensemble consists of a largernumber of models To account for situations when testtime is constrained we also propose MultiSWA a methodthat ensembles independently trained SWA solutions (Iz-mailov et al 2018) SWA solutions are the means of thecorresponding Gaussian SWAG approximations Izmailovet al (2018) argue that SWA solutions approximate the localensembles represented by SWAG with a single model

In Figure 5 we show the negative log-likelihood as a func-tion of the number of independently trained models for aPreactivation ResNet-20 on CIFAR-10 corrupted with Gaus-sian blur with varying levels of intensity (increasing fromleft to right) in Figure 5 MultiSWAG outperforms deepensembles significantly on highly corrupted data For lowerlevels of corruption MultiSWAG works particularly wellwhen only a small number of independently trained modelsare available We note that MultiSWA also outperforms deepensembles and has the same computational requirements attraining and test time as deep ensembles We present resultsfor other types of corruption in Appendix Figures 14 15 1617 showing similar trends In general there is an extensiveevaluation of MultiSWAG in the Appendix

Our perspective of generalization is deeply connected withBayesian marginalization In order to best realize the bene-fits of marginalization in deep learning we need to consideras many hypotheses as possible through multimodal poste-rior approximations such as MultiSWAG In Section 7 wereturn to MultiSWAG showing how it can entirely alleviateprominent double descent behaviour and lead to strikingimprovements in generalization over SGD and single basinmarginalization for both accuracy and NLL

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10

minus05

00

05

10

15

20

25

30

35

(a) Exact

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(b) Deep Ensembles

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(c) Variational Inference

0 10 20 30 40 50

Samples

008

010

012

014

016

018

Avg

Wx(q

p)

Deep Ensembles

SVI

(d) Discrepancy with True BMA

Figure 4 Approximating the true predictive distribution (a) A close approximation of the true predictive distribution obtained bycombining 200 HMC chains (b) Deep ensembles predictive distribution using 50 independently trained networks (c) Predictivedistribution for factorized variational inference (VI) (d) Convergence of the predictive distributions for deep ensembles and variationalinference as a function of the number of samples we measure the average Wasserstein distance between the marginals in the range ofinput positions The multi-basin deep ensembles approach provides a more faithful approximation of the Bayesian predictive distributionthan the conventional single-basin VI approach which is overconfident between data clusters The top panels show the Wassersteindistance between the true predictive distribution and the deep ensemble and VI approximations as a function of inputs x

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

Figure 5 Negative log likelihood for Deep Ensembles MultiSWAG and MultiSWA using a PreResNet-20 on CIFAR-10 with varyingintensity of the Gaussian blur corruption The image in each plot shows the intensity of corruption For all levels of intensity MultiSWAGand MultiSWA outperform Deep Ensembles for a small number of independent models For high levels of corruption MultiSWAGsignificantly outperforms other methods even for many independent models We present results for other corruptions in the Appendix

5 Neural Network PriorsA prior over parameters p(w) combines with the functionalform of a model f(xw) to induce a distribution over func-tions p(f(xw)) It is this distribution over functions thatcontrols the generalization properties of the model the priorover parameters in isolation has no meaning Neural net-works are imbued with structural properties that providegood inductive biases such as translation equivariance hi-erarchical representations and sparsity In the sense ofFigure 2 the prior will have large support due to the flexi-bility of neural networks but its inductive biases provide themost mass to datasets which are representative of problemsettings where neural networks are often applied In thissection we study the properties of the induced distributionover functions We directly continue the discussion of priorsin Section 6 with a focus on examining the noisy CIFAR re-sults in Zhang et al (2016) from a probabilistic perspectiveof generalization These sections are best read together

We also provide several additional experiments in the Ap-pendix In Section E we present analytic results on thedependence of the prior distribution in function space onthe variance of the prior over parameters considering also

layer-wise parameter priors with ReLU activations As partof a discussion on tempering in Section 84 we study theeffect of α in p(w) = N (0 α2I) on prior class probabilitiesfor individual sample functions p(f(xw)) the predictivedistribution and posterior samples as we observe varyingamounts of data In Section F we further study the correla-tion structure over images induced by neural network priorssubject to perturbations of the images In Section D3 weprovide additional experimental details

51 Deep Image Prior and Random Network Features

Two recent results provide strong evidence that vague Gaus-sian priors over parameters when combined with a neu-ral network architecture induce a distribution over func-tions with useful inductive biases In the deep image priorUlyanov et al (2018) show that randomly initialized convo-lutional neural networks without training provide excellentperformance for image denoising super-resolution and in-painting This result demonstrates the ability for a sam-ple function from a random prior over neural networksp(f(xw)) to capture low-level image statistics before anytraining Similarly Zhang et al (2016) shows that pre-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

Figure 6 Induced prior correlation function Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNISTinduced by LeNet-5 for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images fromdifferent classes suggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α (d)NLL of an ensemble of 20 SWAG samples on MNIST as a function of α using a LeNet-5

00 02 04 06 08 104

3

2

1

0

1

2

3

4

(a) Prior Draws

00 02 04 06 08 10

(b) True Labels

00 02 04 06 08 10

(c) Corrupted Labels

0 5k 10k 12k 15k 20k 25k

Altered Labels

minus070

minus065

minus060

minus055

minus050

minus045

Mar

gina

lL

ikel

ihoo

dE

stim

ate

15

20

25

30

35

40

45

Tes

tE

rror

(d) Gaussian Process

0 5k 10k 20k 30k 40k 50k

Resampled Labels

minus5

minus4

minus3

minus2

minus1

Mar

gina

lL

ikel

ihoo

dE

stim

ate times107

20

40

60

80

Tes

tE

rror

(e) PreResNet-20

Figure 7 Rethinking generalization (a) Sample functions from a Gaussian process prior (b) GP fit (with 95 credible region) tostructured data generated as ygreen(x) = sin(x middot 2π)+ ε ε sim N (0 022) (c) GP fit with no training error after a significant addition ofcorrupted data in red drawn from Uniform[05 1] (d) Variational GP marginal likelihood with RBF kernel for two classes of CIFAR-10(e) Laplace BNN marginal likelihood for a PreResNet-20 on CIFAR-10 with different fractions of random labels The marginal likelihoodfor both the GP and BNN decreases as we increase the level of corruption in the labels suggesting reasonable inductive biases in the priorover functions Moreover both the GP and BNN have 100 training accuracy on images with fully corrupted labels

processing CIFAR-10 with a randomly initialized untrainedconvolutional neural network dramatically improves the testperformance of a simple Gaussian kernel on pixels from54 accuracy to 71 Adding `2 regularization only im-proves the accuracy by an additional 2 These resultsagain indicate that broad Gaussian priors over parametersinduce reasonable priors over networks with a minor ad-ditional gain from decreasing the variance of the prior inparameter space which corresponds to `2 regularization

52 Prior Class Correlations

In Figure 6 we study the prior correlations in the outputs ofthe LeNet-5 convolutional network (LeCun et al 1998) onobjects of different MNIST classes We sample networkswith weights p(w) = N (0 α2I) and compute the values oflogits corresponding to the first class for all pairs of imagesand compute correlations of these logits For all levels of αthe correlations between objects corresponding to the sameclass are consistently higher than the correlation betweenobjects of different classes showing that the network in-duces a reasonable prior similarity metric over these imagesAdditionally we observe that the prior correlations some-

what decrease as we increase α showing that bounding thenorm of the weights has some minor utility in accordancewith Section 51 Similarly in panel (d) we see that the NLLsignificantly decreases as α increases in [0 05] and thenslightly increases but is relatively constant thereafter

In the Appendix we further describe analytic results andillustrate the effect of α on sample functions

53 Effect of Prior Variance on CIFAR-10

We further study the effect of the parameter prior stan-dard deviation α measuring performance of approximateBayesian inference for CIFAR-10 with a PreactivationResNet-20 (He et al 2016) and VGG-16 (Simonyan ampZisserman 2014) For each of these architectures we runSWAG (Maddox et al 2019) with fixed hyper-parametersand varying α We report the results in Figure 12(d) (h) Forboth architectures the performance is near-optimal in therange α isin [10minus2 10minus1] Smaller α constrains the weightstoo much Performance is reasonable and becomes mostlyinsensitive to α as it continues to increase due to the induc-tive biases of the functional form of the neural network

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 3: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

terior probabilities rather than bet everything on a singlesetting of parameters Neural networks are typically under-specified by the data and can represent many different buthigh performing models corresponding to different settingsof parameters which is exactly when marginalization willmake the biggest difference for accuracy and calibrationMoreover we clarify that the recent deep ensembles (Lak-shminarayanan et al 2017) are not a competing approachto Bayesian inference but can be viewed as a compellingmechanism for Bayesian marginalization Indeed we em-pirically demonstrate that deep ensembles can provide abetter approximation to the Bayesian predictive distributionthan standard Bayesian approaches We further proposea new method MultiSWAG inspired by deep ensembleswhich marginalizes within basins of attraction mdash achievingsignificantly improved performance with a similar trainingtime

We then investigate the properties of priors over functionsinduced by priors over the weights of neural networks show-ing that they have reasonable inductive biases We also showthat the mysterious generalization properties recently pre-sented in Zhang et al (2016) can be understood by reasoningabout prior distributions over functions and are not specificto neural networks Indeed we show Gaussian processes canalso perfectly fit images with random labels yet generalizeon the noise-free problem These results are a consequenceof large support but reasonable inductive biases for com-mon problem settings We further show that while Bayesianneural networks can fit the noisy datasets the marginal like-lihood has much better support for the noise free datasetsin line with Figure 2 We additionally show that the mul-timodal marginalization in MultiSWAG alleviates doubledescent so as to achieve monotonic improvements in per-formance with model flexibility in line with our perspectiveof generalization MultiSWAG also provides significant im-provements in both accuracy and NLL over SGD trainingand unimodal marginalization Finally we provide severalperspectives on tempering in Bayesian deep learning

In the Appendix we provide several additional experimentsand results We also provide code at httpsgithubcomizmailovpavelunderstandingbdl

2 Related WorkNotable early works on Bayesian neural networks includeMacKay (1992) MacKay (1995) and Neal (1996) Theseworks generally argue in favour of making the model classfor Bayesian approaches as flexible as possible in line withBox amp Tiao (1973) Accordingly Neal (1996) pursued thelimits of large Bayesian neural networks showing that as thenumber of hidden units approached infinity these modelsbecome Gaussian processes with particular kernel functionsThis work harmonizes with recent work describing the neu-

ral tangent kernel (eg Jacot et al 2018)

The marginal likelihood is often used for Bayesian hypothe-sis testing model comparison and hyperparameter tuningwith Bayes factors used to select between models (Kass ampRaftery 1995) MacKay (2003 Ch 28) uses a diagramsimilar to Fig 2(a) to show the marginal likelihood has anOccamrsquos razor property favouring the simplest model con-sistent with a given dataset even if the prior assigns equalprobability to the various models Rasmussen amp Ghahra-mani (2001) reasons about how the marginal likelihoodcan favour large flexible models as long as such modelscorrespond to a reasonable distribution over functions

There has been much recent interest in developing Bayesianapproaches for modern deep learning with new challengesand architectures quite different from what had been con-sidered in early work Recent work has largely focused onscalable inference (eg Blundell et al 2015 Gal amp Ghahra-mani 2016 Kendall amp Gal 2017 Ritter et al 2018 Khanet al 2018 Maddox et al 2019) function-space inspiredpriors (eg Yang et al 2019 Louizos et al 2019 Sunet al 2019 Hafner et al 2018) and developing flat objec-tive priors in parameter space directly leveraging the biasesof the neural network functional form (eg Nalisnick 2018)Wilson (2020) provides a note motivating Bayesian deeplearning

Moreover PAC-Bayes (McAllester 1999) has emerged asan active direction for generalization bounds on stochas-tic neural networks (eg Guedj 2019) with distributionsover the parameters Langford amp Caruana (2002) devised aPAC-Bayes generalization bound for small stochastic neuralnetworks (two layer with two hidden units) achieving animprovement over the existing deterministic generalizationbounds Dziugaite amp Roy (2017) extended this approachoptimizing a PAC-Bayes bound with respect to a parametricdistribution over the weights of the network exploiting theflatness of solutions discovered by SGD As a result they de-rive the first non-vacuous generalization bounds for stochas-tic over-parameterized fully-connected neural networks onbinary MNIST classification Neyshabur et al (2017) alsodiscuss the connection between PAC-Bayes bounds andsharpness Neyshabur et al (2018) then devises PAC-Bayesbounds based on spectral norms of the layers and the Frobe-nius norm of the weights of the network Achille amp Soatto(2018) additionally combine PAC-Bayes and informationtheoretic approaches to argue that flat minima have lowinformation content Masegosa (2019) also proposes varia-tional and ensemble learning methods based on PAC-Bayesanalysis under model misspecification Jiang et al (2019)provide a review and comparison of several generalizationbounds including PAC-Bayes

Early works tend to provide a connection between lossgeometry and generalization using minimum description

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

length frameworks (eg Hinton amp Van Camp 1993 Hochre-iter amp Schmidhuber 1997 MacKay 1995) EmpiricallyKeskar et al (2016) argue that smaller batch SGD providesbetter generalization than large batch SGD by finding flat-ter minima Chaudhari et al (2019) and Izmailov et al(2018) design optimization procedures to specifically findflat minima

By connecting flat solutions with ensemble approximationsIzmailov et al (2018) also suggest that functions associatedwith parameters in flat regions ought to provide differentpredictions on test data for flatness to be helpful in gener-alization which is distinct from the flatness in Dinh et al(2017) Garipov et al (2018) also show that there are modeconnecting curves forming loss valleys which contain avariety of distinct solutions We argue that flat regions ofthe loss containing a diversity of solutions is particularlyrelevant for Bayesian model averaging since the model aver-age will then contain many compelling and complementaryexplanations for the data Additionally Huang et al (2019)describes neural networks as having a blessing of dimension-ality since flat regions will occupy much greater volume ina high dimensional space which we argue means that flatsolutions will dominate in the Bayesian model average

Smith amp Le (2018) and MacKay (2003 Chapter 28) ad-ditionally connect the width of the posterior with Occamfactors from a Bayesian perspective larger width corre-sponds to a smaller Occam factor and thus ought to providebetter generalization Dziugaite amp Roy (2017) and Smith ampLe (2018) also provide different perspectives on the resultsin Zhang et al (2016) which shows that deep convolutionalneural networks can fit CIFAR-10 with random labels andno training error The PAC-Bayes bound of Dziugaite ampRoy (2017) becomes vacuous when applied to randomly-labelled binary MNIST Smith amp Le (2018) show that logis-tic regression can fit noisy labels on sub-sampled MNISTinterpreting the result from an Occam factor perspective

In general PAC-Bayes provides a compelling frameworkfor deriving explicit non-asymptotic generalization boundsThese bounds can be improved by eg fewer parame-ters and very compact priors which can be different fromwhat provides optimal generalization From our perspec-tive model flexibility and priors with large support ratherthan compactness are desirable Our focus is complemen-tary and largely prescriptive aiming to provide intuitionson model construction inference generalization and neu-ral network priors as well as new connections betweenBayesian model averaging and deep ensembles benefitsof Bayesian model averaging specifically in the context ofmodern deep neural networks perspectives on temperingin Bayesian deep learning views of marginalization thatcontrast with simple Monte Carlo as well as new methodsfor Bayesian marginalization in deep learning

In other work Pearce et al (2018) propose a modificationof deep ensembles and argue that it performs approximateBayesian inference and Gustafsson et al (2019) briefly men-tion how deep ensembles can be viewed as samples froman approximate posterior In the context of deep ensembleswe believe it is natural to consider the BMA integral sepa-rately from the simple Monte Carlo approximation that isoften used to approximate this integral to compute an accu-rate predictive distribution we do not need samples from aposterior or even a faithful approximation to the posterior

Fort et al (2019) considered the diversity of predictionsproduced by models from a single SGD run and modelsfrom independent SGD runs and suggested to ensembleaverages of SGD iterates Although MultiSWA (one of themethods considered in Section 4) is related to this idea thecrucial practical difference is that MultiSWA uses a learningrate schedule that selects for flat regions of the loss the keyto the success of the SWA method (Izmailov et al 2018)Section 4 also shows that MultiSWAG which we proposefor multimodal Bayesian marginalization outperforms Mul-tiSWA

Double descent which describes generalization error thatdecreases increases and then again decreases with modelflexibility was demonstrated early by Opper et al (1990)Recently Belkin et al (2019) extensively demonstrateddouble descent leading to a surge of modern interest withNakkiran et al (2019) showing double descent in deep learn-ing Nakkiran et al (2020) shows that tuned l2 regulariza-tion can mitigate double descent Alternatively we showthat Bayesian model averaging particularly based on multi-modal marginalization can alleviate even prominent doubledescent behaviour

Tempering in Bayesian modelling has been considered underthe names Safe Bayes generalized Bayesian inference andfractional Bayesian inference (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) We provide several perspectives of tempering inBayesian deep learning and analyze the results in a recentpaper by Wenzel et al (2020) that questions tempering forBayesian neural networks

3 Bayesian MarginalizationOften the predictive distribution we want to compute isgiven by

p(y|xD) =

intp(y|xw)p(w|D)dw (1)

The outputs are y (eg regression values class labels )indexed by inputs x (eg spatial locations images ) theweights (or parameters) of the neural network f(xw) arew and D are the data Eq (1) represents a Bayesian model

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

average (BMA) Rather than bet everything on one hypoth-esis mdash with a single setting of parameters w mdash we want touse all settings of parameters weighted by their posteriorprobabilities This procedure is called marginalization ofthe parametersw as the predictive distribution of interest nolonger conditions on w This is not a controversial equationbut simply the sum and product rules of probability

31 Importance of Marginalization in Deep Learning

In general we can view classical training as performingapproximate Bayesian inference using the approximate pos-terior p(w|D) asymp δ(w = w) to compute Eq (1) where δis a Dirac delta function that is zero everywhere except atw = argmaxwp(w|D) In this case we recover the standardpredictive distribution p(y|x w) From this perspectivemany alternatives albeit imperfect will be preferable mdashincluding impoverished Gaussian posterior approximationsfor p(w|D) even if the posterior or likelihood are actuallyhighly non-Gaussian and multimodal

The difference between a classical and Bayesian approachwill depend on how sharp the posterior p(w|D) becomes Ifthe posterior is sharply peaked and the conditional predic-tive distribution p(y|xw) does not vary significantly wherethe posterior has mass there may be almost no differencesince a delta function may then be a reasonable approxi-mation of the posterior for the purpose of BMA Howevermodern neural networks are usually highly underspecifiedby the available data and therefore have diffuse likelihoodsp(D|w) not strongly favouring any one setting of param-eters Not only are the likelihoods diffuse but differentsettings of the parameters correspond to a diverse variety ofcompelling hypotheses for the data (Garipov et al 2018Izmailov et al 2019) This is exactly the setting when wemost want to perform a Bayesian model average whichwill lead to an ensemble containing many different but highperforming models for better calibration and accuracy thanclassical training

Loss Valleys Flat regions of low loss (negative log poste-rior density minus log p(w|D)) are associated with good gener-alization (eg Hochreiter amp Schmidhuber 1997 Hinton ampVan Camp 1993 Dziugaite amp Roy 2017 Izmailov et al2018 Keskar et al 2016) While flat solutions that general-ize poorly can be contrived through reparametrization (Dinhet al 2017) the flat regions that lead to good generalizationcontain a diversity of high performing models on test data(Izmailov et al 2018) corresponding to different parametersettings in those regions And indeed there are large con-tiguous regions of low loss that contain such solutions evenconnecting together different SGD solutions (Garipov et al2018 Izmailov et al 2019) (see also Figure 11 Appendix)

Since these regions of the loss represent a large volume in a

high-dimensional space (Huang et al 2019) and providea diversity of solutions they will dominate in forming thepredictive distribution in a Bayesian model average Bycontrast if the parameters in these regions provided similarfunctions as would be the case in flatness obtained throughreparametrization these functions would be redundant inthe model average That is although the solutions of highposterior density can provide poor generalization it is the so-lutions that generalize well that will have greatest posteriormass and thus be automatically favoured by the BMA

Calibration by Epistemic Uncertainty RepresentationIt has been noticed that modern neural networks are oftenmiscalibrated in the sense that their predictions are typicallyoverconfident (Guo et al 2017) For example in classifi-cation the highest softmax output of a convolutional neuralnetwork is typically much larger than the probability of theassociated class label The fundamental reason for miscali-bration is ignoring epistemic uncertainty A neural networkcan represent many models that are consistent with our ob-servations By selecting only one in a classical procedurewe lose uncertainty when the models disagree for a testpoint In regression we can visualize epistemic uncertaintyby looking at the spread of the predictive distribution aswe move away from the data there are a greater varietyof consistent solutions leading to larger uncertainty as inFigure 4 We can further calibrate the model with temperingwhich we discuss in the Appendix Section 8

Accuracy An often overlooked benefit of Bayesian modelaveraging in modern deep learning is improved accuracy Ifwe average the predictions of many high performing modelsthat disagree in some cases we should see significantly im-proved accuracy This benefit is now starting to be observedin practice (eg Izmailov et al 2019) Improvements inaccuracy are very convincingly exemplified by deep en-sembles (Lakshminarayanan et al 2017) which have beenperceived as a competing approach to Bayesian methodsbut in fact provides a compelling mechanism for approxi-mate Bayesian model averaging as we show in Section 33We also demonstrate significant accuracy benefits for multi-modal Bayesian marginalization in Section 7

32 Beyond Monte Carlo

Nearly all approaches to estimating the integral in Eq (1)when it cannot be computed in closed form involvea simple Monte Carlo approximation p(y|xD) asymp1J

sumJj=1 p(y|xwj) wj sim p(w|D) In practice the sam-

ples from the posterior p(w|D) are also approximate andfound through MCMC or deterministic methods The de-terministic methods approximate p(w|D) with a differentmore convenient density q(w|D θ) from which we can sam-ple often chosen to be Gaussian The parameters θ are

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

selected to make q close to p in some sense for exam-ple variational approximations (eg Beal 2003) whichhave emerged as a popular deterministic approach findargminθKL(q||p) Other standard deterministic approxima-tions include Laplace (eg MacKay 1995) EP (Minka2001a) and INLA (Rue et al 2009)

From the perspective of estimating the predictive distribu-tion in Eq (1) we can view simple Monte Carlo as ap-proximating the posterior with a set of point masses withlocations given by samples from another approximate pos-terior q even if q is a continuous distribution That isp(w|D) asympsumJ

j=1 δ(w = wj) wj sim q(w|D)

Ultimately the goal is to accurately compute the predictivedistribution in Eq (1) rather than find a generally accu-rate representation of the posterior In particular we mustcarefully represent the posterior in regions that will makethe greatest contributions to the BMA integral In terms ofefficiently computing the predictive distribution we do notnecessarily want to place point masses at locations givenby samples from the posterior For example functional di-versity is important for a good approximation to the BMAintegral because we are summing together terms of theform p(y|xw) if two settings of the weights wi and wjeach provide high likelihood (and consequently high pos-terior density) but give rise to similar functions f(xwi)f(xwj) then they will be largely redundant in the modelaverage and the second setting of parameters will not con-tribute much to estimating the BMA integral for the uncon-ditional predictive distribution In Sections 33 and 4 weconsider how various approaches approximate the predictivedistribution

33 Deep Ensembles are BMA

Deep ensembles (Lakshminarayanan et al 2017) is fastbecoming a gold standard for accurate and well-calibratedpredictive distributions Recent reports (eg Ovadia et al2019 Ashukha et al 2020) show that deep ensembles ap-pear to outperform some particular approaches to Bayesianneural networks for uncertainty representation leading tothe confusion that deep ensembles and Bayesian methodsare competing approaches These methods are often explic-itly referred to as non-Bayesian (eg Lakshminarayananet al 2017 Ovadia et al 2019 Wenzel et al 2020) Tothe contrary we argue that deep ensembles are actually acompelling approach to Bayesian model averaging in thevein of Section 32

There is a fundamental difference between a Bayesian modelaverage and some approaches to ensembling The Bayesianmodel average assumes that one hypothesis (one parame-ter setting) is correct and averages over models due to aninability to distinguish between hypotheses given limitedinformation (Minka 2000) As we observe more data the

posterior collapses onto a single hypothesis If the trueexplanation for the data is a combination of hypothesesthen the Bayesian model average may appear to performworse as we observe more data Some ensembling methodswork by enriching the hypothesis space and therefore donot collapse in this way Deep ensembles however areformed by MAP or maximum likelihood retraining of thesame architecture multiple times leading to different basinsof attraction The deep ensemble will therefore collapse inthe same way as a Bayesian model average as the posteriorconcentrates Since the hypotheses space (support) for amodern neural network is large containing many differentpossible explanations for the data posterior collapse willoften be desirable

Furthermore by representing multiple basins of attractiondeep ensembles can provide a better approximation to theBMA than the Bayesian approaches in Ovadia et al (2019)Indeed the functional diversity is important for a good ap-proximation to the BMA integral as per Section 32 Theapproaches referred to as Bayesian in Ovadia et al (2019)instead focus their approximation on a single basin whichmay contain a lot of redundancy in function space makinga relatively minimal contribution to computing the Bayesianpredictive distribution On the other hand retraining a neu-ral network multiple times for deep ensembles incurs asignificant computational expense The single basin ap-proaches may be preferred if we are to control for computa-tion We explore these questions in Section 4

4 An Empirical Study of MarginalizationWe have shown that deep ensembles can be interpretedas an approximate approach to Bayesian marginalizationwhich selects for functional diversity by representing multi-ple basins of attraction in the posterior Most Bayesian deeplearning methods instead focus on faithfully approximatinga posterior within a single basin of attraction We propose anew method MultiSWAG which combines these two typesof approaches MultiSWAG combines multiple indepen-dently trained SWAG approximations (Maddox et al 2019)to create a mixture of Gaussians approximation to the pos-terior with each Gaussian centred on a different basin ofattraction We note that MultiSWAG does not require anyadditional training time over standard deep ensembles

We illustrate the conceptual difference between deep en-sembles a standard variational single basin approach andMultiSWAG in Figure 3 In the top panel we have a con-ceptualization of a multimodal posterior VI approximatesthe posterior with multiple samples within a single basinBut we see in the middle panel that the conditional predic-tive distribution p(y|xw) does not vary significantly withinthe basin and thus each additional sample contributes min-imally to computing the marginal predictive distribution

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Deep Ensembles VI Multi-SWAG

ww

w

w

dist(p q)

p(y|w)

p(w|D)

Figure 3 Approximating the BMAp(y|xD) =

intp(y|xw)p(w|D)dw Top p(w|D) with repre-

sentations from VI (orange) deep ensembles (blue) MultiSWAG(red) Middle p(y|xw) as a function of w for a test input xThis function does not vary much within modes but changes sig-nificantly between modes Bottom Distance between the truepredictive distribution and the approximation as a function of rep-resenting a posterior at an additional point w assuming we havesampled the mode in dark green There is more to be gained byexploring new basins than continuing to explore the same basin

p(y|xD) On the other hand p(y|xw) varies significantlybetween basins and thus each point mass for deep ensem-bles contributes significantly to the marginal predictive dis-tribution By sampling within the basins MultiSWAG pro-vides additional contributions to the predictive distributionIn the bottom panel we have the gain in approximatingthe predictive distribution when adding a point mass to therepresentation of the posterior as a function of its locationassuming we have already sampled the mode in dark greenIncluding samples from different modes provides significantgain over continuing to sample from the same mode andincluding weights in wide basins provide relatively moregain than the narrow ones

In Figure 4 we evaluate single basin and multi-basin ap-proaches in a case where we can near-exactly compute thepredictive distribution We provide details for generatingthe data and training the models in Appendix D1 We seethat the predictive distribution given by deep ensembles isqualitatively closer to the true distribution compared to thesingle basin variational method between data clusters thedeep ensemble approach provides a similar representationof epistemic uncertainty whereas the variational method isextremely overconfident in these regions Moreover we seethat the Wasserstein distance between the true predictive dis-tribution and these two approximations quickly shrinks withnumber of samples for deep ensembles but is roughly inde-pendent of number of samples for the variational approachThus the deep ensemble is providing a better approximation

of the Bayesian model average in Eq (1) than the singlebasin variational approach which has traditionally beenlabelled as the Bayesian alternative

Next we evaluate MultiSWAG under distribution shift onthe CIFAR-10 dataset (Krizhevsky et al 2014) replicatingthe setup in Ovadia et al (2019) We consider 16 data cor-ruptions each at 5 different levels of severity introducedby Hendrycks amp Dietterich (2019) For each corruptionwe evaluate the performance of deep ensembles and Mul-tiSWAG varying the training budget For deep ensembleswe show performance as a function of the number of inde-pendently trained models in the ensemble For MultiSWAGwe show performance as a function of the number of inde-pendent SWAG approximations that we construct we thensample 20 models from each of these approximations toconstruct the final ensemble

While the training time for MultiSWAG is the same as fordeep ensembles at test time MultiSWAG is more expen-sive as the corresponding ensemble consists of a largernumber of models To account for situations when testtime is constrained we also propose MultiSWA a methodthat ensembles independently trained SWA solutions (Iz-mailov et al 2018) SWA solutions are the means of thecorresponding Gaussian SWAG approximations Izmailovet al (2018) argue that SWA solutions approximate the localensembles represented by SWAG with a single model

In Figure 5 we show the negative log-likelihood as a func-tion of the number of independently trained models for aPreactivation ResNet-20 on CIFAR-10 corrupted with Gaus-sian blur with varying levels of intensity (increasing fromleft to right) in Figure 5 MultiSWAG outperforms deepensembles significantly on highly corrupted data For lowerlevels of corruption MultiSWAG works particularly wellwhen only a small number of independently trained modelsare available We note that MultiSWA also outperforms deepensembles and has the same computational requirements attraining and test time as deep ensembles We present resultsfor other types of corruption in Appendix Figures 14 15 1617 showing similar trends In general there is an extensiveevaluation of MultiSWAG in the Appendix

Our perspective of generalization is deeply connected withBayesian marginalization In order to best realize the bene-fits of marginalization in deep learning we need to consideras many hypotheses as possible through multimodal poste-rior approximations such as MultiSWAG In Section 7 wereturn to MultiSWAG showing how it can entirely alleviateprominent double descent behaviour and lead to strikingimprovements in generalization over SGD and single basinmarginalization for both accuracy and NLL

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10

minus05

00

05

10

15

20

25

30

35

(a) Exact

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(b) Deep Ensembles

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(c) Variational Inference

0 10 20 30 40 50

Samples

008

010

012

014

016

018

Avg

Wx(q

p)

Deep Ensembles

SVI

(d) Discrepancy with True BMA

Figure 4 Approximating the true predictive distribution (a) A close approximation of the true predictive distribution obtained bycombining 200 HMC chains (b) Deep ensembles predictive distribution using 50 independently trained networks (c) Predictivedistribution for factorized variational inference (VI) (d) Convergence of the predictive distributions for deep ensembles and variationalinference as a function of the number of samples we measure the average Wasserstein distance between the marginals in the range ofinput positions The multi-basin deep ensembles approach provides a more faithful approximation of the Bayesian predictive distributionthan the conventional single-basin VI approach which is overconfident between data clusters The top panels show the Wassersteindistance between the true predictive distribution and the deep ensemble and VI approximations as a function of inputs x

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

Figure 5 Negative log likelihood for Deep Ensembles MultiSWAG and MultiSWA using a PreResNet-20 on CIFAR-10 with varyingintensity of the Gaussian blur corruption The image in each plot shows the intensity of corruption For all levels of intensity MultiSWAGand MultiSWA outperform Deep Ensembles for a small number of independent models For high levels of corruption MultiSWAGsignificantly outperforms other methods even for many independent models We present results for other corruptions in the Appendix

5 Neural Network PriorsA prior over parameters p(w) combines with the functionalform of a model f(xw) to induce a distribution over func-tions p(f(xw)) It is this distribution over functions thatcontrols the generalization properties of the model the priorover parameters in isolation has no meaning Neural net-works are imbued with structural properties that providegood inductive biases such as translation equivariance hi-erarchical representations and sparsity In the sense ofFigure 2 the prior will have large support due to the flexi-bility of neural networks but its inductive biases provide themost mass to datasets which are representative of problemsettings where neural networks are often applied In thissection we study the properties of the induced distributionover functions We directly continue the discussion of priorsin Section 6 with a focus on examining the noisy CIFAR re-sults in Zhang et al (2016) from a probabilistic perspectiveof generalization These sections are best read together

We also provide several additional experiments in the Ap-pendix In Section E we present analytic results on thedependence of the prior distribution in function space onthe variance of the prior over parameters considering also

layer-wise parameter priors with ReLU activations As partof a discussion on tempering in Section 84 we study theeffect of α in p(w) = N (0 α2I) on prior class probabilitiesfor individual sample functions p(f(xw)) the predictivedistribution and posterior samples as we observe varyingamounts of data In Section F we further study the correla-tion structure over images induced by neural network priorssubject to perturbations of the images In Section D3 weprovide additional experimental details

51 Deep Image Prior and Random Network Features

Two recent results provide strong evidence that vague Gaus-sian priors over parameters when combined with a neu-ral network architecture induce a distribution over func-tions with useful inductive biases In the deep image priorUlyanov et al (2018) show that randomly initialized convo-lutional neural networks without training provide excellentperformance for image denoising super-resolution and in-painting This result demonstrates the ability for a sam-ple function from a random prior over neural networksp(f(xw)) to capture low-level image statistics before anytraining Similarly Zhang et al (2016) shows that pre-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

Figure 6 Induced prior correlation function Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNISTinduced by LeNet-5 for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images fromdifferent classes suggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α (d)NLL of an ensemble of 20 SWAG samples on MNIST as a function of α using a LeNet-5

00 02 04 06 08 104

3

2

1

0

1

2

3

4

(a) Prior Draws

00 02 04 06 08 10

(b) True Labels

00 02 04 06 08 10

(c) Corrupted Labels

0 5k 10k 12k 15k 20k 25k

Altered Labels

minus070

minus065

minus060

minus055

minus050

minus045

Mar

gina

lL

ikel

ihoo

dE

stim

ate

15

20

25

30

35

40

45

Tes

tE

rror

(d) Gaussian Process

0 5k 10k 20k 30k 40k 50k

Resampled Labels

minus5

minus4

minus3

minus2

minus1

Mar

gina

lL

ikel

ihoo

dE

stim

ate times107

20

40

60

80

Tes

tE

rror

(e) PreResNet-20

Figure 7 Rethinking generalization (a) Sample functions from a Gaussian process prior (b) GP fit (with 95 credible region) tostructured data generated as ygreen(x) = sin(x middot 2π)+ ε ε sim N (0 022) (c) GP fit with no training error after a significant addition ofcorrupted data in red drawn from Uniform[05 1] (d) Variational GP marginal likelihood with RBF kernel for two classes of CIFAR-10(e) Laplace BNN marginal likelihood for a PreResNet-20 on CIFAR-10 with different fractions of random labels The marginal likelihoodfor both the GP and BNN decreases as we increase the level of corruption in the labels suggesting reasonable inductive biases in the priorover functions Moreover both the GP and BNN have 100 training accuracy on images with fully corrupted labels

processing CIFAR-10 with a randomly initialized untrainedconvolutional neural network dramatically improves the testperformance of a simple Gaussian kernel on pixels from54 accuracy to 71 Adding `2 regularization only im-proves the accuracy by an additional 2 These resultsagain indicate that broad Gaussian priors over parametersinduce reasonable priors over networks with a minor ad-ditional gain from decreasing the variance of the prior inparameter space which corresponds to `2 regularization

52 Prior Class Correlations

In Figure 6 we study the prior correlations in the outputs ofthe LeNet-5 convolutional network (LeCun et al 1998) onobjects of different MNIST classes We sample networkswith weights p(w) = N (0 α2I) and compute the values oflogits corresponding to the first class for all pairs of imagesand compute correlations of these logits For all levels of αthe correlations between objects corresponding to the sameclass are consistently higher than the correlation betweenobjects of different classes showing that the network in-duces a reasonable prior similarity metric over these imagesAdditionally we observe that the prior correlations some-

what decrease as we increase α showing that bounding thenorm of the weights has some minor utility in accordancewith Section 51 Similarly in panel (d) we see that the NLLsignificantly decreases as α increases in [0 05] and thenslightly increases but is relatively constant thereafter

In the Appendix we further describe analytic results andillustrate the effect of α on sample functions

53 Effect of Prior Variance on CIFAR-10

We further study the effect of the parameter prior stan-dard deviation α measuring performance of approximateBayesian inference for CIFAR-10 with a PreactivationResNet-20 (He et al 2016) and VGG-16 (Simonyan ampZisserman 2014) For each of these architectures we runSWAG (Maddox et al 2019) with fixed hyper-parametersand varying α We report the results in Figure 12(d) (h) Forboth architectures the performance is near-optimal in therange α isin [10minus2 10minus1] Smaller α constrains the weightstoo much Performance is reasonable and becomes mostlyinsensitive to α as it continues to increase due to the induc-tive biases of the functional form of the neural network

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 4: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

length frameworks (eg Hinton amp Van Camp 1993 Hochre-iter amp Schmidhuber 1997 MacKay 1995) EmpiricallyKeskar et al (2016) argue that smaller batch SGD providesbetter generalization than large batch SGD by finding flat-ter minima Chaudhari et al (2019) and Izmailov et al(2018) design optimization procedures to specifically findflat minima

By connecting flat solutions with ensemble approximationsIzmailov et al (2018) also suggest that functions associatedwith parameters in flat regions ought to provide differentpredictions on test data for flatness to be helpful in gener-alization which is distinct from the flatness in Dinh et al(2017) Garipov et al (2018) also show that there are modeconnecting curves forming loss valleys which contain avariety of distinct solutions We argue that flat regions ofthe loss containing a diversity of solutions is particularlyrelevant for Bayesian model averaging since the model aver-age will then contain many compelling and complementaryexplanations for the data Additionally Huang et al (2019)describes neural networks as having a blessing of dimension-ality since flat regions will occupy much greater volume ina high dimensional space which we argue means that flatsolutions will dominate in the Bayesian model average

Smith amp Le (2018) and MacKay (2003 Chapter 28) ad-ditionally connect the width of the posterior with Occamfactors from a Bayesian perspective larger width corre-sponds to a smaller Occam factor and thus ought to providebetter generalization Dziugaite amp Roy (2017) and Smith ampLe (2018) also provide different perspectives on the resultsin Zhang et al (2016) which shows that deep convolutionalneural networks can fit CIFAR-10 with random labels andno training error The PAC-Bayes bound of Dziugaite ampRoy (2017) becomes vacuous when applied to randomly-labelled binary MNIST Smith amp Le (2018) show that logis-tic regression can fit noisy labels on sub-sampled MNISTinterpreting the result from an Occam factor perspective

In general PAC-Bayes provides a compelling frameworkfor deriving explicit non-asymptotic generalization boundsThese bounds can be improved by eg fewer parame-ters and very compact priors which can be different fromwhat provides optimal generalization From our perspec-tive model flexibility and priors with large support ratherthan compactness are desirable Our focus is complemen-tary and largely prescriptive aiming to provide intuitionson model construction inference generalization and neu-ral network priors as well as new connections betweenBayesian model averaging and deep ensembles benefitsof Bayesian model averaging specifically in the context ofmodern deep neural networks perspectives on temperingin Bayesian deep learning views of marginalization thatcontrast with simple Monte Carlo as well as new methodsfor Bayesian marginalization in deep learning

In other work Pearce et al (2018) propose a modificationof deep ensembles and argue that it performs approximateBayesian inference and Gustafsson et al (2019) briefly men-tion how deep ensembles can be viewed as samples froman approximate posterior In the context of deep ensembleswe believe it is natural to consider the BMA integral sepa-rately from the simple Monte Carlo approximation that isoften used to approximate this integral to compute an accu-rate predictive distribution we do not need samples from aposterior or even a faithful approximation to the posterior

Fort et al (2019) considered the diversity of predictionsproduced by models from a single SGD run and modelsfrom independent SGD runs and suggested to ensembleaverages of SGD iterates Although MultiSWA (one of themethods considered in Section 4) is related to this idea thecrucial practical difference is that MultiSWA uses a learningrate schedule that selects for flat regions of the loss the keyto the success of the SWA method (Izmailov et al 2018)Section 4 also shows that MultiSWAG which we proposefor multimodal Bayesian marginalization outperforms Mul-tiSWA

Double descent which describes generalization error thatdecreases increases and then again decreases with modelflexibility was demonstrated early by Opper et al (1990)Recently Belkin et al (2019) extensively demonstrateddouble descent leading to a surge of modern interest withNakkiran et al (2019) showing double descent in deep learn-ing Nakkiran et al (2020) shows that tuned l2 regulariza-tion can mitigate double descent Alternatively we showthat Bayesian model averaging particularly based on multi-modal marginalization can alleviate even prominent doubledescent behaviour

Tempering in Bayesian modelling has been considered underthe names Safe Bayes generalized Bayesian inference andfractional Bayesian inference (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) We provide several perspectives of tempering inBayesian deep learning and analyze the results in a recentpaper by Wenzel et al (2020) that questions tempering forBayesian neural networks

3 Bayesian MarginalizationOften the predictive distribution we want to compute isgiven by

p(y|xD) =

intp(y|xw)p(w|D)dw (1)

The outputs are y (eg regression values class labels )indexed by inputs x (eg spatial locations images ) theweights (or parameters) of the neural network f(xw) arew and D are the data Eq (1) represents a Bayesian model

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

average (BMA) Rather than bet everything on one hypoth-esis mdash with a single setting of parameters w mdash we want touse all settings of parameters weighted by their posteriorprobabilities This procedure is called marginalization ofthe parametersw as the predictive distribution of interest nolonger conditions on w This is not a controversial equationbut simply the sum and product rules of probability

31 Importance of Marginalization in Deep Learning

In general we can view classical training as performingapproximate Bayesian inference using the approximate pos-terior p(w|D) asymp δ(w = w) to compute Eq (1) where δis a Dirac delta function that is zero everywhere except atw = argmaxwp(w|D) In this case we recover the standardpredictive distribution p(y|x w) From this perspectivemany alternatives albeit imperfect will be preferable mdashincluding impoverished Gaussian posterior approximationsfor p(w|D) even if the posterior or likelihood are actuallyhighly non-Gaussian and multimodal

The difference between a classical and Bayesian approachwill depend on how sharp the posterior p(w|D) becomes Ifthe posterior is sharply peaked and the conditional predic-tive distribution p(y|xw) does not vary significantly wherethe posterior has mass there may be almost no differencesince a delta function may then be a reasonable approxi-mation of the posterior for the purpose of BMA Howevermodern neural networks are usually highly underspecifiedby the available data and therefore have diffuse likelihoodsp(D|w) not strongly favouring any one setting of param-eters Not only are the likelihoods diffuse but differentsettings of the parameters correspond to a diverse variety ofcompelling hypotheses for the data (Garipov et al 2018Izmailov et al 2019) This is exactly the setting when wemost want to perform a Bayesian model average whichwill lead to an ensemble containing many different but highperforming models for better calibration and accuracy thanclassical training

Loss Valleys Flat regions of low loss (negative log poste-rior density minus log p(w|D)) are associated with good gener-alization (eg Hochreiter amp Schmidhuber 1997 Hinton ampVan Camp 1993 Dziugaite amp Roy 2017 Izmailov et al2018 Keskar et al 2016) While flat solutions that general-ize poorly can be contrived through reparametrization (Dinhet al 2017) the flat regions that lead to good generalizationcontain a diversity of high performing models on test data(Izmailov et al 2018) corresponding to different parametersettings in those regions And indeed there are large con-tiguous regions of low loss that contain such solutions evenconnecting together different SGD solutions (Garipov et al2018 Izmailov et al 2019) (see also Figure 11 Appendix)

Since these regions of the loss represent a large volume in a

high-dimensional space (Huang et al 2019) and providea diversity of solutions they will dominate in forming thepredictive distribution in a Bayesian model average Bycontrast if the parameters in these regions provided similarfunctions as would be the case in flatness obtained throughreparametrization these functions would be redundant inthe model average That is although the solutions of highposterior density can provide poor generalization it is the so-lutions that generalize well that will have greatest posteriormass and thus be automatically favoured by the BMA

Calibration by Epistemic Uncertainty RepresentationIt has been noticed that modern neural networks are oftenmiscalibrated in the sense that their predictions are typicallyoverconfident (Guo et al 2017) For example in classifi-cation the highest softmax output of a convolutional neuralnetwork is typically much larger than the probability of theassociated class label The fundamental reason for miscali-bration is ignoring epistemic uncertainty A neural networkcan represent many models that are consistent with our ob-servations By selecting only one in a classical procedurewe lose uncertainty when the models disagree for a testpoint In regression we can visualize epistemic uncertaintyby looking at the spread of the predictive distribution aswe move away from the data there are a greater varietyof consistent solutions leading to larger uncertainty as inFigure 4 We can further calibrate the model with temperingwhich we discuss in the Appendix Section 8

Accuracy An often overlooked benefit of Bayesian modelaveraging in modern deep learning is improved accuracy Ifwe average the predictions of many high performing modelsthat disagree in some cases we should see significantly im-proved accuracy This benefit is now starting to be observedin practice (eg Izmailov et al 2019) Improvements inaccuracy are very convincingly exemplified by deep en-sembles (Lakshminarayanan et al 2017) which have beenperceived as a competing approach to Bayesian methodsbut in fact provides a compelling mechanism for approxi-mate Bayesian model averaging as we show in Section 33We also demonstrate significant accuracy benefits for multi-modal Bayesian marginalization in Section 7

32 Beyond Monte Carlo

Nearly all approaches to estimating the integral in Eq (1)when it cannot be computed in closed form involvea simple Monte Carlo approximation p(y|xD) asymp1J

sumJj=1 p(y|xwj) wj sim p(w|D) In practice the sam-

ples from the posterior p(w|D) are also approximate andfound through MCMC or deterministic methods The de-terministic methods approximate p(w|D) with a differentmore convenient density q(w|D θ) from which we can sam-ple often chosen to be Gaussian The parameters θ are

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

selected to make q close to p in some sense for exam-ple variational approximations (eg Beal 2003) whichhave emerged as a popular deterministic approach findargminθKL(q||p) Other standard deterministic approxima-tions include Laplace (eg MacKay 1995) EP (Minka2001a) and INLA (Rue et al 2009)

From the perspective of estimating the predictive distribu-tion in Eq (1) we can view simple Monte Carlo as ap-proximating the posterior with a set of point masses withlocations given by samples from another approximate pos-terior q even if q is a continuous distribution That isp(w|D) asympsumJ

j=1 δ(w = wj) wj sim q(w|D)

Ultimately the goal is to accurately compute the predictivedistribution in Eq (1) rather than find a generally accu-rate representation of the posterior In particular we mustcarefully represent the posterior in regions that will makethe greatest contributions to the BMA integral In terms ofefficiently computing the predictive distribution we do notnecessarily want to place point masses at locations givenby samples from the posterior For example functional di-versity is important for a good approximation to the BMAintegral because we are summing together terms of theform p(y|xw) if two settings of the weights wi and wjeach provide high likelihood (and consequently high pos-terior density) but give rise to similar functions f(xwi)f(xwj) then they will be largely redundant in the modelaverage and the second setting of parameters will not con-tribute much to estimating the BMA integral for the uncon-ditional predictive distribution In Sections 33 and 4 weconsider how various approaches approximate the predictivedistribution

33 Deep Ensembles are BMA

Deep ensembles (Lakshminarayanan et al 2017) is fastbecoming a gold standard for accurate and well-calibratedpredictive distributions Recent reports (eg Ovadia et al2019 Ashukha et al 2020) show that deep ensembles ap-pear to outperform some particular approaches to Bayesianneural networks for uncertainty representation leading tothe confusion that deep ensembles and Bayesian methodsare competing approaches These methods are often explic-itly referred to as non-Bayesian (eg Lakshminarayananet al 2017 Ovadia et al 2019 Wenzel et al 2020) Tothe contrary we argue that deep ensembles are actually acompelling approach to Bayesian model averaging in thevein of Section 32

There is a fundamental difference between a Bayesian modelaverage and some approaches to ensembling The Bayesianmodel average assumes that one hypothesis (one parame-ter setting) is correct and averages over models due to aninability to distinguish between hypotheses given limitedinformation (Minka 2000) As we observe more data the

posterior collapses onto a single hypothesis If the trueexplanation for the data is a combination of hypothesesthen the Bayesian model average may appear to performworse as we observe more data Some ensembling methodswork by enriching the hypothesis space and therefore donot collapse in this way Deep ensembles however areformed by MAP or maximum likelihood retraining of thesame architecture multiple times leading to different basinsof attraction The deep ensemble will therefore collapse inthe same way as a Bayesian model average as the posteriorconcentrates Since the hypotheses space (support) for amodern neural network is large containing many differentpossible explanations for the data posterior collapse willoften be desirable

Furthermore by representing multiple basins of attractiondeep ensembles can provide a better approximation to theBMA than the Bayesian approaches in Ovadia et al (2019)Indeed the functional diversity is important for a good ap-proximation to the BMA integral as per Section 32 Theapproaches referred to as Bayesian in Ovadia et al (2019)instead focus their approximation on a single basin whichmay contain a lot of redundancy in function space makinga relatively minimal contribution to computing the Bayesianpredictive distribution On the other hand retraining a neu-ral network multiple times for deep ensembles incurs asignificant computational expense The single basin ap-proaches may be preferred if we are to control for computa-tion We explore these questions in Section 4

4 An Empirical Study of MarginalizationWe have shown that deep ensembles can be interpretedas an approximate approach to Bayesian marginalizationwhich selects for functional diversity by representing multi-ple basins of attraction in the posterior Most Bayesian deeplearning methods instead focus on faithfully approximatinga posterior within a single basin of attraction We propose anew method MultiSWAG which combines these two typesof approaches MultiSWAG combines multiple indepen-dently trained SWAG approximations (Maddox et al 2019)to create a mixture of Gaussians approximation to the pos-terior with each Gaussian centred on a different basin ofattraction We note that MultiSWAG does not require anyadditional training time over standard deep ensembles

We illustrate the conceptual difference between deep en-sembles a standard variational single basin approach andMultiSWAG in Figure 3 In the top panel we have a con-ceptualization of a multimodal posterior VI approximatesthe posterior with multiple samples within a single basinBut we see in the middle panel that the conditional predic-tive distribution p(y|xw) does not vary significantly withinthe basin and thus each additional sample contributes min-imally to computing the marginal predictive distribution

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Deep Ensembles VI Multi-SWAG

ww

w

w

dist(p q)

p(y|w)

p(w|D)

Figure 3 Approximating the BMAp(y|xD) =

intp(y|xw)p(w|D)dw Top p(w|D) with repre-

sentations from VI (orange) deep ensembles (blue) MultiSWAG(red) Middle p(y|xw) as a function of w for a test input xThis function does not vary much within modes but changes sig-nificantly between modes Bottom Distance between the truepredictive distribution and the approximation as a function of rep-resenting a posterior at an additional point w assuming we havesampled the mode in dark green There is more to be gained byexploring new basins than continuing to explore the same basin

p(y|xD) On the other hand p(y|xw) varies significantlybetween basins and thus each point mass for deep ensem-bles contributes significantly to the marginal predictive dis-tribution By sampling within the basins MultiSWAG pro-vides additional contributions to the predictive distributionIn the bottom panel we have the gain in approximatingthe predictive distribution when adding a point mass to therepresentation of the posterior as a function of its locationassuming we have already sampled the mode in dark greenIncluding samples from different modes provides significantgain over continuing to sample from the same mode andincluding weights in wide basins provide relatively moregain than the narrow ones

In Figure 4 we evaluate single basin and multi-basin ap-proaches in a case where we can near-exactly compute thepredictive distribution We provide details for generatingthe data and training the models in Appendix D1 We seethat the predictive distribution given by deep ensembles isqualitatively closer to the true distribution compared to thesingle basin variational method between data clusters thedeep ensemble approach provides a similar representationof epistemic uncertainty whereas the variational method isextremely overconfident in these regions Moreover we seethat the Wasserstein distance between the true predictive dis-tribution and these two approximations quickly shrinks withnumber of samples for deep ensembles but is roughly inde-pendent of number of samples for the variational approachThus the deep ensemble is providing a better approximation

of the Bayesian model average in Eq (1) than the singlebasin variational approach which has traditionally beenlabelled as the Bayesian alternative

Next we evaluate MultiSWAG under distribution shift onthe CIFAR-10 dataset (Krizhevsky et al 2014) replicatingthe setup in Ovadia et al (2019) We consider 16 data cor-ruptions each at 5 different levels of severity introducedby Hendrycks amp Dietterich (2019) For each corruptionwe evaluate the performance of deep ensembles and Mul-tiSWAG varying the training budget For deep ensembleswe show performance as a function of the number of inde-pendently trained models in the ensemble For MultiSWAGwe show performance as a function of the number of inde-pendent SWAG approximations that we construct we thensample 20 models from each of these approximations toconstruct the final ensemble

While the training time for MultiSWAG is the same as fordeep ensembles at test time MultiSWAG is more expen-sive as the corresponding ensemble consists of a largernumber of models To account for situations when testtime is constrained we also propose MultiSWA a methodthat ensembles independently trained SWA solutions (Iz-mailov et al 2018) SWA solutions are the means of thecorresponding Gaussian SWAG approximations Izmailovet al (2018) argue that SWA solutions approximate the localensembles represented by SWAG with a single model

In Figure 5 we show the negative log-likelihood as a func-tion of the number of independently trained models for aPreactivation ResNet-20 on CIFAR-10 corrupted with Gaus-sian blur with varying levels of intensity (increasing fromleft to right) in Figure 5 MultiSWAG outperforms deepensembles significantly on highly corrupted data For lowerlevels of corruption MultiSWAG works particularly wellwhen only a small number of independently trained modelsare available We note that MultiSWA also outperforms deepensembles and has the same computational requirements attraining and test time as deep ensembles We present resultsfor other types of corruption in Appendix Figures 14 15 1617 showing similar trends In general there is an extensiveevaluation of MultiSWAG in the Appendix

Our perspective of generalization is deeply connected withBayesian marginalization In order to best realize the bene-fits of marginalization in deep learning we need to consideras many hypotheses as possible through multimodal poste-rior approximations such as MultiSWAG In Section 7 wereturn to MultiSWAG showing how it can entirely alleviateprominent double descent behaviour and lead to strikingimprovements in generalization over SGD and single basinmarginalization for both accuracy and NLL

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10

minus05

00

05

10

15

20

25

30

35

(a) Exact

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(b) Deep Ensembles

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(c) Variational Inference

0 10 20 30 40 50

Samples

008

010

012

014

016

018

Avg

Wx(q

p)

Deep Ensembles

SVI

(d) Discrepancy with True BMA

Figure 4 Approximating the true predictive distribution (a) A close approximation of the true predictive distribution obtained bycombining 200 HMC chains (b) Deep ensembles predictive distribution using 50 independently trained networks (c) Predictivedistribution for factorized variational inference (VI) (d) Convergence of the predictive distributions for deep ensembles and variationalinference as a function of the number of samples we measure the average Wasserstein distance between the marginals in the range ofinput positions The multi-basin deep ensembles approach provides a more faithful approximation of the Bayesian predictive distributionthan the conventional single-basin VI approach which is overconfident between data clusters The top panels show the Wassersteindistance between the true predictive distribution and the deep ensemble and VI approximations as a function of inputs x

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

Figure 5 Negative log likelihood for Deep Ensembles MultiSWAG and MultiSWA using a PreResNet-20 on CIFAR-10 with varyingintensity of the Gaussian blur corruption The image in each plot shows the intensity of corruption For all levels of intensity MultiSWAGand MultiSWA outperform Deep Ensembles for a small number of independent models For high levels of corruption MultiSWAGsignificantly outperforms other methods even for many independent models We present results for other corruptions in the Appendix

5 Neural Network PriorsA prior over parameters p(w) combines with the functionalform of a model f(xw) to induce a distribution over func-tions p(f(xw)) It is this distribution over functions thatcontrols the generalization properties of the model the priorover parameters in isolation has no meaning Neural net-works are imbued with structural properties that providegood inductive biases such as translation equivariance hi-erarchical representations and sparsity In the sense ofFigure 2 the prior will have large support due to the flexi-bility of neural networks but its inductive biases provide themost mass to datasets which are representative of problemsettings where neural networks are often applied In thissection we study the properties of the induced distributionover functions We directly continue the discussion of priorsin Section 6 with a focus on examining the noisy CIFAR re-sults in Zhang et al (2016) from a probabilistic perspectiveof generalization These sections are best read together

We also provide several additional experiments in the Ap-pendix In Section E we present analytic results on thedependence of the prior distribution in function space onthe variance of the prior over parameters considering also

layer-wise parameter priors with ReLU activations As partof a discussion on tempering in Section 84 we study theeffect of α in p(w) = N (0 α2I) on prior class probabilitiesfor individual sample functions p(f(xw)) the predictivedistribution and posterior samples as we observe varyingamounts of data In Section F we further study the correla-tion structure over images induced by neural network priorssubject to perturbations of the images In Section D3 weprovide additional experimental details

51 Deep Image Prior and Random Network Features

Two recent results provide strong evidence that vague Gaus-sian priors over parameters when combined with a neu-ral network architecture induce a distribution over func-tions with useful inductive biases In the deep image priorUlyanov et al (2018) show that randomly initialized convo-lutional neural networks without training provide excellentperformance for image denoising super-resolution and in-painting This result demonstrates the ability for a sam-ple function from a random prior over neural networksp(f(xw)) to capture low-level image statistics before anytraining Similarly Zhang et al (2016) shows that pre-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

Figure 6 Induced prior correlation function Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNISTinduced by LeNet-5 for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images fromdifferent classes suggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α (d)NLL of an ensemble of 20 SWAG samples on MNIST as a function of α using a LeNet-5

00 02 04 06 08 104

3

2

1

0

1

2

3

4

(a) Prior Draws

00 02 04 06 08 10

(b) True Labels

00 02 04 06 08 10

(c) Corrupted Labels

0 5k 10k 12k 15k 20k 25k

Altered Labels

minus070

minus065

minus060

minus055

minus050

minus045

Mar

gina

lL

ikel

ihoo

dE

stim

ate

15

20

25

30

35

40

45

Tes

tE

rror

(d) Gaussian Process

0 5k 10k 20k 30k 40k 50k

Resampled Labels

minus5

minus4

minus3

minus2

minus1

Mar

gina

lL

ikel

ihoo

dE

stim

ate times107

20

40

60

80

Tes

tE

rror

(e) PreResNet-20

Figure 7 Rethinking generalization (a) Sample functions from a Gaussian process prior (b) GP fit (with 95 credible region) tostructured data generated as ygreen(x) = sin(x middot 2π)+ ε ε sim N (0 022) (c) GP fit with no training error after a significant addition ofcorrupted data in red drawn from Uniform[05 1] (d) Variational GP marginal likelihood with RBF kernel for two classes of CIFAR-10(e) Laplace BNN marginal likelihood for a PreResNet-20 on CIFAR-10 with different fractions of random labels The marginal likelihoodfor both the GP and BNN decreases as we increase the level of corruption in the labels suggesting reasonable inductive biases in the priorover functions Moreover both the GP and BNN have 100 training accuracy on images with fully corrupted labels

processing CIFAR-10 with a randomly initialized untrainedconvolutional neural network dramatically improves the testperformance of a simple Gaussian kernel on pixels from54 accuracy to 71 Adding `2 regularization only im-proves the accuracy by an additional 2 These resultsagain indicate that broad Gaussian priors over parametersinduce reasonable priors over networks with a minor ad-ditional gain from decreasing the variance of the prior inparameter space which corresponds to `2 regularization

52 Prior Class Correlations

In Figure 6 we study the prior correlations in the outputs ofthe LeNet-5 convolutional network (LeCun et al 1998) onobjects of different MNIST classes We sample networkswith weights p(w) = N (0 α2I) and compute the values oflogits corresponding to the first class for all pairs of imagesand compute correlations of these logits For all levels of αthe correlations between objects corresponding to the sameclass are consistently higher than the correlation betweenobjects of different classes showing that the network in-duces a reasonable prior similarity metric over these imagesAdditionally we observe that the prior correlations some-

what decrease as we increase α showing that bounding thenorm of the weights has some minor utility in accordancewith Section 51 Similarly in panel (d) we see that the NLLsignificantly decreases as α increases in [0 05] and thenslightly increases but is relatively constant thereafter

In the Appendix we further describe analytic results andillustrate the effect of α on sample functions

53 Effect of Prior Variance on CIFAR-10

We further study the effect of the parameter prior stan-dard deviation α measuring performance of approximateBayesian inference for CIFAR-10 with a PreactivationResNet-20 (He et al 2016) and VGG-16 (Simonyan ampZisserman 2014) For each of these architectures we runSWAG (Maddox et al 2019) with fixed hyper-parametersand varying α We report the results in Figure 12(d) (h) Forboth architectures the performance is near-optimal in therange α isin [10minus2 10minus1] Smaller α constrains the weightstoo much Performance is reasonable and becomes mostlyinsensitive to α as it continues to increase due to the induc-tive biases of the functional form of the neural network

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 5: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

average (BMA) Rather than bet everything on one hypoth-esis mdash with a single setting of parameters w mdash we want touse all settings of parameters weighted by their posteriorprobabilities This procedure is called marginalization ofthe parametersw as the predictive distribution of interest nolonger conditions on w This is not a controversial equationbut simply the sum and product rules of probability

31 Importance of Marginalization in Deep Learning

In general we can view classical training as performingapproximate Bayesian inference using the approximate pos-terior p(w|D) asymp δ(w = w) to compute Eq (1) where δis a Dirac delta function that is zero everywhere except atw = argmaxwp(w|D) In this case we recover the standardpredictive distribution p(y|x w) From this perspectivemany alternatives albeit imperfect will be preferable mdashincluding impoverished Gaussian posterior approximationsfor p(w|D) even if the posterior or likelihood are actuallyhighly non-Gaussian and multimodal

The difference between a classical and Bayesian approachwill depend on how sharp the posterior p(w|D) becomes Ifthe posterior is sharply peaked and the conditional predic-tive distribution p(y|xw) does not vary significantly wherethe posterior has mass there may be almost no differencesince a delta function may then be a reasonable approxi-mation of the posterior for the purpose of BMA Howevermodern neural networks are usually highly underspecifiedby the available data and therefore have diffuse likelihoodsp(D|w) not strongly favouring any one setting of param-eters Not only are the likelihoods diffuse but differentsettings of the parameters correspond to a diverse variety ofcompelling hypotheses for the data (Garipov et al 2018Izmailov et al 2019) This is exactly the setting when wemost want to perform a Bayesian model average whichwill lead to an ensemble containing many different but highperforming models for better calibration and accuracy thanclassical training

Loss Valleys Flat regions of low loss (negative log poste-rior density minus log p(w|D)) are associated with good gener-alization (eg Hochreiter amp Schmidhuber 1997 Hinton ampVan Camp 1993 Dziugaite amp Roy 2017 Izmailov et al2018 Keskar et al 2016) While flat solutions that general-ize poorly can be contrived through reparametrization (Dinhet al 2017) the flat regions that lead to good generalizationcontain a diversity of high performing models on test data(Izmailov et al 2018) corresponding to different parametersettings in those regions And indeed there are large con-tiguous regions of low loss that contain such solutions evenconnecting together different SGD solutions (Garipov et al2018 Izmailov et al 2019) (see also Figure 11 Appendix)

Since these regions of the loss represent a large volume in a

high-dimensional space (Huang et al 2019) and providea diversity of solutions they will dominate in forming thepredictive distribution in a Bayesian model average Bycontrast if the parameters in these regions provided similarfunctions as would be the case in flatness obtained throughreparametrization these functions would be redundant inthe model average That is although the solutions of highposterior density can provide poor generalization it is the so-lutions that generalize well that will have greatest posteriormass and thus be automatically favoured by the BMA

Calibration by Epistemic Uncertainty RepresentationIt has been noticed that modern neural networks are oftenmiscalibrated in the sense that their predictions are typicallyoverconfident (Guo et al 2017) For example in classifi-cation the highest softmax output of a convolutional neuralnetwork is typically much larger than the probability of theassociated class label The fundamental reason for miscali-bration is ignoring epistemic uncertainty A neural networkcan represent many models that are consistent with our ob-servations By selecting only one in a classical procedurewe lose uncertainty when the models disagree for a testpoint In regression we can visualize epistemic uncertaintyby looking at the spread of the predictive distribution aswe move away from the data there are a greater varietyof consistent solutions leading to larger uncertainty as inFigure 4 We can further calibrate the model with temperingwhich we discuss in the Appendix Section 8

Accuracy An often overlooked benefit of Bayesian modelaveraging in modern deep learning is improved accuracy Ifwe average the predictions of many high performing modelsthat disagree in some cases we should see significantly im-proved accuracy This benefit is now starting to be observedin practice (eg Izmailov et al 2019) Improvements inaccuracy are very convincingly exemplified by deep en-sembles (Lakshminarayanan et al 2017) which have beenperceived as a competing approach to Bayesian methodsbut in fact provides a compelling mechanism for approxi-mate Bayesian model averaging as we show in Section 33We also demonstrate significant accuracy benefits for multi-modal Bayesian marginalization in Section 7

32 Beyond Monte Carlo

Nearly all approaches to estimating the integral in Eq (1)when it cannot be computed in closed form involvea simple Monte Carlo approximation p(y|xD) asymp1J

sumJj=1 p(y|xwj) wj sim p(w|D) In practice the sam-

ples from the posterior p(w|D) are also approximate andfound through MCMC or deterministic methods The de-terministic methods approximate p(w|D) with a differentmore convenient density q(w|D θ) from which we can sam-ple often chosen to be Gaussian The parameters θ are

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

selected to make q close to p in some sense for exam-ple variational approximations (eg Beal 2003) whichhave emerged as a popular deterministic approach findargminθKL(q||p) Other standard deterministic approxima-tions include Laplace (eg MacKay 1995) EP (Minka2001a) and INLA (Rue et al 2009)

From the perspective of estimating the predictive distribu-tion in Eq (1) we can view simple Monte Carlo as ap-proximating the posterior with a set of point masses withlocations given by samples from another approximate pos-terior q even if q is a continuous distribution That isp(w|D) asympsumJ

j=1 δ(w = wj) wj sim q(w|D)

Ultimately the goal is to accurately compute the predictivedistribution in Eq (1) rather than find a generally accu-rate representation of the posterior In particular we mustcarefully represent the posterior in regions that will makethe greatest contributions to the BMA integral In terms ofefficiently computing the predictive distribution we do notnecessarily want to place point masses at locations givenby samples from the posterior For example functional di-versity is important for a good approximation to the BMAintegral because we are summing together terms of theform p(y|xw) if two settings of the weights wi and wjeach provide high likelihood (and consequently high pos-terior density) but give rise to similar functions f(xwi)f(xwj) then they will be largely redundant in the modelaverage and the second setting of parameters will not con-tribute much to estimating the BMA integral for the uncon-ditional predictive distribution In Sections 33 and 4 weconsider how various approaches approximate the predictivedistribution

33 Deep Ensembles are BMA

Deep ensembles (Lakshminarayanan et al 2017) is fastbecoming a gold standard for accurate and well-calibratedpredictive distributions Recent reports (eg Ovadia et al2019 Ashukha et al 2020) show that deep ensembles ap-pear to outperform some particular approaches to Bayesianneural networks for uncertainty representation leading tothe confusion that deep ensembles and Bayesian methodsare competing approaches These methods are often explic-itly referred to as non-Bayesian (eg Lakshminarayananet al 2017 Ovadia et al 2019 Wenzel et al 2020) Tothe contrary we argue that deep ensembles are actually acompelling approach to Bayesian model averaging in thevein of Section 32

There is a fundamental difference between a Bayesian modelaverage and some approaches to ensembling The Bayesianmodel average assumes that one hypothesis (one parame-ter setting) is correct and averages over models due to aninability to distinguish between hypotheses given limitedinformation (Minka 2000) As we observe more data the

posterior collapses onto a single hypothesis If the trueexplanation for the data is a combination of hypothesesthen the Bayesian model average may appear to performworse as we observe more data Some ensembling methodswork by enriching the hypothesis space and therefore donot collapse in this way Deep ensembles however areformed by MAP or maximum likelihood retraining of thesame architecture multiple times leading to different basinsof attraction The deep ensemble will therefore collapse inthe same way as a Bayesian model average as the posteriorconcentrates Since the hypotheses space (support) for amodern neural network is large containing many differentpossible explanations for the data posterior collapse willoften be desirable

Furthermore by representing multiple basins of attractiondeep ensembles can provide a better approximation to theBMA than the Bayesian approaches in Ovadia et al (2019)Indeed the functional diversity is important for a good ap-proximation to the BMA integral as per Section 32 Theapproaches referred to as Bayesian in Ovadia et al (2019)instead focus their approximation on a single basin whichmay contain a lot of redundancy in function space makinga relatively minimal contribution to computing the Bayesianpredictive distribution On the other hand retraining a neu-ral network multiple times for deep ensembles incurs asignificant computational expense The single basin ap-proaches may be preferred if we are to control for computa-tion We explore these questions in Section 4

4 An Empirical Study of MarginalizationWe have shown that deep ensembles can be interpretedas an approximate approach to Bayesian marginalizationwhich selects for functional diversity by representing multi-ple basins of attraction in the posterior Most Bayesian deeplearning methods instead focus on faithfully approximatinga posterior within a single basin of attraction We propose anew method MultiSWAG which combines these two typesof approaches MultiSWAG combines multiple indepen-dently trained SWAG approximations (Maddox et al 2019)to create a mixture of Gaussians approximation to the pos-terior with each Gaussian centred on a different basin ofattraction We note that MultiSWAG does not require anyadditional training time over standard deep ensembles

We illustrate the conceptual difference between deep en-sembles a standard variational single basin approach andMultiSWAG in Figure 3 In the top panel we have a con-ceptualization of a multimodal posterior VI approximatesthe posterior with multiple samples within a single basinBut we see in the middle panel that the conditional predic-tive distribution p(y|xw) does not vary significantly withinthe basin and thus each additional sample contributes min-imally to computing the marginal predictive distribution

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Deep Ensembles VI Multi-SWAG

ww

w

w

dist(p q)

p(y|w)

p(w|D)

Figure 3 Approximating the BMAp(y|xD) =

intp(y|xw)p(w|D)dw Top p(w|D) with repre-

sentations from VI (orange) deep ensembles (blue) MultiSWAG(red) Middle p(y|xw) as a function of w for a test input xThis function does not vary much within modes but changes sig-nificantly between modes Bottom Distance between the truepredictive distribution and the approximation as a function of rep-resenting a posterior at an additional point w assuming we havesampled the mode in dark green There is more to be gained byexploring new basins than continuing to explore the same basin

p(y|xD) On the other hand p(y|xw) varies significantlybetween basins and thus each point mass for deep ensem-bles contributes significantly to the marginal predictive dis-tribution By sampling within the basins MultiSWAG pro-vides additional contributions to the predictive distributionIn the bottom panel we have the gain in approximatingthe predictive distribution when adding a point mass to therepresentation of the posterior as a function of its locationassuming we have already sampled the mode in dark greenIncluding samples from different modes provides significantgain over continuing to sample from the same mode andincluding weights in wide basins provide relatively moregain than the narrow ones

In Figure 4 we evaluate single basin and multi-basin ap-proaches in a case where we can near-exactly compute thepredictive distribution We provide details for generatingthe data and training the models in Appendix D1 We seethat the predictive distribution given by deep ensembles isqualitatively closer to the true distribution compared to thesingle basin variational method between data clusters thedeep ensemble approach provides a similar representationof epistemic uncertainty whereas the variational method isextremely overconfident in these regions Moreover we seethat the Wasserstein distance between the true predictive dis-tribution and these two approximations quickly shrinks withnumber of samples for deep ensembles but is roughly inde-pendent of number of samples for the variational approachThus the deep ensemble is providing a better approximation

of the Bayesian model average in Eq (1) than the singlebasin variational approach which has traditionally beenlabelled as the Bayesian alternative

Next we evaluate MultiSWAG under distribution shift onthe CIFAR-10 dataset (Krizhevsky et al 2014) replicatingthe setup in Ovadia et al (2019) We consider 16 data cor-ruptions each at 5 different levels of severity introducedby Hendrycks amp Dietterich (2019) For each corruptionwe evaluate the performance of deep ensembles and Mul-tiSWAG varying the training budget For deep ensembleswe show performance as a function of the number of inde-pendently trained models in the ensemble For MultiSWAGwe show performance as a function of the number of inde-pendent SWAG approximations that we construct we thensample 20 models from each of these approximations toconstruct the final ensemble

While the training time for MultiSWAG is the same as fordeep ensembles at test time MultiSWAG is more expen-sive as the corresponding ensemble consists of a largernumber of models To account for situations when testtime is constrained we also propose MultiSWA a methodthat ensembles independently trained SWA solutions (Iz-mailov et al 2018) SWA solutions are the means of thecorresponding Gaussian SWAG approximations Izmailovet al (2018) argue that SWA solutions approximate the localensembles represented by SWAG with a single model

In Figure 5 we show the negative log-likelihood as a func-tion of the number of independently trained models for aPreactivation ResNet-20 on CIFAR-10 corrupted with Gaus-sian blur with varying levels of intensity (increasing fromleft to right) in Figure 5 MultiSWAG outperforms deepensembles significantly on highly corrupted data For lowerlevels of corruption MultiSWAG works particularly wellwhen only a small number of independently trained modelsare available We note that MultiSWA also outperforms deepensembles and has the same computational requirements attraining and test time as deep ensembles We present resultsfor other types of corruption in Appendix Figures 14 15 1617 showing similar trends In general there is an extensiveevaluation of MultiSWAG in the Appendix

Our perspective of generalization is deeply connected withBayesian marginalization In order to best realize the bene-fits of marginalization in deep learning we need to consideras many hypotheses as possible through multimodal poste-rior approximations such as MultiSWAG In Section 7 wereturn to MultiSWAG showing how it can entirely alleviateprominent double descent behaviour and lead to strikingimprovements in generalization over SGD and single basinmarginalization for both accuracy and NLL

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10

minus05

00

05

10

15

20

25

30

35

(a) Exact

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(b) Deep Ensembles

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(c) Variational Inference

0 10 20 30 40 50

Samples

008

010

012

014

016

018

Avg

Wx(q

p)

Deep Ensembles

SVI

(d) Discrepancy with True BMA

Figure 4 Approximating the true predictive distribution (a) A close approximation of the true predictive distribution obtained bycombining 200 HMC chains (b) Deep ensembles predictive distribution using 50 independently trained networks (c) Predictivedistribution for factorized variational inference (VI) (d) Convergence of the predictive distributions for deep ensembles and variationalinference as a function of the number of samples we measure the average Wasserstein distance between the marginals in the range ofinput positions The multi-basin deep ensembles approach provides a more faithful approximation of the Bayesian predictive distributionthan the conventional single-basin VI approach which is overconfident between data clusters The top panels show the Wassersteindistance between the true predictive distribution and the deep ensemble and VI approximations as a function of inputs x

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

Figure 5 Negative log likelihood for Deep Ensembles MultiSWAG and MultiSWA using a PreResNet-20 on CIFAR-10 with varyingintensity of the Gaussian blur corruption The image in each plot shows the intensity of corruption For all levels of intensity MultiSWAGand MultiSWA outperform Deep Ensembles for a small number of independent models For high levels of corruption MultiSWAGsignificantly outperforms other methods even for many independent models We present results for other corruptions in the Appendix

5 Neural Network PriorsA prior over parameters p(w) combines with the functionalform of a model f(xw) to induce a distribution over func-tions p(f(xw)) It is this distribution over functions thatcontrols the generalization properties of the model the priorover parameters in isolation has no meaning Neural net-works are imbued with structural properties that providegood inductive biases such as translation equivariance hi-erarchical representations and sparsity In the sense ofFigure 2 the prior will have large support due to the flexi-bility of neural networks but its inductive biases provide themost mass to datasets which are representative of problemsettings where neural networks are often applied In thissection we study the properties of the induced distributionover functions We directly continue the discussion of priorsin Section 6 with a focus on examining the noisy CIFAR re-sults in Zhang et al (2016) from a probabilistic perspectiveof generalization These sections are best read together

We also provide several additional experiments in the Ap-pendix In Section E we present analytic results on thedependence of the prior distribution in function space onthe variance of the prior over parameters considering also

layer-wise parameter priors with ReLU activations As partof a discussion on tempering in Section 84 we study theeffect of α in p(w) = N (0 α2I) on prior class probabilitiesfor individual sample functions p(f(xw)) the predictivedistribution and posterior samples as we observe varyingamounts of data In Section F we further study the correla-tion structure over images induced by neural network priorssubject to perturbations of the images In Section D3 weprovide additional experimental details

51 Deep Image Prior and Random Network Features

Two recent results provide strong evidence that vague Gaus-sian priors over parameters when combined with a neu-ral network architecture induce a distribution over func-tions with useful inductive biases In the deep image priorUlyanov et al (2018) show that randomly initialized convo-lutional neural networks without training provide excellentperformance for image denoising super-resolution and in-painting This result demonstrates the ability for a sam-ple function from a random prior over neural networksp(f(xw)) to capture low-level image statistics before anytraining Similarly Zhang et al (2016) shows that pre-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

Figure 6 Induced prior correlation function Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNISTinduced by LeNet-5 for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images fromdifferent classes suggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α (d)NLL of an ensemble of 20 SWAG samples on MNIST as a function of α using a LeNet-5

00 02 04 06 08 104

3

2

1

0

1

2

3

4

(a) Prior Draws

00 02 04 06 08 10

(b) True Labels

00 02 04 06 08 10

(c) Corrupted Labels

0 5k 10k 12k 15k 20k 25k

Altered Labels

minus070

minus065

minus060

minus055

minus050

minus045

Mar

gina

lL

ikel

ihoo

dE

stim

ate

15

20

25

30

35

40

45

Tes

tE

rror

(d) Gaussian Process

0 5k 10k 20k 30k 40k 50k

Resampled Labels

minus5

minus4

minus3

minus2

minus1

Mar

gina

lL

ikel

ihoo

dE

stim

ate times107

20

40

60

80

Tes

tE

rror

(e) PreResNet-20

Figure 7 Rethinking generalization (a) Sample functions from a Gaussian process prior (b) GP fit (with 95 credible region) tostructured data generated as ygreen(x) = sin(x middot 2π)+ ε ε sim N (0 022) (c) GP fit with no training error after a significant addition ofcorrupted data in red drawn from Uniform[05 1] (d) Variational GP marginal likelihood with RBF kernel for two classes of CIFAR-10(e) Laplace BNN marginal likelihood for a PreResNet-20 on CIFAR-10 with different fractions of random labels The marginal likelihoodfor both the GP and BNN decreases as we increase the level of corruption in the labels suggesting reasonable inductive biases in the priorover functions Moreover both the GP and BNN have 100 training accuracy on images with fully corrupted labels

processing CIFAR-10 with a randomly initialized untrainedconvolutional neural network dramatically improves the testperformance of a simple Gaussian kernel on pixels from54 accuracy to 71 Adding `2 regularization only im-proves the accuracy by an additional 2 These resultsagain indicate that broad Gaussian priors over parametersinduce reasonable priors over networks with a minor ad-ditional gain from decreasing the variance of the prior inparameter space which corresponds to `2 regularization

52 Prior Class Correlations

In Figure 6 we study the prior correlations in the outputs ofthe LeNet-5 convolutional network (LeCun et al 1998) onobjects of different MNIST classes We sample networkswith weights p(w) = N (0 α2I) and compute the values oflogits corresponding to the first class for all pairs of imagesand compute correlations of these logits For all levels of αthe correlations between objects corresponding to the sameclass are consistently higher than the correlation betweenobjects of different classes showing that the network in-duces a reasonable prior similarity metric over these imagesAdditionally we observe that the prior correlations some-

what decrease as we increase α showing that bounding thenorm of the weights has some minor utility in accordancewith Section 51 Similarly in panel (d) we see that the NLLsignificantly decreases as α increases in [0 05] and thenslightly increases but is relatively constant thereafter

In the Appendix we further describe analytic results andillustrate the effect of α on sample functions

53 Effect of Prior Variance on CIFAR-10

We further study the effect of the parameter prior stan-dard deviation α measuring performance of approximateBayesian inference for CIFAR-10 with a PreactivationResNet-20 (He et al 2016) and VGG-16 (Simonyan ampZisserman 2014) For each of these architectures we runSWAG (Maddox et al 2019) with fixed hyper-parametersand varying α We report the results in Figure 12(d) (h) Forboth architectures the performance is near-optimal in therange α isin [10minus2 10minus1] Smaller α constrains the weightstoo much Performance is reasonable and becomes mostlyinsensitive to α as it continues to increase due to the induc-tive biases of the functional form of the neural network

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 6: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

selected to make q close to p in some sense for exam-ple variational approximations (eg Beal 2003) whichhave emerged as a popular deterministic approach findargminθKL(q||p) Other standard deterministic approxima-tions include Laplace (eg MacKay 1995) EP (Minka2001a) and INLA (Rue et al 2009)

From the perspective of estimating the predictive distribu-tion in Eq (1) we can view simple Monte Carlo as ap-proximating the posterior with a set of point masses withlocations given by samples from another approximate pos-terior q even if q is a continuous distribution That isp(w|D) asympsumJ

j=1 δ(w = wj) wj sim q(w|D)

Ultimately the goal is to accurately compute the predictivedistribution in Eq (1) rather than find a generally accu-rate representation of the posterior In particular we mustcarefully represent the posterior in regions that will makethe greatest contributions to the BMA integral In terms ofefficiently computing the predictive distribution we do notnecessarily want to place point masses at locations givenby samples from the posterior For example functional di-versity is important for a good approximation to the BMAintegral because we are summing together terms of theform p(y|xw) if two settings of the weights wi and wjeach provide high likelihood (and consequently high pos-terior density) but give rise to similar functions f(xwi)f(xwj) then they will be largely redundant in the modelaverage and the second setting of parameters will not con-tribute much to estimating the BMA integral for the uncon-ditional predictive distribution In Sections 33 and 4 weconsider how various approaches approximate the predictivedistribution

33 Deep Ensembles are BMA

Deep ensembles (Lakshminarayanan et al 2017) is fastbecoming a gold standard for accurate and well-calibratedpredictive distributions Recent reports (eg Ovadia et al2019 Ashukha et al 2020) show that deep ensembles ap-pear to outperform some particular approaches to Bayesianneural networks for uncertainty representation leading tothe confusion that deep ensembles and Bayesian methodsare competing approaches These methods are often explic-itly referred to as non-Bayesian (eg Lakshminarayananet al 2017 Ovadia et al 2019 Wenzel et al 2020) Tothe contrary we argue that deep ensembles are actually acompelling approach to Bayesian model averaging in thevein of Section 32

There is a fundamental difference between a Bayesian modelaverage and some approaches to ensembling The Bayesianmodel average assumes that one hypothesis (one parame-ter setting) is correct and averages over models due to aninability to distinguish between hypotheses given limitedinformation (Minka 2000) As we observe more data the

posterior collapses onto a single hypothesis If the trueexplanation for the data is a combination of hypothesesthen the Bayesian model average may appear to performworse as we observe more data Some ensembling methodswork by enriching the hypothesis space and therefore donot collapse in this way Deep ensembles however areformed by MAP or maximum likelihood retraining of thesame architecture multiple times leading to different basinsof attraction The deep ensemble will therefore collapse inthe same way as a Bayesian model average as the posteriorconcentrates Since the hypotheses space (support) for amodern neural network is large containing many differentpossible explanations for the data posterior collapse willoften be desirable

Furthermore by representing multiple basins of attractiondeep ensembles can provide a better approximation to theBMA than the Bayesian approaches in Ovadia et al (2019)Indeed the functional diversity is important for a good ap-proximation to the BMA integral as per Section 32 Theapproaches referred to as Bayesian in Ovadia et al (2019)instead focus their approximation on a single basin whichmay contain a lot of redundancy in function space makinga relatively minimal contribution to computing the Bayesianpredictive distribution On the other hand retraining a neu-ral network multiple times for deep ensembles incurs asignificant computational expense The single basin ap-proaches may be preferred if we are to control for computa-tion We explore these questions in Section 4

4 An Empirical Study of MarginalizationWe have shown that deep ensembles can be interpretedas an approximate approach to Bayesian marginalizationwhich selects for functional diversity by representing multi-ple basins of attraction in the posterior Most Bayesian deeplearning methods instead focus on faithfully approximatinga posterior within a single basin of attraction We propose anew method MultiSWAG which combines these two typesof approaches MultiSWAG combines multiple indepen-dently trained SWAG approximations (Maddox et al 2019)to create a mixture of Gaussians approximation to the pos-terior with each Gaussian centred on a different basin ofattraction We note that MultiSWAG does not require anyadditional training time over standard deep ensembles

We illustrate the conceptual difference between deep en-sembles a standard variational single basin approach andMultiSWAG in Figure 3 In the top panel we have a con-ceptualization of a multimodal posterior VI approximatesthe posterior with multiple samples within a single basinBut we see in the middle panel that the conditional predic-tive distribution p(y|xw) does not vary significantly withinthe basin and thus each additional sample contributes min-imally to computing the marginal predictive distribution

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Deep Ensembles VI Multi-SWAG

ww

w

w

dist(p q)

p(y|w)

p(w|D)

Figure 3 Approximating the BMAp(y|xD) =

intp(y|xw)p(w|D)dw Top p(w|D) with repre-

sentations from VI (orange) deep ensembles (blue) MultiSWAG(red) Middle p(y|xw) as a function of w for a test input xThis function does not vary much within modes but changes sig-nificantly between modes Bottom Distance between the truepredictive distribution and the approximation as a function of rep-resenting a posterior at an additional point w assuming we havesampled the mode in dark green There is more to be gained byexploring new basins than continuing to explore the same basin

p(y|xD) On the other hand p(y|xw) varies significantlybetween basins and thus each point mass for deep ensem-bles contributes significantly to the marginal predictive dis-tribution By sampling within the basins MultiSWAG pro-vides additional contributions to the predictive distributionIn the bottom panel we have the gain in approximatingthe predictive distribution when adding a point mass to therepresentation of the posterior as a function of its locationassuming we have already sampled the mode in dark greenIncluding samples from different modes provides significantgain over continuing to sample from the same mode andincluding weights in wide basins provide relatively moregain than the narrow ones

In Figure 4 we evaluate single basin and multi-basin ap-proaches in a case where we can near-exactly compute thepredictive distribution We provide details for generatingthe data and training the models in Appendix D1 We seethat the predictive distribution given by deep ensembles isqualitatively closer to the true distribution compared to thesingle basin variational method between data clusters thedeep ensemble approach provides a similar representationof epistemic uncertainty whereas the variational method isextremely overconfident in these regions Moreover we seethat the Wasserstein distance between the true predictive dis-tribution and these two approximations quickly shrinks withnumber of samples for deep ensembles but is roughly inde-pendent of number of samples for the variational approachThus the deep ensemble is providing a better approximation

of the Bayesian model average in Eq (1) than the singlebasin variational approach which has traditionally beenlabelled as the Bayesian alternative

Next we evaluate MultiSWAG under distribution shift onthe CIFAR-10 dataset (Krizhevsky et al 2014) replicatingthe setup in Ovadia et al (2019) We consider 16 data cor-ruptions each at 5 different levels of severity introducedby Hendrycks amp Dietterich (2019) For each corruptionwe evaluate the performance of deep ensembles and Mul-tiSWAG varying the training budget For deep ensembleswe show performance as a function of the number of inde-pendently trained models in the ensemble For MultiSWAGwe show performance as a function of the number of inde-pendent SWAG approximations that we construct we thensample 20 models from each of these approximations toconstruct the final ensemble

While the training time for MultiSWAG is the same as fordeep ensembles at test time MultiSWAG is more expen-sive as the corresponding ensemble consists of a largernumber of models To account for situations when testtime is constrained we also propose MultiSWA a methodthat ensembles independently trained SWA solutions (Iz-mailov et al 2018) SWA solutions are the means of thecorresponding Gaussian SWAG approximations Izmailovet al (2018) argue that SWA solutions approximate the localensembles represented by SWAG with a single model

In Figure 5 we show the negative log-likelihood as a func-tion of the number of independently trained models for aPreactivation ResNet-20 on CIFAR-10 corrupted with Gaus-sian blur with varying levels of intensity (increasing fromleft to right) in Figure 5 MultiSWAG outperforms deepensembles significantly on highly corrupted data For lowerlevels of corruption MultiSWAG works particularly wellwhen only a small number of independently trained modelsare available We note that MultiSWA also outperforms deepensembles and has the same computational requirements attraining and test time as deep ensembles We present resultsfor other types of corruption in Appendix Figures 14 15 1617 showing similar trends In general there is an extensiveevaluation of MultiSWAG in the Appendix

Our perspective of generalization is deeply connected withBayesian marginalization In order to best realize the bene-fits of marginalization in deep learning we need to consideras many hypotheses as possible through multimodal poste-rior approximations such as MultiSWAG In Section 7 wereturn to MultiSWAG showing how it can entirely alleviateprominent double descent behaviour and lead to strikingimprovements in generalization over SGD and single basinmarginalization for both accuracy and NLL

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10

minus05

00

05

10

15

20

25

30

35

(a) Exact

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(b) Deep Ensembles

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(c) Variational Inference

0 10 20 30 40 50

Samples

008

010

012

014

016

018

Avg

Wx(q

p)

Deep Ensembles

SVI

(d) Discrepancy with True BMA

Figure 4 Approximating the true predictive distribution (a) A close approximation of the true predictive distribution obtained bycombining 200 HMC chains (b) Deep ensembles predictive distribution using 50 independently trained networks (c) Predictivedistribution for factorized variational inference (VI) (d) Convergence of the predictive distributions for deep ensembles and variationalinference as a function of the number of samples we measure the average Wasserstein distance between the marginals in the range ofinput positions The multi-basin deep ensembles approach provides a more faithful approximation of the Bayesian predictive distributionthan the conventional single-basin VI approach which is overconfident between data clusters The top panels show the Wassersteindistance between the true predictive distribution and the deep ensemble and VI approximations as a function of inputs x

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

Figure 5 Negative log likelihood for Deep Ensembles MultiSWAG and MultiSWA using a PreResNet-20 on CIFAR-10 with varyingintensity of the Gaussian blur corruption The image in each plot shows the intensity of corruption For all levels of intensity MultiSWAGand MultiSWA outperform Deep Ensembles for a small number of independent models For high levels of corruption MultiSWAGsignificantly outperforms other methods even for many independent models We present results for other corruptions in the Appendix

5 Neural Network PriorsA prior over parameters p(w) combines with the functionalform of a model f(xw) to induce a distribution over func-tions p(f(xw)) It is this distribution over functions thatcontrols the generalization properties of the model the priorover parameters in isolation has no meaning Neural net-works are imbued with structural properties that providegood inductive biases such as translation equivariance hi-erarchical representations and sparsity In the sense ofFigure 2 the prior will have large support due to the flexi-bility of neural networks but its inductive biases provide themost mass to datasets which are representative of problemsettings where neural networks are often applied In thissection we study the properties of the induced distributionover functions We directly continue the discussion of priorsin Section 6 with a focus on examining the noisy CIFAR re-sults in Zhang et al (2016) from a probabilistic perspectiveof generalization These sections are best read together

We also provide several additional experiments in the Ap-pendix In Section E we present analytic results on thedependence of the prior distribution in function space onthe variance of the prior over parameters considering also

layer-wise parameter priors with ReLU activations As partof a discussion on tempering in Section 84 we study theeffect of α in p(w) = N (0 α2I) on prior class probabilitiesfor individual sample functions p(f(xw)) the predictivedistribution and posterior samples as we observe varyingamounts of data In Section F we further study the correla-tion structure over images induced by neural network priorssubject to perturbations of the images In Section D3 weprovide additional experimental details

51 Deep Image Prior and Random Network Features

Two recent results provide strong evidence that vague Gaus-sian priors over parameters when combined with a neu-ral network architecture induce a distribution over func-tions with useful inductive biases In the deep image priorUlyanov et al (2018) show that randomly initialized convo-lutional neural networks without training provide excellentperformance for image denoising super-resolution and in-painting This result demonstrates the ability for a sam-ple function from a random prior over neural networksp(f(xw)) to capture low-level image statistics before anytraining Similarly Zhang et al (2016) shows that pre-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

Figure 6 Induced prior correlation function Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNISTinduced by LeNet-5 for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images fromdifferent classes suggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α (d)NLL of an ensemble of 20 SWAG samples on MNIST as a function of α using a LeNet-5

00 02 04 06 08 104

3

2

1

0

1

2

3

4

(a) Prior Draws

00 02 04 06 08 10

(b) True Labels

00 02 04 06 08 10

(c) Corrupted Labels

0 5k 10k 12k 15k 20k 25k

Altered Labels

minus070

minus065

minus060

minus055

minus050

minus045

Mar

gina

lL

ikel

ihoo

dE

stim

ate

15

20

25

30

35

40

45

Tes

tE

rror

(d) Gaussian Process

0 5k 10k 20k 30k 40k 50k

Resampled Labels

minus5

minus4

minus3

minus2

minus1

Mar

gina

lL

ikel

ihoo

dE

stim

ate times107

20

40

60

80

Tes

tE

rror

(e) PreResNet-20

Figure 7 Rethinking generalization (a) Sample functions from a Gaussian process prior (b) GP fit (with 95 credible region) tostructured data generated as ygreen(x) = sin(x middot 2π)+ ε ε sim N (0 022) (c) GP fit with no training error after a significant addition ofcorrupted data in red drawn from Uniform[05 1] (d) Variational GP marginal likelihood with RBF kernel for two classes of CIFAR-10(e) Laplace BNN marginal likelihood for a PreResNet-20 on CIFAR-10 with different fractions of random labels The marginal likelihoodfor both the GP and BNN decreases as we increase the level of corruption in the labels suggesting reasonable inductive biases in the priorover functions Moreover both the GP and BNN have 100 training accuracy on images with fully corrupted labels

processing CIFAR-10 with a randomly initialized untrainedconvolutional neural network dramatically improves the testperformance of a simple Gaussian kernel on pixels from54 accuracy to 71 Adding `2 regularization only im-proves the accuracy by an additional 2 These resultsagain indicate that broad Gaussian priors over parametersinduce reasonable priors over networks with a minor ad-ditional gain from decreasing the variance of the prior inparameter space which corresponds to `2 regularization

52 Prior Class Correlations

In Figure 6 we study the prior correlations in the outputs ofthe LeNet-5 convolutional network (LeCun et al 1998) onobjects of different MNIST classes We sample networkswith weights p(w) = N (0 α2I) and compute the values oflogits corresponding to the first class for all pairs of imagesand compute correlations of these logits For all levels of αthe correlations between objects corresponding to the sameclass are consistently higher than the correlation betweenobjects of different classes showing that the network in-duces a reasonable prior similarity metric over these imagesAdditionally we observe that the prior correlations some-

what decrease as we increase α showing that bounding thenorm of the weights has some minor utility in accordancewith Section 51 Similarly in panel (d) we see that the NLLsignificantly decreases as α increases in [0 05] and thenslightly increases but is relatively constant thereafter

In the Appendix we further describe analytic results andillustrate the effect of α on sample functions

53 Effect of Prior Variance on CIFAR-10

We further study the effect of the parameter prior stan-dard deviation α measuring performance of approximateBayesian inference for CIFAR-10 with a PreactivationResNet-20 (He et al 2016) and VGG-16 (Simonyan ampZisserman 2014) For each of these architectures we runSWAG (Maddox et al 2019) with fixed hyper-parametersand varying α We report the results in Figure 12(d) (h) Forboth architectures the performance is near-optimal in therange α isin [10minus2 10minus1] Smaller α constrains the weightstoo much Performance is reasonable and becomes mostlyinsensitive to α as it continues to increase due to the induc-tive biases of the functional form of the neural network

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 7: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Deep Ensembles VI Multi-SWAG

ww

w

w

dist(p q)

p(y|w)

p(w|D)

Figure 3 Approximating the BMAp(y|xD) =

intp(y|xw)p(w|D)dw Top p(w|D) with repre-

sentations from VI (orange) deep ensembles (blue) MultiSWAG(red) Middle p(y|xw) as a function of w for a test input xThis function does not vary much within modes but changes sig-nificantly between modes Bottom Distance between the truepredictive distribution and the approximation as a function of rep-resenting a posterior at an additional point w assuming we havesampled the mode in dark green There is more to be gained byexploring new basins than continuing to explore the same basin

p(y|xD) On the other hand p(y|xw) varies significantlybetween basins and thus each point mass for deep ensem-bles contributes significantly to the marginal predictive dis-tribution By sampling within the basins MultiSWAG pro-vides additional contributions to the predictive distributionIn the bottom panel we have the gain in approximatingthe predictive distribution when adding a point mass to therepresentation of the posterior as a function of its locationassuming we have already sampled the mode in dark greenIncluding samples from different modes provides significantgain over continuing to sample from the same mode andincluding weights in wide basins provide relatively moregain than the narrow ones

In Figure 4 we evaluate single basin and multi-basin ap-proaches in a case where we can near-exactly compute thepredictive distribution We provide details for generatingthe data and training the models in Appendix D1 We seethat the predictive distribution given by deep ensembles isqualitatively closer to the true distribution compared to thesingle basin variational method between data clusters thedeep ensemble approach provides a similar representationof epistemic uncertainty whereas the variational method isextremely overconfident in these regions Moreover we seethat the Wasserstein distance between the true predictive dis-tribution and these two approximations quickly shrinks withnumber of samples for deep ensembles but is roughly inde-pendent of number of samples for the variational approachThus the deep ensemble is providing a better approximation

of the Bayesian model average in Eq (1) than the singlebasin variational approach which has traditionally beenlabelled as the Bayesian alternative

Next we evaluate MultiSWAG under distribution shift onthe CIFAR-10 dataset (Krizhevsky et al 2014) replicatingthe setup in Ovadia et al (2019) We consider 16 data cor-ruptions each at 5 different levels of severity introducedby Hendrycks amp Dietterich (2019) For each corruptionwe evaluate the performance of deep ensembles and Mul-tiSWAG varying the training budget For deep ensembleswe show performance as a function of the number of inde-pendently trained models in the ensemble For MultiSWAGwe show performance as a function of the number of inde-pendent SWAG approximations that we construct we thensample 20 models from each of these approximations toconstruct the final ensemble

While the training time for MultiSWAG is the same as fordeep ensembles at test time MultiSWAG is more expen-sive as the corresponding ensemble consists of a largernumber of models To account for situations when testtime is constrained we also propose MultiSWA a methodthat ensembles independently trained SWA solutions (Iz-mailov et al 2018) SWA solutions are the means of thecorresponding Gaussian SWAG approximations Izmailovet al (2018) argue that SWA solutions approximate the localensembles represented by SWAG with a single model

In Figure 5 we show the negative log-likelihood as a func-tion of the number of independently trained models for aPreactivation ResNet-20 on CIFAR-10 corrupted with Gaus-sian blur with varying levels of intensity (increasing fromleft to right) in Figure 5 MultiSWAG outperforms deepensembles significantly on highly corrupted data For lowerlevels of corruption MultiSWAG works particularly wellwhen only a small number of independently trained modelsare available We note that MultiSWA also outperforms deepensembles and has the same computational requirements attraining and test time as deep ensembles We present resultsfor other types of corruption in Appendix Figures 14 15 1617 showing similar trends In general there is an extensiveevaluation of MultiSWAG in the Appendix

Our perspective of generalization is deeply connected withBayesian marginalization In order to best realize the bene-fits of marginalization in deep learning we need to consideras many hypotheses as possible through multimodal poste-rior approximations such as MultiSWAG In Section 7 wereturn to MultiSWAG showing how it can entirely alleviateprominent double descent behaviour and lead to strikingimprovements in generalization over SGD and single basinmarginalization for both accuracy and NLL

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10

minus05

00

05

10

15

20

25

30

35

(a) Exact

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(b) Deep Ensembles

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(c) Variational Inference

0 10 20 30 40 50

Samples

008

010

012

014

016

018

Avg

Wx(q

p)

Deep Ensembles

SVI

(d) Discrepancy with True BMA

Figure 4 Approximating the true predictive distribution (a) A close approximation of the true predictive distribution obtained bycombining 200 HMC chains (b) Deep ensembles predictive distribution using 50 independently trained networks (c) Predictivedistribution for factorized variational inference (VI) (d) Convergence of the predictive distributions for deep ensembles and variationalinference as a function of the number of samples we measure the average Wasserstein distance between the marginals in the range ofinput positions The multi-basin deep ensembles approach provides a more faithful approximation of the Bayesian predictive distributionthan the conventional single-basin VI approach which is overconfident between data clusters The top panels show the Wassersteindistance between the true predictive distribution and the deep ensemble and VI approximations as a function of inputs x

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

Figure 5 Negative log likelihood for Deep Ensembles MultiSWAG and MultiSWA using a PreResNet-20 on CIFAR-10 with varyingintensity of the Gaussian blur corruption The image in each plot shows the intensity of corruption For all levels of intensity MultiSWAGand MultiSWA outperform Deep Ensembles for a small number of independent models For high levels of corruption MultiSWAGsignificantly outperforms other methods even for many independent models We present results for other corruptions in the Appendix

5 Neural Network PriorsA prior over parameters p(w) combines with the functionalform of a model f(xw) to induce a distribution over func-tions p(f(xw)) It is this distribution over functions thatcontrols the generalization properties of the model the priorover parameters in isolation has no meaning Neural net-works are imbued with structural properties that providegood inductive biases such as translation equivariance hi-erarchical representations and sparsity In the sense ofFigure 2 the prior will have large support due to the flexi-bility of neural networks but its inductive biases provide themost mass to datasets which are representative of problemsettings where neural networks are often applied In thissection we study the properties of the induced distributionover functions We directly continue the discussion of priorsin Section 6 with a focus on examining the noisy CIFAR re-sults in Zhang et al (2016) from a probabilistic perspectiveof generalization These sections are best read together

We also provide several additional experiments in the Ap-pendix In Section E we present analytic results on thedependence of the prior distribution in function space onthe variance of the prior over parameters considering also

layer-wise parameter priors with ReLU activations As partof a discussion on tempering in Section 84 we study theeffect of α in p(w) = N (0 α2I) on prior class probabilitiesfor individual sample functions p(f(xw)) the predictivedistribution and posterior samples as we observe varyingamounts of data In Section F we further study the correla-tion structure over images induced by neural network priorssubject to perturbations of the images In Section D3 weprovide additional experimental details

51 Deep Image Prior and Random Network Features

Two recent results provide strong evidence that vague Gaus-sian priors over parameters when combined with a neu-ral network architecture induce a distribution over func-tions with useful inductive biases In the deep image priorUlyanov et al (2018) show that randomly initialized convo-lutional neural networks without training provide excellentperformance for image denoising super-resolution and in-painting This result demonstrates the ability for a sam-ple function from a random prior over neural networksp(f(xw)) to capture low-level image statistics before anytraining Similarly Zhang et al (2016) shows that pre-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

Figure 6 Induced prior correlation function Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNISTinduced by LeNet-5 for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images fromdifferent classes suggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α (d)NLL of an ensemble of 20 SWAG samples on MNIST as a function of α using a LeNet-5

00 02 04 06 08 104

3

2

1

0

1

2

3

4

(a) Prior Draws

00 02 04 06 08 10

(b) True Labels

00 02 04 06 08 10

(c) Corrupted Labels

0 5k 10k 12k 15k 20k 25k

Altered Labels

minus070

minus065

minus060

minus055

minus050

minus045

Mar

gina

lL

ikel

ihoo

dE

stim

ate

15

20

25

30

35

40

45

Tes

tE

rror

(d) Gaussian Process

0 5k 10k 20k 30k 40k 50k

Resampled Labels

minus5

minus4

minus3

minus2

minus1

Mar

gina

lL

ikel

ihoo

dE

stim

ate times107

20

40

60

80

Tes

tE

rror

(e) PreResNet-20

Figure 7 Rethinking generalization (a) Sample functions from a Gaussian process prior (b) GP fit (with 95 credible region) tostructured data generated as ygreen(x) = sin(x middot 2π)+ ε ε sim N (0 022) (c) GP fit with no training error after a significant addition ofcorrupted data in red drawn from Uniform[05 1] (d) Variational GP marginal likelihood with RBF kernel for two classes of CIFAR-10(e) Laplace BNN marginal likelihood for a PreResNet-20 on CIFAR-10 with different fractions of random labels The marginal likelihoodfor both the GP and BNN decreases as we increase the level of corruption in the labels suggesting reasonable inductive biases in the priorover functions Moreover both the GP and BNN have 100 training accuracy on images with fully corrupted labels

processing CIFAR-10 with a randomly initialized untrainedconvolutional neural network dramatically improves the testperformance of a simple Gaussian kernel on pixels from54 accuracy to 71 Adding `2 regularization only im-proves the accuracy by an additional 2 These resultsagain indicate that broad Gaussian priors over parametersinduce reasonable priors over networks with a minor ad-ditional gain from decreasing the variance of the prior inparameter space which corresponds to `2 regularization

52 Prior Class Correlations

In Figure 6 we study the prior correlations in the outputs ofthe LeNet-5 convolutional network (LeCun et al 1998) onobjects of different MNIST classes We sample networkswith weights p(w) = N (0 α2I) and compute the values oflogits corresponding to the first class for all pairs of imagesand compute correlations of these logits For all levels of αthe correlations between objects corresponding to the sameclass are consistently higher than the correlation betweenobjects of different classes showing that the network in-duces a reasonable prior similarity metric over these imagesAdditionally we observe that the prior correlations some-

what decrease as we increase α showing that bounding thenorm of the weights has some minor utility in accordancewith Section 51 Similarly in panel (d) we see that the NLLsignificantly decreases as α increases in [0 05] and thenslightly increases but is relatively constant thereafter

In the Appendix we further describe analytic results andillustrate the effect of α on sample functions

53 Effect of Prior Variance on CIFAR-10

We further study the effect of the parameter prior stan-dard deviation α measuring performance of approximateBayesian inference for CIFAR-10 with a PreactivationResNet-20 (He et al 2016) and VGG-16 (Simonyan ampZisserman 2014) For each of these architectures we runSWAG (Maddox et al 2019) with fixed hyper-parametersand varying α We report the results in Figure 12(d) (h) Forboth architectures the performance is near-optimal in therange α isin [10minus2 10minus1] Smaller α constrains the weightstoo much Performance is reasonable and becomes mostlyinsensitive to α as it continues to increase due to the induc-tive biases of the functional form of the neural network

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 8: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10

minus05

00

05

10

15

20

25

30

35

(a) Exact

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(b) Deep Ensembles

000

025

Wx(q

p)

minus10 minus5 0 5 10

0

1

2

3

(c) Variational Inference

0 10 20 30 40 50

Samples

008

010

012

014

016

018

Avg

Wx(q

p)

Deep Ensembles

SVI

(d) Discrepancy with True BMA

Figure 4 Approximating the true predictive distribution (a) A close approximation of the true predictive distribution obtained bycombining 200 HMC chains (b) Deep ensembles predictive distribution using 50 independently trained networks (c) Predictivedistribution for factorized variational inference (VI) (d) Convergence of the predictive distributions for deep ensembles and variationalinference as a function of the number of samples we measure the average Wasserstein distance between the marginals in the range ofinput positions The multi-basin deep ensembles approach provides a more faithful approximation of the Bayesian predictive distributionthan the conventional single-basin VI approach which is overconfident between data clusters The top panels show the Wassersteindistance between the true predictive distribution and the deep ensemble and VI approximations as a function of inputs x

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

Figure 5 Negative log likelihood for Deep Ensembles MultiSWAG and MultiSWA using a PreResNet-20 on CIFAR-10 with varyingintensity of the Gaussian blur corruption The image in each plot shows the intensity of corruption For all levels of intensity MultiSWAGand MultiSWA outperform Deep Ensembles for a small number of independent models For high levels of corruption MultiSWAGsignificantly outperforms other methods even for many independent models We present results for other corruptions in the Appendix

5 Neural Network PriorsA prior over parameters p(w) combines with the functionalform of a model f(xw) to induce a distribution over func-tions p(f(xw)) It is this distribution over functions thatcontrols the generalization properties of the model the priorover parameters in isolation has no meaning Neural net-works are imbued with structural properties that providegood inductive biases such as translation equivariance hi-erarchical representations and sparsity In the sense ofFigure 2 the prior will have large support due to the flexi-bility of neural networks but its inductive biases provide themost mass to datasets which are representative of problemsettings where neural networks are often applied In thissection we study the properties of the induced distributionover functions We directly continue the discussion of priorsin Section 6 with a focus on examining the noisy CIFAR re-sults in Zhang et al (2016) from a probabilistic perspectiveof generalization These sections are best read together

We also provide several additional experiments in the Ap-pendix In Section E we present analytic results on thedependence of the prior distribution in function space onthe variance of the prior over parameters considering also

layer-wise parameter priors with ReLU activations As partof a discussion on tempering in Section 84 we study theeffect of α in p(w) = N (0 α2I) on prior class probabilitiesfor individual sample functions p(f(xw)) the predictivedistribution and posterior samples as we observe varyingamounts of data In Section F we further study the correla-tion structure over images induced by neural network priorssubject to perturbations of the images In Section D3 weprovide additional experimental details

51 Deep Image Prior and Random Network Features

Two recent results provide strong evidence that vague Gaus-sian priors over parameters when combined with a neu-ral network architecture induce a distribution over func-tions with useful inductive biases In the deep image priorUlyanov et al (2018) show that randomly initialized convo-lutional neural networks without training provide excellentperformance for image denoising super-resolution and in-painting This result demonstrates the ability for a sam-ple function from a random prior over neural networksp(f(xw)) to capture low-level image statistics before anytraining Similarly Zhang et al (2016) shows that pre-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

Figure 6 Induced prior correlation function Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNISTinduced by LeNet-5 for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images fromdifferent classes suggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α (d)NLL of an ensemble of 20 SWAG samples on MNIST as a function of α using a LeNet-5

00 02 04 06 08 104

3

2

1

0

1

2

3

4

(a) Prior Draws

00 02 04 06 08 10

(b) True Labels

00 02 04 06 08 10

(c) Corrupted Labels

0 5k 10k 12k 15k 20k 25k

Altered Labels

minus070

minus065

minus060

minus055

minus050

minus045

Mar

gina

lL

ikel

ihoo

dE

stim

ate

15

20

25

30

35

40

45

Tes

tE

rror

(d) Gaussian Process

0 5k 10k 20k 30k 40k 50k

Resampled Labels

minus5

minus4

minus3

minus2

minus1

Mar

gina

lL

ikel

ihoo

dE

stim

ate times107

20

40

60

80

Tes

tE

rror

(e) PreResNet-20

Figure 7 Rethinking generalization (a) Sample functions from a Gaussian process prior (b) GP fit (with 95 credible region) tostructured data generated as ygreen(x) = sin(x middot 2π)+ ε ε sim N (0 022) (c) GP fit with no training error after a significant addition ofcorrupted data in red drawn from Uniform[05 1] (d) Variational GP marginal likelihood with RBF kernel for two classes of CIFAR-10(e) Laplace BNN marginal likelihood for a PreResNet-20 on CIFAR-10 with different fractions of random labels The marginal likelihoodfor both the GP and BNN decreases as we increase the level of corruption in the labels suggesting reasonable inductive biases in the priorover functions Moreover both the GP and BNN have 100 training accuracy on images with fully corrupted labels

processing CIFAR-10 with a randomly initialized untrainedconvolutional neural network dramatically improves the testperformance of a simple Gaussian kernel on pixels from54 accuracy to 71 Adding `2 regularization only im-proves the accuracy by an additional 2 These resultsagain indicate that broad Gaussian priors over parametersinduce reasonable priors over networks with a minor ad-ditional gain from decreasing the variance of the prior inparameter space which corresponds to `2 regularization

52 Prior Class Correlations

In Figure 6 we study the prior correlations in the outputs ofthe LeNet-5 convolutional network (LeCun et al 1998) onobjects of different MNIST classes We sample networkswith weights p(w) = N (0 α2I) and compute the values oflogits corresponding to the first class for all pairs of imagesand compute correlations of these logits For all levels of αthe correlations between objects corresponding to the sameclass are consistently higher than the correlation betweenobjects of different classes showing that the network in-duces a reasonable prior similarity metric over these imagesAdditionally we observe that the prior correlations some-

what decrease as we increase α showing that bounding thenorm of the weights has some minor utility in accordancewith Section 51 Similarly in panel (d) we see that the NLLsignificantly decreases as α increases in [0 05] and thenslightly increases but is relatively constant thereafter

In the Appendix we further describe analytic results andillustrate the effect of α on sample functions

53 Effect of Prior Variance on CIFAR-10

We further study the effect of the parameter prior stan-dard deviation α measuring performance of approximateBayesian inference for CIFAR-10 with a PreactivationResNet-20 (He et al 2016) and VGG-16 (Simonyan ampZisserman 2014) For each of these architectures we runSWAG (Maddox et al 2019) with fixed hyper-parametersand varying α We report the results in Figure 12(d) (h) Forboth architectures the performance is near-optimal in therange α isin [10minus2 10minus1] Smaller α constrains the weightstoo much Performance is reasonable and becomes mostlyinsensitive to α as it continues to increase due to the induc-tive biases of the functional form of the neural network

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 9: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

Figure 6 Induced prior correlation function Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNISTinduced by LeNet-5 for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images fromdifferent classes suggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α (d)NLL of an ensemble of 20 SWAG samples on MNIST as a function of α using a LeNet-5

00 02 04 06 08 104

3

2

1

0

1

2

3

4

(a) Prior Draws

00 02 04 06 08 10

(b) True Labels

00 02 04 06 08 10

(c) Corrupted Labels

0 5k 10k 12k 15k 20k 25k

Altered Labels

minus070

minus065

minus060

minus055

minus050

minus045

Mar

gina

lL

ikel

ihoo

dE

stim

ate

15

20

25

30

35

40

45

Tes

tE

rror

(d) Gaussian Process

0 5k 10k 20k 30k 40k 50k

Resampled Labels

minus5

minus4

minus3

minus2

minus1

Mar

gina

lL

ikel

ihoo

dE

stim

ate times107

20

40

60

80

Tes

tE

rror

(e) PreResNet-20

Figure 7 Rethinking generalization (a) Sample functions from a Gaussian process prior (b) GP fit (with 95 credible region) tostructured data generated as ygreen(x) = sin(x middot 2π)+ ε ε sim N (0 022) (c) GP fit with no training error after a significant addition ofcorrupted data in red drawn from Uniform[05 1] (d) Variational GP marginal likelihood with RBF kernel for two classes of CIFAR-10(e) Laplace BNN marginal likelihood for a PreResNet-20 on CIFAR-10 with different fractions of random labels The marginal likelihoodfor both the GP and BNN decreases as we increase the level of corruption in the labels suggesting reasonable inductive biases in the priorover functions Moreover both the GP and BNN have 100 training accuracy on images with fully corrupted labels

processing CIFAR-10 with a randomly initialized untrainedconvolutional neural network dramatically improves the testperformance of a simple Gaussian kernel on pixels from54 accuracy to 71 Adding `2 regularization only im-proves the accuracy by an additional 2 These resultsagain indicate that broad Gaussian priors over parametersinduce reasonable priors over networks with a minor ad-ditional gain from decreasing the variance of the prior inparameter space which corresponds to `2 regularization

52 Prior Class Correlations

In Figure 6 we study the prior correlations in the outputs ofthe LeNet-5 convolutional network (LeCun et al 1998) onobjects of different MNIST classes We sample networkswith weights p(w) = N (0 α2I) and compute the values oflogits corresponding to the first class for all pairs of imagesand compute correlations of these logits For all levels of αthe correlations between objects corresponding to the sameclass are consistently higher than the correlation betweenobjects of different classes showing that the network in-duces a reasonable prior similarity metric over these imagesAdditionally we observe that the prior correlations some-

what decrease as we increase α showing that bounding thenorm of the weights has some minor utility in accordancewith Section 51 Similarly in panel (d) we see that the NLLsignificantly decreases as α increases in [0 05] and thenslightly increases but is relatively constant thereafter

In the Appendix we further describe analytic results andillustrate the effect of α on sample functions

53 Effect of Prior Variance on CIFAR-10

We further study the effect of the parameter prior stan-dard deviation α measuring performance of approximateBayesian inference for CIFAR-10 with a PreactivationResNet-20 (He et al 2016) and VGG-16 (Simonyan ampZisserman 2014) For each of these architectures we runSWAG (Maddox et al 2019) with fixed hyper-parametersand varying α We report the results in Figure 12(d) (h) Forboth architectures the performance is near-optimal in therange α isin [10minus2 10minus1] Smaller α constrains the weightstoo much Performance is reasonable and becomes mostlyinsensitive to α as it continues to increase due to the induc-tive biases of the functional form of the neural network

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 10: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

6 Rethinking GeneralizationZhang et al (2016) demonstrated that deep neural networkshave sufficient capacity to fit randomized labels on popularimage classification tasks and suggest this result requiresre-thinking generalization to understand deep learning

We argue however that this behaviour is not puzzling froma probabilistic perspective is not unique to neural networksand cannot be used as evidence against Bayesian neuralnetworks (BNNs) with vague parameter priors Fundamen-tally the resolution is the view presented in the introductionfrom a probabilistic perspective generalization is at leasta two-dimensional concept related to support (flexibility)which should be as large as possible supporting even noisysolutions and inductive biases that represent relative priorprobabilities of solutions

Indeed we demonstrate that the behaviour in Zhang et al(2016) that was treated as mysterious and specific to neuralnetworks can be exactly reproduced by Gaussian processes(GPs) Gaussian processes are an ideal choice for this exper-iment because they are popular Bayesian non-parametricmodels and they assign a prior directly in function spaceMoreover GPs have remarkable flexibility providing univer-sal approximation with popular covariance functions suchas the RBF kernel Yet the functions that are a priori likelyunder a GP with an RBF kernel are relatively simple Wedescribe GPs further in the Appendix and Rasmussen ampWilliams (2006) provides an extensive introduction

We start with a simple example to illustrate the ability fora GP with an RBF kernel to easily fit a corrupted datasetyet generalize well on a non-corrupted dataset in Figure 7In Fig 7(a) we have sample functions from a GP prior overfunctions p(f(x)) showing that likely functions under theprior are smooth and well-behaved In Fig 7(b) we see theGP is able to reasonably fit data from a structured functionAnd in Fig 7(c) the GP is also able to fit highly corrupteddata with essentially no structure although these data arenot a likely draw from the prior the GP has support for awide range of solutions including noise

We next show that GPs can replicate the generalizationbehaviour described in Zhang et al (2016) (experimentaldetails in the Appendix) When applied to CIFAR-10 imageswith random labels Gaussian processes achieve 100 trainaccuracy and 104 test accuracy (at the level of randomguessing) However the same model trained on the truelabels achieves a training accuracy of 728 and a testaccuracy of 543 Thus the generalization behaviourdescribed in Zhang et al (2016) is not unique to neuralnetworks and can be described by separately understandingthe support and the inductive biases of a model

Indeed although Gaussian processes support CIFAR-10images with random labels they are not likely under the

GP prior In Fig 7(d) we compute the approximate GPmarginal likelihood on a binary CIFAR-10 classificationproblem with labels of varying levels of corruption We seeas the noise in the data increases the approximate marginallikelihood and thus the prior support for these data de-creases In Fig 7(e) we see a similar trend for a Bayesianneural network Again as the fraction of corrupted labelsincreases the approximate marginal likelihood decreasesshowing that the prior over functions given by the Bayesianneural network has less support for these noisy datasets Weprovide further experimental details in the Appendix

Dziugaite amp Roy (2017) and Smith amp Le (2018) providecomplementary perspectives on Zhang et al (2016) forMNIST Dziugaite amp Roy (2017) show non-vacuous PAC-Bayes bounds for the noise-free binary MNIST but not noisyMNIST and Smith amp Le (2018) show that logistic regres-sion can fit noisy labels on subsampled MNIST interpretingthe results from an Occam factor perspective

7 Double DescentDouble descent (eg Belkin et al 2019) describes gen-eralization error that decreases increases and then againdecreases with increases in model flexibility The firstdecrease and then increase is referred to as the classicalregime models with increasing flexibility are increasinglyable to capture structure and perform better until they beginto overfit The next regime is referred to as the modern inter-polating regime The existence of the interpolation regimehas been presented as mysterious generalization behaviourin deep learning

However our perspective of generalization suggests thatperformance should monotonically improve as we increasemodel flexibility when we use Bayesian model averagingwith a reasonable prior Indeed in the opening exampleof Figure 1 we would in principle want to use the mostflexible possible model Our results in Section 5 show thatstandard BNN priors induce structured and useful priors inthe function space so we should not expect double descentin Bayesian deep learning models that perform reasonablemarginalization

To test this hypothesis we evaluate MultiSWAG SWAGand standard SGD with ResNet-18 models of varying widthfollowing Nakkiran et al (2019) measuring both error andnegative log likelihood (NLL) For the details see AppendixD We present the results in Figure 8 and Appendix Figure13

First we observe that models trained with SGD indeed suf-fer from double descent especially when the train labels arepartially corrupted (see panels 8(c) 8(d)) We also see thatSWAG a unimodal posterior approximation reduces theextent of double descent Moreover MultiSWAG which

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 11: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

20

25

30

35

40

45T

est

Err

or(

)SGD

SWAG

MultiSWAG

(a) True Labels (Err)

10 20 30 40 50

ResNet-18 width

08

10

12

14

16

Tes

tN

LL

(b) True Labels (NLL)

10 20 30 40 50

ResNet-18 width

275

300

325

350

375

400

425

450

475

Tes

tE

rror

()

(c) Corrupted (Err)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

22

24

26

Tes

tN

LL

(d) Corrupted (NLL)

0 10 20 30 40 50

ResNet-18 width

30

35

40

45

50

Mul

tiS

WA

GT

est

Err

or(

)

SWAG Models

1

3

5

10

(e) Corrupted ( Models)

Figure 8 Bayesian model averaging alleviates double descent (a) Test error and (b) NLL loss for ResNet-18 with varying widthon CIFAR-100 for SGD SWAG and MultiSWAG (c) Test error and (d) NLL loss when 20 of the labels are randomly reshuffledSWAG reduces double descent and MultiSWAG which marginalizes over multiple modes entirely alleviates double descent both on theoriginal labels and under label noise both in accuracy and NLL (e) Test errors for MultiSWAG with varying number of independentSWAG models error monotonically decreases with increased number of independent models alleviating double descent We also notethat MultiSWAG provides significant improvements in accuracy and NLL over SGD and SWAG models See Appendix Figure 13 foradditional results

performs a more exhaustive multimodal Bayesian model av-erage completely mitigates double descent the performanceof MultiSWAG solutions increases monotonically with thesize of the model showing no double descent even undersignificant label corruption for both accuracy and NLL

Our results highlight the importance of marginalization overmultiple modes of the posterior under 20 label corruptionSWAG clearly suffers from double descent while Multi-SWAG does not In Figure 8(e) we show how the doubledescent is alleviated with increased number of independentmodes marginalized in MultiSWAG

These results also clearly show that MultiSWAG providessignificant improvements in accuracy over both SGD andSWAG models in addition to NLL an often overlookedadvantage of Bayesian model averaging we discuss in Sec-tion 31

Recently Nakkiran et al (2020) show that carefully tunedl2 regularization can help mitigate double descent Alterna-tively we show that Bayesian model averaging particularlybased on multimodal marginalization can mitigate promi-nent double descent behaviour The perspective in Sections1 and 3 predicts this result models with reasonable priorsand effective Bayesian model averaging should monotoni-cally improve with increases in flexibility

8 Temperature ScalingThe standard Bayesian posterior distribution is given by

p(w|D) =1

Zp(D|w)p(w) (2)

where p(D|w) is a likelihood p(w) is a prior and Z is anormalizing constant

In Bayesian deep learning it is typical to consider the tem-

pered posterior

pT (w|D) =1

Z(T )p(D|w)1T p(w) (3)

where T is a temperature parameter and Z(T ) is the nor-malizing constant corresponding to temperature T Thetemperature parameter controls how the prior and likelihoodinteract in the posterior

bull T lt 1 corresponds to cold posteriors where the poste-rior distribution is more concentrated around solutionswith high likelihood

bull T = 1 corresponds to the standard Bayesian posteriordistribution

bull T gt 1 corresponds to warm posteriors where the prioreffect is stronger and the posterior collapse is slower

Tempering posteriors is a well-known practice in statis-tics where it goes by the names Safe Bayes generalizedBayesian inference and fractional Bayesian inference (egde Heide et al 2019 Grunwald et al 2017 Barron ampCover 1991 Walker amp Hjort 2001 Zhang 2006 Bissiriet al 2016 Grunwald 2012) Safe Bayes has been shown tobe natural from a variety of perspectives including from pre-quential learning theory and minimum description lengthframeworks (eg Grunwald et al 2017)

Concurrently with our work Wenzel et al (2020) noticedthat successful Bayesian deep learning methods tend to usecold posteriors They provide an empirical study that showsthat Bayesian neural networks (BNNs) with cold posteriorsoutperform models with SGD based maximum likelihoodtraining while BNNs with T = 1 can perform worse thanthe maximum likelihood solution They claim that coldposteriors sharply deviate from the Bayesian paradigm and

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 12: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Prio

rSa

mpl

e1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rSa

mpl

e2

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

Prio

rPr

edic

tive

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) α = 001

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) α = 01

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) α = 03

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(d) α = 1

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(e) α =radic10

10minus2 10minus1 100

Prior std α

2 middot 103

4 middot 103

1 middot 104

2 middot 104

NL

L

PreResNet-20

VGG-16

Weight Decay

(f)

10minus2 10minus1 100

Prior std α

5

10

20

90

Cla

ssifi

cati

onE

rror

PreResNet-20

VGG-16

Weight Decay

(g)

Figure 9 Effects of the prior variance α2 (a)ndash(e) Average class probabilities over all of CIFAR-10 for two sample prior functionsp(f(xw)) (two top rows) and predictive distribution (average over 200 samples of weights bottom row) for varying settings of α inp(w) = N (0 α2I) (f) NLL and (g) classification error of an ensemble of 20 SWAG samples on CIFAR-10 as a function of α using aPreactivation ResNet-20 and VGG-16 The NLL is high for overly small α and near-optimal in the range of [01 03] The NLL remainsrelatively low for vague priors corresponding to large values of α

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

Cla

ssP

roba

bilit

y

(a) Prior (α =radic10)

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(b) 10 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(c) 100 datapoints

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

0 1 2 3 4 5 6 7 8 9

00

02

04

06

08

10

(d) 1000 datapoints

Figure 10 Adaptivity of posterior variance with data We sample two functions f(xw) from the distribution over functions inducedby a distribution over weights starting with the prior p(w) = N (0 10 middot I) in combination with a PreResNet-20 We measure classprobabilities averaged across the CIFAR-10 test set as we vary the amount of available training data Although the prior variance is toolarge such that the softmax saturates for logits sampled from the prior leading to one class being favoured we see that the posteriorquickly adapts to correct the scale of the logits in the presence of data In Figure 9 we also show that the prior variance can easily becalibrated such that the prior predictive distribution even before observing data is high entropy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 13: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

consider possible reasons for why tempering is helpful inBayesian deep learning

In this section we provide an alternative view and argue thattempering is not at odds with Bayesian principles Moreoverfor virtually any realistic model class and dataset it wouldbe highly surprising if T = 1 were in fact the best settingof this hyperparameter Indeed as long as it is practicallyconvenient we would advocate tempering for essentiallyany model especially parametric models that do not scaletheir capacity automatically with the amount of availableinformation Our position is that at a high level Bayesianmethods are trying to combine honest beliefs with data toform a posterior By reflecting the belief that the model ismisspecified the tempered posterior is often more of a trueposterior than the posterior that results from ignoring ourbelief that the model misspecified

Finding that T lt 1 helps for Bayesian neural networks isneither surprising nor discouraging And the actual resultsof the experiments in Wenzel et al (2020) which showgreat improvements over standard SGD training are in factvery encouraging of deriving inspiration from Bayesianprocedures in deep learning

We consider (1) tempering under misspecification (Sec-tion 81) (2) tempering in terms of overcounting data (Sec-tion 82) (3) how tempering compares to changing theobservation model (Section 83) (4) the effect of the priorin relation to the experiments of Wenzel et al (2020) (Sec-tion 84) (5) the effect of approximate inference includinghow tempering can help in efficiently estimating parameterseven for the untempered posterior (Section 85)

81 Tempering Helps with Misspecified Models

Many works explain how tempered posteriors help un-der model misspecification (eg de Heide et al 2019Grunwald et al 2017 Barron amp Cover 1991 Walker ampHjort 2001 Zhang 2006 Bissiri et al 2016 Grunwald2012) In fact de Heide et al (2019) and Grunwald et al(2017) provide several simple examples where Bayesianinference fails to provide good convergence behaviour foruntempered posteriors While it is easier to show theoreticalresults for T gt 1 several of these works also show thatT lt 1 can be preferred even in well-specified settingsand indeed recommend learning T from data for exampleby cross-validation (eg Grunwald 2012 de Heide et al2019)

Are we in a misspecified setting for Bayesian neural net-works Of course And it would be irrational to proceedas if it were otherwise Every model is misspecified In thecontext of Bayesian neural networks specifically the massof solutions expressed by the prior outside of the datasetswe typically consider is likely much larger than desired for

most applications We can calibrate for this discrepancythrough tempering The resulting tempered posterior willbe more in line with our beliefs than pretending the modelis not misspecified and finding the untempered posterior

Non-parametric models such as Gaussian processes at-tempt to side-step model misspecification by growing thenumber of free parameters (information capacity) automat-ically with the amount of available data In parametricmodels we take much more of a manual guess about themodel capacity In the case of deep neural networks thischoice is not even close to a best guess it was once the casethat architectural design was a large component of worksinvolving neural networks but now it is more standard prac-tice to choose an off-the-shelf architecture without muchconsideration of model capacity We do not believe thatknowingly using a misspecified model to find a posterioris more reasonable (or Bayesian) than honestly reflectingthe belief that the model is misspecified and then using atempered posterior For parametric models such as neuralnetworks it is to be expected that the capacity is particularlymisspecified

82 Overcounting Data with Cold Posteriors

The criticism of cold posteriors raised by Wenzel et al(2020) is largely based on the fact that decreasing tempera-ture leads to overcounting data in the posterior distribution

However a similar argument can be made against marginallikelihood maximization (also known as empirical Bayesor type 2 maximum likelihood) Indeed here the prior willdepend on the same data as the likelihood which can leadto miscalibrated predictive distributions (Darnieder 2011)

Nonetheless empirical Bayes has been embraced and widelyadopted in Bayesian machine learning (eg Bishop 2006Rasmussen amp Williams 2006 MacKay 2003 Minka2001b) as embodying several Bayesian principles Em-pirical Bayes has been particularly embraced in seminalwork on Bayesian neural networks (eg MacKay 19921995) where it has been proposed as a principled approachto learning hyperparameters such as the scale of the vari-ance for the prior over weights automatically embodyingOccamrsquos razor While there is in this case some deviationfrom the fully Bayesian paradigm the procedure whichdepends on marginalization is nonetheless clearly inspiredby Bayesian thinking mdash and it is thus helpful to reflect thisinspiration and provide understanding of how it works froma Bayesian perspective

There is also work showing the marginal likelihood can leadto miscalibrated Bayes factors under model misspecificationAttempts to calibrate these factors (Xu et al 2019) as partof the Bayesian paradigm is highly reminiscent of work onsafe Bayes

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 14: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

83 Tempered Posterior or Different Likelihood

The tempered posterior for one model is an untempered pos-terior using a different observation model In other wordswe can trivially change the likelihood function in our modelso that the standard Bayesian posterior in the new modelis equal to the posterior of temperature T in the originalmodel Indeed consider the likelihood function

pT (D|w) prop p(D|w)1T (4)

where the posterior distribution for the model MT withlikelihood pT is exactly the temperature T posterior for themodelM with likelihood p

The predictive distribution differs for the two models eventhough the posteriors coincide the likelihoods for a newdatapoint ylowast are differentint

p(ylowast|w)p(w)dw 6=intpT (ylowast|w)p(w)dw (5)

As an example consider a regression model M with aGaussian likelihood function y sim N (f σ2) where f is theoutput of the network and σ2 is the noise variance Thepredictive distributions for the two modelsM andMT fora new input x will have different variance but the samemean Ew[f + ε] = Ew[f ] Moreover in this case the noisevariance would typically be learned in either model MT

orM

Section 41 of Grunwald et al (2017) considers a relatedconstruction

84 Effect of the Prior

While a somewhat misspecified prior will certainly interactwith the utility of tempering we do not believe the exper-iments in Wenzel et al (2020) provide evidence that eventhe prior p(w) = N (0 I) is misspecified to any seriousextent For a relatively wide range of distributions over wthe functional form of the network f(xw) can produce agenerally reasonable distribution over functions p(f(xw))In Figure 10 we reproduce the findings in Wenzel et al(2020) that show that sample functions p(f(xw)) corre-sponding to the prior p(w) = N (0 10 middot I) strongly favoura single class over the dataset While this behaviour appearssuperficially dramatic we note it is simply an artifact of hav-ing a miscalibrated signal variance A miscalibrated signalvariance interacts with a quickly saturating soft-max linkfunction to provide a seemingly dramatic preference to agiven class If we instead use p(w) = N (0 α2I) for quitea range of α then sample functions provide reasonably highentropy across labels as in Figure 9 The value of α can beeasily determined through cross-validation as in Figure 9or specified as a standard value used for L2 regularization(α = 024 in this case)

However even with the inappropriate prior scale we seein the bottom row of panels (a)ndash(e) of Figure 9 that the un-conditional predictive distribution is completely reasonableMoreover the prior variance represents a soft prior bias andwill quickly update in the presence of data In Figure 10we show the posterior samples after observing 10 100 and1000 data points

Other aspects of the prior outside of the prior signal vari-ance will have a much greater effect on the inductive biasesof the model For example the induced covariance functioncov(f(xi w) f(xj w)) reflects the induced similarity met-ric over data instances through the covariance function wecan answer for instance whether the model believes a priorithat a translated image is similar to the original Unlike thesignal variance of the prior the prior covariance functionwill continue to have a significant effect on posterior in-ference for even very large datasets and strongly reflectsthe structural properties of the neural network We explorethese structures of the prior in Figure 12

85 The Effect of Inexact Inference

We have to keep in mind what we ultimately use poste-rior samples to compute Ultimately we wish to estimatethe predictive distribution given by the integral in Equa-tion (1) With a finite number of samples the temperedposterior could be used to provide a better approximationto the expectation of the predictive distribution associatedwith untempered posterior

Consider a simple example where we wish to estimate themean of a high-dimensional Gaussian distribution N (0 I)Suppose we use J independent samples The mean of thesesamples is also Gaussian distributed micro sim N (0 1

J I) InBayesian deep learning the dimension d is typically on theorder 107 and J would be on the order of 10 The norm ofmicro would be highly concentrated around

radic107radic10

= 1000 Inthis case sampling from a tempered posterior with T lt 1would lead to a better approximation of the Bayesian modelaverage associated with an untempered posterior

Furthermore no current sampling procedure will be provid-ing samples that are close to independent samples from thetrue posterior of a Bayesian neural network The posteriorlandscape is far too multimodal and complex for there tobe any reasonable coverage The approximations we haveare practically useful and often preferable to conventionaltraining but we cannot realistically proceed with analysisassuming that we have obtained true samples from a pos-terior While we would expect that some value of T 6= 1would be preferred for any finite dataset in practice it is con-ceivable that some of the results in Wenzel et al (2020) maybe affected by the specifics of the approximate inferencetechnique being used

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 15: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

We should be wary not to view Bayesian model averagingpurely through the prism of simple Monte Carlo as advisedin Section 32 Given a finite computational budget our goalin effectively approximating a Bayesian model average isnot equivalent to obtaining good samples from the posterior

9 DiscussionldquoIt is now common practice for Bayesians to

fit models that have more parameters than thenumber of data points Incorporate every imag-inable possibility into the model space for exam-ple if it is conceivable that a very simple modelmight be able to explain the data one shouldinclude simple models if the noise might havea long-tailed distribution one should include ahyperparameter which controls the heaviness ofthe tails of the distribution if an input variablemight be irrelevant to a regression include it inthe regression anywayrdquo MacKay (1995)

We have presented a probabilistic perspective of general-ization which depends on the support and inductive biasesof the model The support should be as large possible butthe inductive biases must be well-calibrated to a given prob-lem class We argue that Bayesian neural networks embodythese properties mdash and through the lens of probabilisticinference explain generalization behaviour that has pre-viously been viewed as mysterious Moreover we arguethat Bayesian marginalization is particularly compelling forneural networks show how deep ensembles provide a prac-tical mechanism for marginalization and propose a newapproach that generalizes deep ensembles to marginalizewithin basins of attraction We show that this multimodalapproach to Bayesian model averaging MultiSWAG canentirely alleviate double descent to enable monotonic per-formance improvements with increases in model flexibilityas well significant improvements in generalization accuracyand log likelihood over SGD and single basin marginaliza-tion

There are certainly many challenges to estimating the inte-gral for a Bayesian model average in modern deep learningincluding a high-dimensional parameter space and a com-plex posterior landscape But viewing the challenge indeedas an integration problem rather than an attempt to obtainposterior samples for a simple Monte Carlo approximationprovides opportunities for future progress Bayesian deeplearning has been making fast practical advances with ap-proaches that now enable better accuracy and calibrationover standard training with minimal overhead

We finish with remarks about future developments forBayesian neural network priors and approaches to researchin Bayesian deep learning

91 The Future for BNN Priors

We provide some brief remarks about future developmentsfor BNN priors Here we have explored relatively simpleparameter priors p(w) = N (0 α2I) While these priors aresimple in parameter space they interact with the neural net-work architecture to induce a sophisticated prior over func-tions p(f(xw)) with many desirable properties includinga reasonable correlation structure over images Howeverthese parameter priors can certainly still be improved Aswe have seen even tuning the value of the signal varianceα2 an analogue of the L2 regularization often used in deeplearning can have a noticeable affect on the induced priorover functions mdash though this affect is quickly modulated bydata Layer-wise priors such that parameters in each layerhave a different signal variance are intuitive we wouldexpect later layers require precise determination while pa-rameters in earlier layers could reasonably take a range ofvalues But one has to be cautious as we show in AppendixSection E with ReLU activations different signal variancesin different layers can be degenerate combining together toaffect only the output scale of the network

A currently popular sentiment is that we should directlybuild function-space BNN priors often taking inspirationfrom Gaussian processes While we believe this is a promis-ing direction one should proceed with caution If we con-trive priors over parameters p(w) to induce distributionsover functions p(f) that resemble familiar models such asGaussian processes with RBF kernels we could be throw-ing the baby out with the bathwater Neural networks areuseful as their own model class precisely because they havedifferent inductive biases from other models

A similar concern applies to taking infinite width limits inBayesian neural networks In these cases we recover Gaus-sian processes with interpretable kernel functions becausethese models are easier to use and analyze and give riseto interpretable and well-motivated priors it is tempting totreat them as drop-in replacements for the parametric ana-logues However the kernels for these models are fixed Inorder for a model to do effective representation learning wemust learn a similarity metric for the data Training a neuralnetwork in many ways is like learning a kernel rather thanusing a fixed kernel MacKay (1998) has also expressedconcerns in treating these limits as replacements for neuralnetworks due to the loss of representation learning power

Perhaps the distribution over functions induced by a networkin combination with a generic distribution over parametersp(w) may be hard to interpret mdash but this distribution willcontain the equivariance properties representation learn-ing abilities and other biases that make neural networks acompelling model class in their own right

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 16: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

92 ldquoBut is it really Bayesianrdquo

We finish with an editorial comment about approaches toresearch within Bayesian deep learning There is sometimesa tendency to classify work as Bayesian or not Bayesianwith very stringent criteria for what qualifies as BayesianMoreover the implication and sometimes even explicit rec-ommendation is that if an approach is not unequivocallyBayesian in every respect then we should not term it asBayesian and we should instead attempt to understand theprocedure through entirely different non-Bayesian mech-anisms We believe this mentality encourages tribalismwhich is not conducive to the best research or creating thebest performing methods What matters is not a debate aboutsemantics but making rational modelling choices given aparticular problem setting and trying to understand thesechoices Often these choices can largely be inspired by aBayesian approach mdash in which case it desirable to indicatethis source of inspiration And in the semantics debate whowould be the arbiter of what gets to be called BayesianArguably it ought to be an evolving definition

Broadly speaking what makes Bayesian approaches distinc-tive is a posterior weighted marginalization over parametersAnd at a high level Bayesian methods are about combiningour honest beliefs with data to form a posterior In actualityno fair-minded researcher entirely believes the prior overparameters the functional form of the model (which is partof the prior over functions) or the likelihood From thisperspective it is broadly compatible with a Bayesian philos-ophy to reflect misspecification in the modelling procedureitself which is achieved through tempering In this sensethe tempered posterior is more reflective of a true posteriorthan the posterior that results from ignoring our belief thatthe model is misspecified

Moreover basic probability theory indicates that marginal-ization is desirable While marginalization cannot in prac-tice be achieved exactly we can try to improve over conven-tional training which as we have discussed can be viewedas approximate marginalization Given computational con-straints effective marginalization is not equivalent to ob-taining accurate samples from a posterior As we havediscussed simple Monte Carlo is only one of many mech-anisms for marginalization Just like we how expectationpropagation (Minka 2001a) focuses its approximation tofactors in a posterior where it will most affect the end resultwe should focus on representing the posterior where it willmake the biggest difference to the model average As wehave shown deep ensembles are a reasonable mechanismup to a point After having trained many independent mod-els there are added benefits to marginalizing within basinsgiven the computational expense associated with retrainingan additional model to find an additional basin of attraction

We should also not hold Bayesian methods to a double

standard Indeed it can be hard to interpret or understandthe prior the posterior and whether the marginalizationprocedure is optimal But it is also hard to interpret thechoices behind the functional form of the model or the ra-tionale behind classical procedures where we bet everythingon a single global optimum mdash when we know there aremany global optima and many of them will perform wellbut provide different solutions and many others will notperform well We should apply the same level of scrutiny toall modelling choices consider the alternatives and not beparalyzed if a procedure is not optimal in every respect

Acknowledgements AGW and PI are supported by anAmazon Research Award Facebook Research NSF I-DISRE 193471 NIH R01 DA048764-01A1 NSF IIS-1563887 and NSF IIS-1910266 We thank Greg Benton forhelpful discussions

ReferencesAchille A and Soatto S Emergence of invariance and

disentanglement in deep representations The Journal ofMachine Learning Research 19(1)1947ndash1980 2018

Ashukha A Lyzhov A Molchanov D and Vetrov DPitfalls of in-domain uncertainty estimation and ensem-bling in deep learning arXiv preprint arXiv2002064702020

Barron A R and Cover T M Minimum complexity den-sity estimation IEEE transactions on information theory37(4)1034ndash1054 1991

Beal M J Variational algorithms for approximate Bayesianinference university of London 2003

Belkin M Hsu D Ma S and Mandal S Reconcilingmodern machine-learning practice and the classical biasndashvariance trade-off Proceedings of the National Academyof Sciences 116(32)15849ndash15854 2019

Bishop C M Pattern Recognition and Machine LearningSpringer 2006

Bissiri P G Holmes C C and Walker S G A generalframework for updating belief distributions Journal ofthe Royal Statistical Society Series B (Statistical Method-ology) 78(5)1103ndash1130 2016

Blundell C Cornebise J Kavukcuoglu K and WierstraD Weight uncertainty in neural networks arXiv preprintarXiv150505424 2015

Box G E and Tiao G C Bayesian inference in statisticalanalysis addision-wesley Reading MA 1973

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 17: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Chaudhari P Choromanska A Soatto S LeCun Y Bal-dassi C Borgs C Chayes J Sagun L and ZecchinaR Entropy-sgd Biasing gradient descent into wide val-leys Journal of Statistical Mechanics Theory and Ex-periment 2019

Cobb A D Baydin A G Markham A and Roberts S JIntroducing an explicit symplectic integration schemefor riemannian manifold hamiltonian monte carlo arXivpreprint arXiv191006243 2019

Darnieder W F Bayesian methods for data-dependentpriors PhD thesis The Ohio State University 2011

de Heide R Kirichenko A Mehta N and GrunwaldP Safe-Bayesian generalized linear regression arXivpreprint arXiv191009227 2019

Dinh L Pascanu R Bengio S and Bengio Y Sharpminima can generalize for deep nets arXiv preprintarXiv170304933 2017

Draxler F Veschgini K Salmhofer M and HamprechtF A Essentially no barriers in neural network energylandscape arXiv preprint arXiv180300885 2018

Dziugaite G K and Roy D M Computing nonvacuousgeneralization bounds for deep (stochastic) neural net-works with many more parameters than training dataarXiv preprint arXiv170311008 2017

Fort S and Jastrzebski S Large scale structure of neuralnetwork loss landscapes In Advances in Neural Informa-tion Processing Systems pp 6706ndash6714 2019

Fort S Hu H and Lakshminarayanan B Deep en-sembles A loss landscape perspective arXiv preprintarXiv191202757 2019

Gal Y and Ghahramani Z Dropout as a Bayesian approx-imation Representing model uncertainty in deep learn-ing In international conference on machine learning pp1050ndash1059 2016

Gardner J R Pleiss G Bindel D Weinberger K Qand Wilson A G GPyTorch Blackbox matrix-matrixgaussian process inference with gpu acceleration InNeural Information Processing Systems 2018

Garipov T Izmailov P Podoprikhin D Vetrov D P andWilson A G Loss surfaces mode connectivity and fastensembling of DNNs In Neural Information ProcessingSystems 2018

Grunwald P The safe Bayesian In International Con-ference on Algorithmic Learning Theory pp 169ndash183Springer 2012

Grunwald P Van Ommen T et al Inconsistency ofBayesian inference for misspecified linear models anda proposal for repairing it Bayesian Analysis 12(4)1069ndash1103 2017

Guedj B A primer on pac-bayesian learning arXiv preprintarXiv190105353 2019

Guo C Pleiss G Sun Y and Weinberger K Q Oncalibration of modern neural networks In Proceedings ofthe 34th International Conference on Machine Learning-Volume 70 pp 1321ndash1330 JMLR org 2017

Gustafsson F K Danelljan M and Schon T B Evaluat-ing scalable bayesian deep learning methods for robustcomputer vision arXiv preprint arXiv190601620 2019

Hafner D Tran D Irpan A Lillicrap T and David-son J Reliable uncertainty estimates in deep neuralnetworks using noise contrastive priors arXiv preprintarXiv180709289 2018

He K Zhang X Ren S and Sun J Deep residual learn-ing for image recognition In Proceedings of the IEEEconference on computer vision and pattern recognitionpp 770ndash778 2016

Hendrycks D and Dietterich T Benchmarking neuralnetwork robustness to common corruptions and perturba-tions arXiv preprint arXiv190312261 2019

Hinton G E and Van Camp D Keeping the neural net-works simple by minimizing the description length of theweights In Proceedings of the sixth annual conferenceon Computational learning theory pp 5ndash13 1993

Hochreiter S and Schmidhuber J Flat minima NeuralComputation 9(1)1ndash42 1997

Huang W R Emam Z Goldblum M Fowl L TerryJ K Huang F and Goldstein T Understandinggeneralization through visualizations arXiv preprintarXiv190603291 2019

Izmailov P Podoprikhin D Garipov T Vetrov D andWilson A G Averaging weights leads to wider optimaand better generalization In Uncertainty in ArtificialIntelligence (UAI) 2018

Izmailov P Maddox W J Kirichenko P Garipov TVetrov D and Wilson A G Subspace inference forBayesian deep learning Uncertainty in Artificial Intelli-gence 2019

Jacot A Gabriel F and Hongler C Neural tangent kernelConvergence and generalization in neural networks InAdvances in neural information processing systems pp8571ndash8580 2018

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 18: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Jiang Y Neyshabur B Mobahi H Krishnan D andBengio S Fantastic generalization measures and whereto find them arXiv preprint arXiv191202178 2019

Kass R E and Raftery A E Bayes factors Journal ofthe American Statistical Association 90(430)773ndash7951995

Kendall A and Gal Y What uncertainties do we need inBayesian deep learning for computer vision In Advancesin neural information processing systems pp 5574ndash55842017

Keskar N S Mudigere D Nocedal J SmelyanskiyM and Tang P T P On large-batch training for deeplearning Generalization gap and sharp minima arXivpreprint arXiv160904836 2016

Khan M E Nielsen D Tangkaratt V Lin W GalY and Srivastava A Fast and scalable Bayesian deeplearning by weight-perturbation in adam arXiv preprintarXiv180604854 2018

Krizhevsky A Nair V and Hinton G The CIFAR-10dataset 2014 httpwwwcstorontoedukrizcifarhtml

Lakshminarayanan B Pritzel A and Blundell C Simpleand scalable predictive uncertainty estimation using deepensembles In Advances in Neural Information Process-ing Systems pp 6402ndash6413 2017

Langford J and Caruana R (not) bounding the true errorIn Advances in Neural Information Processing Systemspp 809ndash816 2002

LeCun Y Bottou L Bengio Y and Haffner P Gradient-based learning applied to document recognition Proceed-ings of the IEEE 86(11)2278ndash2324 1998

Louizos C Shi X Schutte K and Welling M The func-tional neural process In Advances in Neural InformationProcessing Systems 2019

MacKay D J Bayesian methods for adaptive models PhDthesis California Institute of Technology 1992

MacKay D J Probable networks and plausible predic-tionsa review of practical Bayesian methods for super-vised neural networks Network computation in neuralsystems 6(3)469ndash505 1995

MacKay D J Introduction to Gaussian processes InBishop C M (ed) Neural Networks and Machine Learn-ing chapter 11 pp 133ndash165 Springer-Verlag 1998

MacKay D J Information theory inference and learningalgorithms Cambridge university press 2003

Maddox W J Izmailov P Garipov T Vetrov D P andWilson A G A simple baseline for Bayesian uncertaintyin deep learning In Advances in Neural InformationProcessing Systems 2019

Masegosa A R Learning under model misspecificationApplications to variational and ensemble methods 2019

McAllester D A Pac-bayesian model averaging In Pro-ceedings of the twelfth annual conference on Computa-tional learning theory pp 164ndash170 1999

Minka T Expectation propagation for approximateBayesian inference In Uncertainty in Artificial Intel-ligence volume 17 pp 362ndash369 2001a

Minka T P Bayesian model averaging is not model combi-nation 2000

Minka T P Automatic choice of dimensionality for pca InAdvances in neural information processing systems pp598ndash604 2001b

Mohri M and Rostamizadeh A Rademacher complexitybounds for non-iid processes In Advances in NeuralInformation Processing Systems pp 1097ndash1104 2009

Nakkiran P Kaplun G Bansal Y Yang T BarakB and Sutskever I Deep double descent Wherebigger models and more data hurt arXiv preprintarXiv191202292 2019

Nakkiran P Venkat P Kakade S and Ma T Optimalregularization can mitigate double descent arXiv preprintarXiv200301897 2020

Nalisnick E On priors for Bayesian neural networks PhDthesis University of California Irvine 2018

Neal R Bayesian Learning for Neural Networks SpringerVerlag 1996 ISBN 0387947248

Neyshabur B Bhojanapalli S McAllester D and Sre-bro N Exploring generalization in deep learning InAdvances in Neural Information Processing Systems pp5947ndash5956 2017

Neyshabur B Bhojanapalli S and Srebro N APAC-bayesian approach to spectrally-normalized marginbounds for neural networks In International Confer-ence on Learning Representations 2018 URL httpsopenreviewnetforumid=Skz_WfbCZ

Opper M Kinzel W Kleinz J and Nehl R On theability of the optimal perceptron to generalise Journalof Physics A Mathematical and General 23(11)L5811990

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 19: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Ovadia Y Fertig E Ren J Nado Z Sculley DNowozin S Dillon J V Lakshminarayanan B andSnoek J Can you trust your modelrsquos uncertainty eval-uating predictive uncertainty under dataset shift arXivpreprint arXiv190602530 2019

Pearce T Zaki M Brintrup A Anastassacos N andNeely A Uncertainty in neural networks Bayesianensembling arXiv preprint arXiv181005546 2018

Rasmussen C E and Ghahramani Z Occamrsquos razor InNeural Information Processing Systems (NIPS) 2001

Rasmussen C E and Williams C K I Gaussian processesfor Machine Learning The MIT Press 2006

Ritter H Botev A and Barber D A scalable Laplaceapproximation for neural networks In International Con-ference on Learning Representations (ICLR) 2018

Rue H Martino S and Chopin N Approximate Bayesianinference for latent gaussian models by using integratednested laplace approximations Journal of the royal sta-tistical society Series b (statistical methodology) 71(2)319ndash392 2009

Simonyan K and Zisserman A Very deep convolu-tional networks for large-scale image recognition arXivpreprint arXiv14091556 2014

Smith S L and Le Q V A Bayesian perspective on gener-alization and stochastic gradient descent In InternationalConference on Learning Representations 2018

Sun S Zhang G Shi J and Grosse R Functionalvariational Bayesian neural networks arXiv preprintarXiv190305779 2019

Ulyanov D Vedaldi A and Lempitsky V Deep imageprior In Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition pp 9446ndash94542018

Vapnik V N Adaptive and learning systems for signal pro-cessing communications and control Statistical learningtheory 1998

Virtanen P Gommers R Oliphant T E Haberland MReddy T Cournapeau D Burovski E Peterson PWeckesser W Bright J et al Scipy 10 fundamentalalgorithms for scientific computing in python NatureMethods pp 1ndash12 2020

Walker S and Hjort N L On Bayesian consistency Jour-nal of the Royal Statistical Society Series B (StatisticalMethodology) 63(4)811ndash821 2001

Welling M and Teh Y W Bayesian learning via stochasticgradient langevin dynamics In Proceedings of the 28thinternational conference on machine learning (ICML-11)pp 681ndash688 2011

Wenzel F Roth K Veeling B S Swiatkowski J TranL Mandt S Snoek J Salimans T Jenatton R andNowozin S How good is the Bayes posterior in deepneural networks really arXiv preprint arXiv2002024052020

Wilson A G The case for Bayesian deep learning arXivpreprint arXiv200110995 2020

Xu X Lu P MacEachern S and Xu R Calibratedbayes factors for model comparison Journal of StatisticalComputation and Simulation 89(4)591ndash614 2019

Yang W Lorch L Graule M A Srinivasan S SureshA Yao J Pradier M F and Doshi-Velez F Output-constrained Bayesian neural networks arXiv preprintarXiv190506287 2019

Zhang C Bengio S Hardt M Recht B and Vinyals OUnderstanding deep learning requires rethinking general-ization arXiv preprint arXiv161103530 2016

Zhang T Information-theoretic upper and lower bounds forstatistical estimation IEEE Transactions on InformationTheory 52(4)1307ndash1321 2006

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 20: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Appendix OutlineThis appendix is organized as follows In Section A wevisualize predictive functions corresponding to weight sam-ples within high posterior density valleys on a regressionproblem In Section B we provide background material onGaussian processes In Section C we present further resultscomparing MultiSWAG and MultiSWA to Deep Ensemblesunder data distribution shift on CIFAR-10 In Section D weprovide the details of all experiments presented in the paperIn Section E we present analytic results on the dependenceof the prior distribution in function space on the variance ofthe prior over parameters In Section F we study the priorcorrelations between BNN logits on perturbed images

A Loss ValleysWe demonstrate that different points along the valleys ofhigh posterior density (low loss) connecting pairs of in-dependently trained optima (Garipov et al 2018 Draxleret al 2018 Fort amp Jastrzebski 2019) correspond to dif-ferent predictive functions We use the regression examplefrom Izmailov et al (2019) and show the results in Figure11

B Gaussian processesWith a Bayesian neural network a distribution over parame-ters p(w) induces a distribution over functions p(f(xw))when combined with the functional form of the networkGaussian processes (GPs) are often used to instead directlyspecify a distribution over functions

A Gaussian process is a distribution over functions f(x) simGP(m k) such that any collection of function valuesqueried at any finite set of inputs x1 xn has a jointGaussian distribution

f(x1) f(xn) sim N (microK) (6)

The mean vector microi = E[f(xi)] = m(xi) and covariancematrix Kij = cov(f(xi) f(xj)) = k(xi xj) are deter-mined by the mean function m and covariance function (orkernel) k of the Gaussian process

The popular RBF kernel has the form

k(xi xj) = exp

(minus 1

2`2xi minus xj2

) (7)

The length-scale hyperparameter ` controls the extent ofcorrelations between function values If ` is large samplefunctions from a GP prior are simple and slowly varyingwith inputs x

Gaussian processes with RBF kernels (as well as manyother standard kernels) assign positive density to any setof observations Moreover these models are universal ap-proximators (Rasmussen amp Williams 2006) as the numberof observations increase they are able to approximate anyfunction to arbitrary precision

Work on Gaussian processes in machine learning was trig-gered by the observation that Bayesian neural networksbecome Gaussian processes with particular kernel functionsas the number of hidden units approaches infinity (Neal1996) This result resembles recent work on the neuraltangent kernel (eg Jacot et al 2018)

C Deep Ensembles and MultiSWAG UnderDistribution Shift

In Figures 14 15 16 17 we show the negative log-likelihood for Deep Ensembles MultiSWA and MultiSWAGusing PreResNet-20 on CIFAR-10 with various corruptionsas a function of the number of independently trained models(SGD solutions SWA solutions or SWAG models respec-tively) For MultiSWAG we generate 20 samples from eachindependent SWAG model Typically MultiSWA and Mul-tiSWAG significantly outperform Deep Ensembles when asmall number of independent models is used or when thelevel of corruption is high

In Figure 18 following Ovadia et al (2019) we showthe distribution of negative log likelihood accuracy andexpected calibration error as we vary the type of corrup-tion We use a fixed training time budget 10 independentlytrained models for every method For MultiSWAG we en-semble 20 samples from each of the 10 SWAG approxima-tions MultiSWAG particularly achieves better NLL than theother two methods and MultiSWA outperforms Deep En-sembles the difference is especially pronounced for higherlevels of corruption In terms of ECE MultiSWAG againoutperforms the other two methods for higher corruptionintensities

We note that Ovadia et al (2019) found Deep Ensembles tobe a very strong baseline for prediction quality and calibra-tion under distribution shift For this reason we focus onDeep Ensembles in our comparisons

D Details of ExperimentsIn this section we provide additional details of the experi-ments presented in the paper

D1 Approximating the True Predictive Distribution

For the results presented in Figure 4 we used a network with3 hidden layers of size 10 each The network takes twoinputs x and x2 We pass both x and x2 as input to ensure

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 21: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

minus10 minus5 0 5 10minus10

minus05

00

05

10

15

20

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

0 10 20 30

minus3

minus2

minus1

0

1

2

Figure 11 Diversity of high performing functions Bottom a contour plot of the posterior log-density in the subspace containing apair of independently trained modes (as with deep ensembles) and a path of high posterior density connecting these modes In eachpanel the purple point represents a sample from the posterior in the parameter subspace Top the predictive distribution constructed fromsamples in the subspace The shaded blue area shows the 3σ-region of the predictive distribution at each of the input locations and theblue line shows the mean of the predictive distribution In each panel the purple line shows the predictive function corresponding to thesample shown in the corresponding bottom row panel For the details of the experimental setup see Section 51 of Izmailov et al (2019)

that the network can represent a broader class of functionsThe network outputs a single number y = f(x)

To generate data for the plots we used a randomly-initialized neural network of the same architecture describedabove We sampled the weights from an isotropic Gaussianwith variance 012 and added isotropic Gaussian noise withvariance 012 to the outputs

y = f(xw) + ε(x)

withw sim N (0 012 middotI) ε(x) sim N (0 012 middotI) The trainingset consists of 120 points shown in Figure 4

For estimating the ground truth we ran 200 chains of Hamil-tonian Monte Carlo (HMC) using the hamiltorch pack-age (Cobb et al 2019) We initialized each chain witha network pre-trained with SGD for 3000 steps then ranHamiltonian Monte Carlo (HMC) for 2000 steps producing200 samples

For Deep Ensembles we independently trained 50 networkswith SGD for 20000 steps each We used minus posteriorlog-density as the training loss For SVI we used a fully-factorized Gaussian approximation initialized at an SGDsolution trained for 20000 steps For all inference methodswe set prior variance to 102 and noise variance to 0022

Discrepancy with true BMA For the results presentedin panel (d) of Figure 4 we computed Wasserstein distancebetween the predictive distribution approximated with HMCand the predictive distribution for Deep Ensembles and SVIWe used the one-dimensional Wasserstein distance function1

1httpsdocsscipyorgdocscipyreferencegeneratedscipystats

from the scipy package (Virtanen et al 2020) We computedthe Wasserstein distance between marginal distributionsat each input location and averaged the results over theinput locations In the top sub-panels of panels (b) (c) ofFigure 4 we additionally visualize the marginal Wassersteindistance between the HMC predictive distribution and DeepEnsembles and SVI predictive distrbutions respectively foreach input location

D2 Deep Ensembles and MultiSWAG

We evaluate Deep Ensembles MultiSWA and MultiSWAGunder distribution shift in Section 4 Following Ovadia et al(2019) we use a PreResNet-20 network and the CIFAR-10 dataset with different types of corruptions introducedin Hendrycks amp Dietterich (2019) For training individualSGD SWA and SWAG models we use the hyper-parametersused for PreResNet-164 in Maddox et al (2019) For eachSWAG model we sample 20 networks and ensemble themSo Deep Ensembles MultiSWA and MultiSWAG are allevaluated under the same training budget Deep Ensemblesand MultiSWA also use the same test-time budget

For producing the corrupted data we used the code2 releasedby Hendrycks amp Dietterich (2019) We had issues produc-ing the data for the frost corruption type so we omit itin our evaluation and include Gaussian blur which wasnot included in the evaluation of Hendrycks amp Dietterich(2019)

wasserstein_distancehtml2httpsgithubcomhendrycksrobustness

blobmasterImageNet-Ccreate_cmake_cifar_cpy

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 22: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

098

096

097

097

097

096

099

097

097

097

097

097

098

097

097

097

097

097

098

097

097

097

097

097

098

090

092

094

096

098

100

(a) α = 002

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

089

075

083

081

081

075

090

082

079

082

083

082

089

083

085

081

079

083

089

084

081

082

085

084

088

05

06

07

08

09

10

(b) α = 01

0 1 2 4 7

MNIST Class

0

1

2

4

7

MN

IST

Cla

ss

085

071

077

076

076

071

089

080

078

079

077

080

084

080

080

076

078

080

085

081

076

079

080

081

085

05

06

07

08

09

10

(c) α = 1

10minus2 10minus1 100 101

Prior std α

5 middot 102

103

5 middot 103

104

NL

L

(d)

10 05 00 05 10 15 20

(e) α = 002

10 05 00 05 10 15 20

(f) α = 01

10 05 00 05 10 15 20

(g) α = 1

10minus2 10minus1 100 101

Prior std α

1

5

20

90

Cla

ssifi

cati

onE

rror

(h)

Figure 12 (a)ndash(c) Average pairwise prior correlations for pairs of objects in classes 0 1 2 4 7 of MNIST induced by LeNet-5for p(f(xw)) when p(w) = N (0 α2I) Images in the same class have higher prior correlations than images from different classessuggesting that p(f(xw)) has desirable inductive biases The correlations slightly decrease with increases in α Panels (e)ndash(g) showsample functions from LeNet-5 along the direction connecting a pair of MNIST images of 0 and 1 digits The complexity of the samplesincreases with α (d) NLL and (h) classification error of an ensemble of 20 SWAG samples on MNIST as a function of α using aLeNet-5 The NLL is high for overly small α and near-optimal for larger values with an optimum near α = 03

D3 Neural Network Priors

In the main text we considered different properties of theprior distribution over functions induced by a sphericalGaussian distribution over the weights with different vari-ance scales

Prior correlation diagrams In panels (a)ndash(c) of Figure12 we show pairwise correlations of the logits for differ-ent pairs of datapoints To make these plots we produceS = 100 samples of the weights wi of a LeNet-5 from theprior distribution N (0 α2I) and compute the logits corre-sponding to class 0 for each data point and each weightsample We then compute the correlations for each pair xxprime of data points as follows

corrlogit(x xprime) =sumS

i=1(f(xwi)minus f(x))(f(xprime wi)minus f(xprime))radicsumSi=1(f(xwi)minus f(x))2 middotsumS

i=1(f(xprime wi)minus f(xprime))2

where f(xw) is the logit corresponding to class 0 of thenetwork with weights w on the input x and f(x) is themean value of the logit f(x) = 1

S

sumi f(xwi) For evalu-

ation we use 200 random datapoints per class for classes0 1 2 4 7 (a total of 1000 datapoints) We use this setof classes to ensure that the structure is clearly visible inthe figure We combine the correlations into a diagramadditionally showing the average correlation for each pairof classes We repeat the experiment for different values

of α isin 002 01 1 For a discussion of the results seeSection 52

Sample functions In panels (e)ndash(g) of Figure 12 we vi-sualize the functions sampled from the LeNet-5 networkalong the direction connecting a pair of MNIST imagesIn particular we take a pair of images x0 and x1 of dig-its 0 and 1 respectively and construct the path x(t) =t middot x0 + (1 minus t) middot x1 We then study the samples of thelogits z(t) = f(x(t) middot x0x(t) w) along the path herewe adjusted the norm of the images along the path to beconstant as the values of the logits are sensitive to the normof the inputs The complexity of the samples increases aswe increase the variance of the prior distribution over theweights This increased complexity of sample functions ex-plains why we might expect the prior correlations for pairsof images to be lower when we increase the variance of theprior distribution over the weights

Performance dependence on prior variance In panels(d) (h) of Figure 12 we show the test negative log-likelihoodand accuracy of SWAG applied to LeNet-5 on MNIST Wetrain the model for 50 epochs constructing the rank-20SWAG approximation from the last 25 epochs We usean initial learning rate of 005 and SWAG learning rateof 001 with the learning rate schedule of Maddox et al(2019) We use posterior log-density as the objective andvary the prior variance α2 In panels (f) (g) of Figure 9

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 23: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

we perform an analogous experiment using a PreResNet-20and a VGG-16 on CIFAR-10 using the hyper-parametersreported in Maddox et al (2019) (for PreResNet-20 we usethe hyper-parameters used with PreResNet-164 in Maddoxet al (2019)) Both on MNIST and CIFAR-10 we observethat the performance is poor for overly small values of αclose to optimal for intermediate values and still reasonablefor larger values of α For further discussion of the resultssee Section 53

Predictions from prior samples Following Wenzel et al(2020) we study the predictive distributions of prior samplesusing PreResNet-20 on CIFAR-10 In Figure 9 we showthe sample predictive functions averaged over datapoints fordifferent scales α of the prior distribution We also showthe predictive distribution for each α which is the averageof the sample predictive distributions over 200 samples ofweights In Figure 10 we show how the predictive distribu-tion changes as we vary the number of observed data forprior scale α =

radic10 We see that the marginal predictive

distribution for all considered values of α is reasonable mdashroughly uniform across classes when averaged across thedataset For the latter experiment we used stochastic gra-dient Langevin dynamics (SGLD) (Welling amp Teh 2011)with a cosine lerning rate schedule For each sample werestart SGLD and we only use the sample obtained at thelast iteration We discuss the results in Section 84

Prior correlations with corrupted images In Section Fand Figure 19 we study the decay of the prior correlationsbetween logits on an original image and a perturbed imageas we increase the intensity of perturbations For the BNNwe use PreResNet-20 architecture with the standard Gaus-sian prior N (0 I) For the linear model the correlationsare not affected by the prior variance α2

cov(wTxwT y) = E(wTx middot wT y) =

ExTwwT y = xTEwwT y = α2xT y

and hence

corr(wTxwT y) =

cov(wTxwT y)radiccov(wT y wT y) middot cov(wTxwTx)

= xT y

We use theN (0 I) prior for the weights of the linear modelFinally we also evaluate the correlations associated withan RBF kernel (see Equation (7)) To set the lengthscale` of the kernel we evaluate the pairwise correlations forthe PreResnet-20 and RBF kernel on the 100 uncorruptedCIFAR-10 images that were used for the experiment andensure that the average correlations match The resultingvalue of ` is 10000 and the average correlation for the RBFkernel and PreResNet was asymp 09 for the linear model the

average correlation was asymp 082 For the perturbations weused the same set of corruptions introduced in Hendrycks ampDietterich (2019) as in the experiments in Section 4 with theaddition of a random translation for a random translationof intensity i we pad the image with 2 middot i zeros on each sideand crop the image randomly to 32times 32

D4 Rethinking Generalization

In Section 6 we experiment with Bayesian neural networksand Gaussian processes on CIFAR-10 with noisy labelsinspired by the results in Zhang et al (2016) that suggest weneed to re-think generalization to understand deep learning

Following Zhang et al (2016) we train PreResNet-20 onCIFAR-10 with different fractions of random labels To en-sure that the networks fits the train data we turn off weightdecay and data augmentation and use a lower initial learn-ing rate of 001 Otherwise we follow the hyper-parametersthat were used with PreResNet-164 in Maddox et al (2019)We use diagonal Laplace approximation to compute an esti-mate of marginal likelihood for each level of label corrup-tion Following Ritter et al (2018) we use the diagonal ofthe Fisher information matrix rather than the Hessian

We perform a similar experiment with a Gaussian processwith RBF kernel on the binary classification problem fortwo classes of CIFAR-10 We use variational inference to fitthe model and we use the variational evidence lower boundto approximate the marginal likelihood We use variationalinference to overcome the non-Gaussian likelihood and notfor scalability reasons ie we are not using inducing inputsWe use the GPyTorch package (Gardner et al 2018) totrain the models We use an RBF kernel with default ini-tialization from GPyTorch and divide the inputs by 5000to get an appropriate input scale We train the model on abinary classification problem between classes 0 and 1

For the 10-class GP classification experiment we train 10one-vs-all models that classify between a given class andthe rest of the data To reduce computation in trainingwe subsample the data not belonging to the given class to10k datapoints so each model is trained on a total of 15kdatapoints We then combine the 10 models into a singlemulti-class model an observation is attributed to the classthat corresponds to the one-vs-all model with the highestconfidence We use the same hyper-parameters as in thebinary classification experiments

D5 Double Descent

In Section 7 we evaluate SGD SWAG and MultiSWAG formodels of varying width Following Nakkiran et al (2019)we use ResNet-18 on CIFAR-100 we consider original la-bels 10 and 20 label corruption For networks of everywidth we reuse the hyper-paramerers used for PreResNet-

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 24: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

10 20 30 40 50

ResNet-18 width

250

275

300

325

350

375

400

425

450

Tes

tE

rror

()

SGD

SWAG

Multi-SWAG

(a) 10 Corrupted (Err)

10 20 30 40 50

ResNet-18 width

10

12

14

16

18

20

22

Tes

tN

LL

(b) 10 Corrupted (NLL)

10 20 30 40 50

ResNet-18 width

12

14

16

18

20

Mul

tiS

WA

GT

est

NL

L

SWAG Models

1 3 5 10

(c) 20 Corrupted ( Models)

Figure 13 Double Descent (a) Test error and (b) NLL loss for ResNet-18 with varying width on CIFAR-100 for SGD SWAG andMultiSWAG when 10 of the labels are randomly reshuffled MultiSWAG alleviates double descent both on the original labels and underlabel noise both in accuracy and NLL (e) Test NLLs for MultiSWAG with varying number of independent models under 20 labelcorruption NLL monotonically decreases with increased number of independent models alleviating double descent

164 in Maddox et al (2019) For original labels and 10label corruption we use 5 independently trained SWAG mod-els with MultiSWAG and for 20 label corruption we use10 models for 20 label corruption we also show perfor-mance varying the number of independent models in Figures8(e) and 13(c) Both for SWAG and MultiSWAG we use anensemble of 20 sampled models from each of the SWAG so-lutions for example for MultiSWAG with 10 independentSWAG solutions we use an ensemble of 200 networks

E Analysis of Prior Variance EffectIn this section we provide simple analytic results for theeffect of prior variance in ReLU networks A related deriva-tion is presented in the Appendix Section A8 of Garipovet al (2018) about connecting paths from symmetries inparametrization

We will consider a multilayer network f(xw) of the form

f(x Wibini=1) =

Wn( φ(W2φ(W1x+ b1) + b2)) + bn

where φ is the ReLU (or in fact any positively-homogeneousactivation function) Wi are weight matrices and bi are biasvectors In particular f can be a regular CNN with ReLU ac-tivations up to the logits (with softmax activation removed)

Now suppose we have a prior distribution of the form

Wi sim N (0 α2i I) bi sim N (0 β2

i I)

where the identity matrices I are implicitly assumed to beof appropriate shapes so each weight matrix and bias vectorhas a spherical Gaussian distribution We can reparameter-ize this distribution as

Wi = αiEi Ei sim N (0 I)

bi = βiεi εi sim N (0 I)

We can then express the predictions of the network on theinput x for weights sampled from the prior as the random

variable

f(x αi βini=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn(8)

Through Equation (8) we can observe some simple proper-ties of the dependence between the prior scales αi βi andthe induced function-space prior

Proposition 1 Suppose the network has no bias vectorsie β1 = = βn = 0 Then the scales αi of the priordistribution over the weights only affect the output scale ofthe network

Proof In the case when there are no bias vectors Equation(8) simplifies to

f(xαi βi = 0ni=1) =

αn middot En( φ(α1E1x+ β1 middot ε1)) + βn middot εn =

αn middot middot α1 middot En( φ(E1x)) =

αn middot middot α1 middot f(x αi = 1 βi = 0ni=1)

In the derivation above we used positive homogeneity ofReLU φ(αz) = αφ(z) for any positive α

In other words to sample from the distribution over func-tions corresponding to a prior with variances αi βi =0ni=1 we can sample from the spherical Gaussian prior(without bias terms) αi = 1 βi = 0ni=1 and then rescalethe outputs of the network by the product of variancesαn middot middot α2 middot α1

We note that the result above is different from the resultsfor sigmoid networks considered in MacKay (1995) wherevarying the prior on the weights leads to changing the length-scale of the sample functions For ReLU networks withoutbiases increasing prior variance only increases the outputscale of the network and not the complexity of the samplesIf we apply the softmax activation on the outputs of thelast layer of such network we will observe increasingly

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 25: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

confident predictions as we increase the prior variance Weobserve this effect in Figure 9 and discuss it in Section 84

In case bias vectors are present we can obtain a similarresult using a specific scaling of the prior variances withlayer as in the following proposition

Proposition 2 Suppose the prior scales depend on thelayer of the network as follows for some γ gt 0

αi = γ βi = γi

for all layers i = 1 n Then γ only affects the scale ofthe predictive distribution at any input x

f(xαi = γ βi = γini=1) = γn middot f(x αi = 1 βi = 1ni=1)

Proof The proof is analogous to the proof of Proposition 1We can use the positive homogenety of ReLU activations tofactor the prior scales outside of the network

f(xαi = γ βi = γini=1) =

γ middot En( φ(γ middot E1x+ γ middot ε1)) + γn middot εn =

γn middot(En( φ(E1x+ ε1))) + εn

)=

γn middot f(x αi = 1 βi = 1ni=1)

The analysis above can be applied to other simple scalingrules of the prior eg

f(xαi = γαi βi = γiβini=1) =

γn middot f(x αi = αi βi = βini=1)(9)

can be shown completely analogously to Proposition 2

More general types of scaling of the prior affect both theoutput scale of the network and also the relative effect ofprior and variance terms For example by Equation (9) wehave

f(xαi = γ βi = γni=1) =

f(xαi = γ middot 1 βi = γi middot γ1minusini=1) =

γn middot f(x αi = 1 βi = γ1minusini=1)

We note that the analysis does not cover residual connectionsand batch normalization so it applies to LeNet-5 but cannotbe directly applied to PreResNet-20 networks used in manyof our experiments

F Prior Correlation Structure underPerturbations

In this section we explore the prior correlations between thelogits on different pairs of datapoints induced by a spherical

Gaussian prior on the weights of a PreResNet-20 We sam-ple a 100 random images from CIFAR-10 (10 from eachclass) and apply 17 different perturbations introduced byHendrycks amp Dietterich (2019) at 5 different levels of in-tensity We then compute correlations between the logitsf(xw) for the original image x and f(x w) for the cor-rupted image x as we sample the weights of the networkfrom the prior w sim N (0 I)

In Figure 19 we show how the correlations decay with per-turbation intensity For reference we also show how thecorrelations decay for a linear model and for an RBF kernelFor the RBF kernel we set the lengthscale so that the averagecorrelations on the uncorrupted datapoints match those of aPreResNet-20 Further experimental details can be found inAppendix D3

For all types of corruptions except saturate snow fog andbrightness the PreResNet logits decay slower compared tothe RBF kernel and linear model It appears that the priorsamples are sensitive to corruptions that alter the bright-ness or more generally the colours in the image For manytypes of corruptions (such as eg Gaussian Noise) the priorcorrelations for PreResNet are close to 1 for all levels ofcorruption

Overall these results indicate that the prior over functionsinduced by a vague prior over parameters w in combinationwith a PreResNet has useful equivariance properties beforeseeing data the model treats images of the same class ashighly correlated even after an image has undergone signif-icant perturbations representative of perturbations we oftensee in the real world These types of symmetries are a largepart of what makes neural networks a powerful model classfor high dimensional natural signals

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 26: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

NLL

1 2 3 4 5 6 7 8 9 10 Models

17500

20000

22500

25000

27500

30000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

55000

Deep Ensembles MultiSWA MultiSWAG

(a) Gaussian Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

NLL

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

20000

22500

25000

27500

30000

32500

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(b) Impulse Noise

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

NLL

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

1 2 3 4 5 6 7 8 9 10 Models

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(c) Shot Noise

Figure 14 Noise Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 27: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

140001600018000200002200024000

Deep Ensembles MultiSWA MultiSWAG

(a) Defocus Blur

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

24000

NLL

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

30000

1 2 3 4 5 6 7 8 9 10 Models

18000

20000

22000

24000

26000

28000

Deep Ensembles MultiSWA MultiSWAG

(b) Glass Blur

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

Deep Ensembles MultiSWA MultiSWAG

(c) Motion Blur

1 2 3 4 5 6 7 8 9 10 Models

3500400045005000550060006500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

13000

1 2 3 4 5 6 7 8 9 10 Models

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

24000

Deep Ensembles MultiSWA MultiSWAG

(d) Zoom Blur

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

1 2 3 4 5 6 7 8 9 10 Models

12000

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

20000

25000

30000

35000

40000

Deep Ensembles MultiSWA MultiSWAG

(e) Gaussian Blur

Figure 15 Blur Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 28: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

1 2 3 4 5 6 7 8 9 10 Models

25000

30000

35000

40000

45000

50000

Deep Ensembles MultiSWA MultiSWAG

(a) Contrast

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

6500

Deep Ensembles MultiSWA MultiSWAG

(b) Saturate

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

NLL

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

5000

1 2 3 4 5 6 7 8 9 10 Models

4000

5000

6000

7000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(c) Elastic Transform

1 2 3 4 5 6 7 8 9 10 Models

2250

2500

2750

3000

3250

3500

3750

NLL

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

14000

16000

18000

20000

22000

1 2 3 4 5 6 7 8 9 10 Models

24000

26000

28000

30000

32000

34000

Deep Ensembles MultiSWA MultiSWAG

(d) Pixelate

1 2 3 4 5 6 7 8 9 10 Models

4000

4500

5000

5500

6000

6500

NLL

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

11000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

13000

Deep Ensembles MultiSWA MultiSWAG

(e) JPEG Compression

Figure 16 Digital Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 29: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000NL

L

1 2 3 4 5 6 7 8 9 10 Models

7000

8000

9000

10000

1 2 3 4 5 6 7 8 9 10 Models

5000

6000

7000

8000

1 2 3 4 5 6 7 8 9 10 Models

6000

7000

8000

9000

1 2 3 4 5 6 7 8 9 10 Models

8000

9000

10000

11000

12000

Deep Ensembles MultiSWA MultiSWAG

(a) Snow

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2500

3000

3500

4000

1 2 3 4 5 6 7 8 9 10 Models

3500

4000

4500

5000

5500

6000

1 2 3 4 5 6 7 8 9 10 Models

8000

10000

12000

14000

16000

18000

Deep Ensembles MultiSWA MultiSWAG

(b) Fog

1 2 3 4 5 6 7 8 9 10 Models

1600

1800

2000

2200

2400

2600

2800

NLL

1 2 3 4 5 6 7 8 9 10 Models

1600180020002200240026002800

1 2 3 4 5 6 7 8 9 10 Models

1750

2000

2250

2500

2750

3000

3250

1 2 3 4 5 6 7 8 9 10 Models

2000225025002750300032503500

1 2 3 4 5 6 7 8 9 10 Models

3000

3500

4000

4500

Deep Ensembles MultiSWA MultiSWAG

(c) Brightness

Figure 17 Weather Corruptions Negative log likelihood on CIFAR-10 with a PreResNet-20 for Deep Ensembles MultiSWAG andMultiSWA as a function of the number of independently trained models for different types of corruption and corruption intensity(increasing from left to right)

1 2 3 4 5Corruption Intensity

0

5000

10000

15000

20000

25000

30000

35000

NLL

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

02

03

04

05

06

07

08

09

ACcu

racy

Deep Ensembles MultiSWA MultiSWAG

1 2 3 4 5Corruption Intensity

00

01

02

03

04

ECE

Deep Ensembles MultiSWA MultiSWAG

Figure 18 Negative log likelihood accuracy and expected calibration error distribution on CIFAR-10 with a PreResNet-20 for DeepEnsembles MultiSWAG and MultiSWA as a function of the corruption intensity Following Ovadia et al (2019) we summarize the resultsfor different types of corruption with a boxplot For each method we use 10 independently trained models and for MultiSWAG wesample 20 networks from each model As in Figures 5 11-14 there are substantial differences between these three methods which arehard to see due to the vertical scale on this plot MultiSWAG particularly outperforms Deep Ensembles and MultiSWA in terms of NLLand ECE for higher corruption intensities

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines

Page 30: Bayesian Deep Learning and a Probabilistic Perspective of ... · Andrew Gordon Wilson Pavel Izmailov New York University Abstract The key distinguishing property of a Bayesian ap-proach

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

0 1 2 3 4 5Corruption Intensity

065070075080085090095100

Corre

latio

n

Translation

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Gaussian Noise

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Impulse Noise

0 1 2 3 4 5Corruption Intensity

0975

0980

0985

0990

0995

1000

Corre

latio

n

Shot Noise

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

n

Defocus Blur

0 1 2 3 4 5Corruption Intensity

094

095

096

097

098

099

100

Corre

latio

n

Glass Blur

0 1 2 3 4 5Corruption Intensity

0970

0975

0980

0985

0990

0995

1000Co

rrela

tion

Motion Blur

0 1 2 3 4 5Corruption Intensity

09650970097509800985099009951000

Corre

latio

n

Zoom Blur

0 1 2 3 4 5Corruption Intensity

0992

0994

0996

0998

1000

Corre

latio

n

Gaussian Blur

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Contrast

0 1 2 3 4 5Corruption Intensity

092

094

096

098

100

Corre

latio

n

Saturate

0 1 2 3 4 5Corruption Intensity

093094095096097098099100

Corre

latio

n

Elastic

0 1 2 3 4 5Corruption Intensity

0990

0992

0994

0996

0998

1000

Corre

latio

nPixelate

0 1 2 3 4 5Corruption Intensity

099209930994099509960997099809991000

Corre

latio

n

JPEG

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Snow

0 1 2 3 4 5Corruption Intensity

090

092

094

096

098

100

Corre

latio

n

Fog

0 1 2 3 4 5Corruption Intensity

088

090

092

094

096

098

100

Corre

latio

n

Brightness

PreResNet-20 RBF Linear

Figure 19 Prior correlations under corruption Prior correlations between predictions (logits) for PreResNet-20 Linear Model andRBF kernel on original and corrupted images as a function of corruption intensity for different types of corruptions The lengthscale of theRBF kernell is calibrated to produce similar correlations to PreResNet on uncorrupted datapoints We report the mean correlation valuesover 100 different images and show the 1σ error bars with shaded regions For all corruptions except Snow Saturate Fog and Brightnessthe correlations decay slower for PreResNet compared to baselines