Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of...

26
Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu 1 , Michael C. Hughes 2 , Sonali Parbhoo 3 , Maurizio Zazzi 4 , Volker Roth 3 , and Finale Doshi-Velez 2 1 Stanford University, [email protected] 2 Harvard University SEAS, [email protected], fi[email protected] 3 University of Basel, {sonali.parbhoo,volker.roth}@unibas.ch 4 University of Siena, [email protected] Abstract The lack of interpretability remains a key barrier to the adop- tion of deep models in many applications. In this work, we explicitly regularize deep models so human users might step through the process behind their predictions in little time. Specifically, we train deep time-series models so their class- probability predictions have high accuracy while being closely modeled by decision trees with few nodes. Using intuitive toy examples as well as medical tasks for treating sepsis and HIV, we demonstrate that this new tree regularization yields models that are easier for humans to simulate than simpler L1 or L2 penalties without sacrificing predictive power. 1 Introduction Deep models have become the de-facto approach for pre- diction in a variety of applications such as image classifi- cation (e.g. (Krizhevsky, Sutskever, and Hinton 2012)) and machine translation (e.g. (Bahdanau, Cho, and Bengio 2014; Sutskever, Vinyals, and Le 2014)). However, many practi- tioners are reluctant to adopt deep models because their pre- dictions are difficult to interpret. In this work, we seek a spe- cific form of interpretability known as human-simulability. A human-simulatable model is one in which a human user can “take in input data together with the parameters of the model and in reasonable time step through every calculation required to produce a prediction” (Lipton 2016). For exam- ple, small decision trees with only a few nodes are easy for humans to simulate and thus understand and trust. In con- trast, even simple deep models like multi-layer perceptrons with a few dozen units can have far too many parameters and connections for a human to easily step through. Deep models for sequences are even more challenging. Of course, decision trees with too many nodes are also hard to simulate. Our key research question is: can we create deep models that are well-approximated by compact, human-simulatable models? The question of creating accurate yet human-simulatable models is an important one, because in many domains sim- ulatability is paramount. For example, despite advances in deep learning for clinical decision support (e.g. (Miotto et al. A version of this work will appear in AAAI 2018 (https:// aaai.org/Conferences/AAAI-18/). This paper includes an extended appendix with supplementary material. 2016; Choi et al. 2016; Che et al. 2015)), the clinical com- munity remains skeptical of machine learning systems (Chen and Asch 2017). Simulatability allows clinicians to audit pre- dictions easily. They can manually inspect changes to outputs under slightly-perturbed inputs, check substeps against their expert knowledge, and identify when predictions are made due to systemic bias in the data rather than real causes. Sim- ilar needs for simulatability exist in many decision-critical domains such as disaster response or recidivism prediction. To address this need for interpretability, a number of works have been developed to assist in the interpretation of already- trained models. Craven and Shavlik (1996) train decision trees that mimic the predictions of a fixed, pretrained neural network, but do not train the network itself to be simpler. Other post-hoc interpretations typically typically evaluate the sensitivity of predictions to local perturbations of inputs or the input gradient (Ribeiro, Singh, and Guestrin 2016; Sel- varaju et al. 2016; Adler et al. 2016; Lundberg and Lee 2016; Erhan et al. 2009). In parallel, research efforts have empha- sized that simple lists of (perhaps locally) important features are not sufficient: Singh, Ribeiro, and Guestrin (2016) pro- vide explanations in the form of programs; Lakkaraju, Bach, and Leskovec (2016) learn decision sets and show benefits over other rule-based methods. These techniques focus on understanding already learned models, rather than finding models that are more interpretable. However, it is well-known that deep models often have mul- tiple optima of similar predictive accuracy (Goodfellow, Ben- gio, and Courville 2016), and thus one might hope to find more interpretable models with equal predictive accuracy. However, the field of optimizing deep models for interpretabil- ity remains nascent. Ross, Hughes, and Doshi-Velez (2017) penalize input sensitivity to features marked as less relevant. Lei, Barzilay, and Jaakkola (2016) train deep models that make predictions from text and simultaneously highlight con- tiguous subsets of words, called a “rationale,” to justify each prediction. While both works optimize their deep models to expose relevant features, lists of features are not sufficient to simulate the prediction. Contributions. In this work, we take steps toward optimiz- ing deep models for human-simulatability via a new model complexity penalty function we call tree regularization. Tree regularization favors models whose decision boundaries can arXiv:1711.06178v1 [stat.ML] 16 Nov 2017

Transcript of Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of...

Page 1: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

Beyond Sparsity: Tree Regularization of Deep Models for Interpretability

Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3,Maurizio Zazzi4, Volker Roth3, and Finale Doshi-Velez2

1Stanford University, [email protected] University SEAS, [email protected], [email protected]

3University of Basel, sonali.parbhoo,[email protected] of Siena, [email protected]

Abstract

The lack of interpretability remains a key barrier to the adop-tion of deep models in many applications. In this work, weexplicitly regularize deep models so human users might stepthrough the process behind their predictions in little time.Specifically, we train deep time-series models so their class-probability predictions have high accuracy while being closelymodeled by decision trees with few nodes. Using intuitive toyexamples as well as medical tasks for treating sepsis and HIV,we demonstrate that this new tree regularization yields modelsthat are easier for humans to simulate than simpler L1 or L2penalties without sacrificing predictive power.

1 IntroductionDeep models have become the de-facto approach for pre-diction in a variety of applications such as image classifi-cation (e.g. (Krizhevsky, Sutskever, and Hinton 2012)) andmachine translation (e.g. (Bahdanau, Cho, and Bengio 2014;Sutskever, Vinyals, and Le 2014)). However, many practi-tioners are reluctant to adopt deep models because their pre-dictions are difficult to interpret. In this work, we seek a spe-cific form of interpretability known as human-simulability.A human-simulatable model is one in which a human usercan “take in input data together with the parameters of themodel and in reasonable time step through every calculationrequired to produce a prediction” (Lipton 2016). For exam-ple, small decision trees with only a few nodes are easy forhumans to simulate and thus understand and trust. In con-trast, even simple deep models like multi-layer perceptronswith a few dozen units can have far too many parameters andconnections for a human to easily step through. Deep modelsfor sequences are even more challenging. Of course, decisiontrees with too many nodes are also hard to simulate. Ourkey research question is: can we create deep models that arewell-approximated by compact, human-simulatable models?

The question of creating accurate yet human-simulatablemodels is an important one, because in many domains sim-ulatability is paramount. For example, despite advances indeep learning for clinical decision support (e.g. (Miotto et al.

A version of this work will appear in AAAI 2018 (https://aaai.org/Conferences/AAAI-18/). This paper includesan extended appendix with supplementary material.

2016; Choi et al. 2016; Che et al. 2015)), the clinical com-munity remains skeptical of machine learning systems (Chenand Asch 2017). Simulatability allows clinicians to audit pre-dictions easily. They can manually inspect changes to outputsunder slightly-perturbed inputs, check substeps against theirexpert knowledge, and identify when predictions are madedue to systemic bias in the data rather than real causes. Sim-ilar needs for simulatability exist in many decision-criticaldomains such as disaster response or recidivism prediction.

To address this need for interpretability, a number of workshave been developed to assist in the interpretation of already-trained models. Craven and Shavlik (1996) train decisiontrees that mimic the predictions of a fixed, pretrained neuralnetwork, but do not train the network itself to be simpler.Other post-hoc interpretations typically typically evaluate thesensitivity of predictions to local perturbations of inputs orthe input gradient (Ribeiro, Singh, and Guestrin 2016; Sel-varaju et al. 2016; Adler et al. 2016; Lundberg and Lee 2016;Erhan et al. 2009). In parallel, research efforts have empha-sized that simple lists of (perhaps locally) important featuresare not sufficient: Singh, Ribeiro, and Guestrin (2016) pro-vide explanations in the form of programs; Lakkaraju, Bach,and Leskovec (2016) learn decision sets and show benefitsover other rule-based methods.

These techniques focus on understanding already learnedmodels, rather than finding models that are more interpretable.However, it is well-known that deep models often have mul-tiple optima of similar predictive accuracy (Goodfellow, Ben-gio, and Courville 2016), and thus one might hope to findmore interpretable models with equal predictive accuracy.However, the field of optimizing deep models for interpretabil-ity remains nascent. Ross, Hughes, and Doshi-Velez (2017)penalize input sensitivity to features marked as less relevant.Lei, Barzilay, and Jaakkola (2016) train deep models thatmake predictions from text and simultaneously highlight con-tiguous subsets of words, called a “rationale,” to justify eachprediction. While both works optimize their deep models toexpose relevant features, lists of features are not sufficient tosimulate the prediction.

Contributions. In this work, we take steps toward optimiz-ing deep models for human-simulatability via a new modelcomplexity penalty function we call tree regularization. Treeregularization favors models whose decision boundaries can

arX

iv:1

711.

0617

8v1

[st

at.M

L]

16

Nov

201

7

Page 2: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

be well-approximated by small decision-trees, thus penaliz-ing models that would require many calculations to simulatepredictions. We first demonstrate how this technique can beused to train simple multi-layer perceptrons to have tree-likedecision boundaries. We then focus on time-series appli-cations and show that gated recurrent unit (GRU) modelstrained with strong tree-regularization reach a high-accuracy-at-low-complexity sweet spot that is not possible with anystrength of L1 or L2 regularization. Prediction quality canbe further boosted by training new hybrid models – GRU-HMMs – which explain the residuals of interpretable discreteHMMs via tree-regularized GRUs. We further show thatthe approximate decision trees for our tree-regularized deepmodels are useful for human simulation and interpretability.We demonstrate our approach on a speech recognition taskand two medical treatment prediction tasks for patients withsepsis in the intensive care unit (ICU) and for patients withhuman immunodeficiency virus (HIV). Throughout, we alsoshow that standalone decision trees as a baseline are notice-ably less accurate than our tree-regularized deep models. Wehave released an open-source Python toolbox to allow othersto experiment with tree regularization 1.

Related work. While there is little work (as mentionedabove) on optimizing models for interpretability, there aresome related threads. The first is model compression, whichtrains smaller models that perform similarly to large, black-box models (e.g. (?; Hinton, Vinyals, and Dean 2015;Balan et al. 2015; Han et al. 2015)). Other efforts specifi-cally train very sparse networks via L1 penalties (Zhang, Lee,and Jordan 2016) or even binary neural networks (Tang, Hua,and Wang 2017; Rastegari et al. 2016) with the goal of fastercomputation. Edge and node regularization is commonly usedto improve prediction accuracy (Drucker and Le Cun 1992;Ochiai et al. 2017), and recently Hu et al. (2016) improveprediction accuracy by training neural networks so that pre-dictions match a small list of known domain-specific first-order logic rules. Sometimes, these regularizations—whichall smooth or simplify decision boundaries—can have theeffect of also improving interpretability. However, there isno guarantee that these regularizations will improve inter-pretability; we emphasize that specifically training deep mod-els to have easily-simulatable decision boundaries is (to ourknowledge) novel.

2 Background and NotationWe consider supervised learning tasks given datasets of Nlabeled examples, where each example (indexed by n) hasan input feature vectors xn and a target output vector yn. Weshall assume the targets yn are binary, though it is simple toextend to other types. When modeling time series, each ex-ample sequence n contains Tn timesteps indexed by t whicheach have a feature vector xnt and an output ynt. Formally,we write: xn = [xn1 . . . xnTn

] and yn = [yn1 . . . ynTn]. Each

value ynt could be prediction about the next timestep (e.g. thecharacter at time t+ 1) or some other task-related annotation(e.g. if the patient became septic at time t).

1https://github.com/dtak/tree-regularization-public

Simple neural networks. A multi-layer perceptron (MLP)makes predictions yn of the target yn via a functionyn(xn,W ), where the vector W represents all parameters ofthe network. Given a data set (xn, yn), our goal is to learnthe parameters W to minimize the objective

minW

λΨ(W ) +

N∑n=1

loss(yn, yn(xn,W )) (1)

For binary targets yn, the logistic loss (binary cross entropy)is an effective choice. The regularization term Ψ(W ) canrepresent L1 or L2 penalties (e.g. (Drucker and Le Cun 1992;Goodfellow, Bengio, and Courville 2016; Ochiai et al. 2017))or our new regularization.

Recurrent Neural Networks with Gated Recurrent Units.A recurrent neural network (RNN) takes as input an arbitrarylength sequence xn = [xn1 . . . xnTn

] and produces a “hiddenstate” sequence hn = [hn1 . . . hnTn

] of the same length asthe input. Each hidden state vector at timestep t represents alocation in a (possibly low-dimensional) “state space” withKdimensions: hnt ∈ RK . RNNs perform sequential nonlinearembedding of the form hnt = f(xnt, hnt−1) in hope thatthe state space location hnt is a useful summary statistic formaking predictions of the target ynt at timestep t.

Many different variants of the transition function architec-ture f have been proposed to solve the challenge of capturinglong-term dependencies. In this paper, we use gated recurrentunits (GRUs) (Cho et al. 2014), which are simpler than otheralternatives such as long short-term memory units (LSTMs)(Hochreiter and Schmidhuber 1997). While GRUs are con-venient, any differentiable RNN architecture is compatiblewith our new tree-regularization approach.

Below we describe the evolution of a single GRU sequence,dropping the sequence index n for readability. The GRU tran-sition function f produces the state vector ht = [ht1 . . . htK ]from a previous state ht−1 and an input vector xt, via thefollowing feed-forward architecture:

output state : htk = (1− ztk)ht−1,k + zt,khtk (2)

candidate state : htk = tanh(V hk xt + Uhk (rt ht−1))

update gate : ztk = σ(V zk xt + Uzkht−1)

reset gate : rtk = σ(V rk xt + Urkht−1)

The internal network nodes include candidate state gatesh, update gates z and reset gates r which have the samecardinalty as the state vector h. Reset gates allow the networkto forget past state vectors when set near zero via the logisticsigmoid nonlinearity σ(·). Update gates allow the network toeither pass along the previous state vector unchanged or usethe new candidate state vector instead. This architecture isdiagrammed in Figure 1.

The predicted probability of the binary label yt for time tis a sigmoid transformation of the state at time t:

yt = σ(wTht) (3)

Here, weight vector w ∈ RK represents the parameters ofthis output layer. We denote the parameters for the entire

2

Page 3: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

1-

sigm sigm tanh

sigmht-1

ht

yt

rt zth~t

x t-1

λ

Figure 1: Diagram of gated recurrent unit (GRU) used foreach timestep our neural time-series model. The orange trian-gle indicates the predicted output yt at time t.

GRU-RNN model as W = (w,U, V ), concatenating all com-ponent parameters. We can train GRU-RNN time-series mod-els (hereafter often just called GRUs) via the following lossminimization objective:

minW

λΨ(W ) +

N∑n=1

Tn∑n=1

loss(ynt, ynt(xn,W )) (4)

where again Ψ(W ) defines a regularization cost.

3 Tree Regularization for Deep ModelsWe now propose a novel tree regularization function Ω(W )for the parameters of a differentiable model which attemptsto penalize models whose predictions are not easily simu-latable. Of course, it is difficult to measure “simulatability”directly for an arbitrary network, so we take inspiration fromdecision trees. Our chosen method has two stages: first, finda single binary decision tree which accurately reproducesthe network’s thresholded binary predictions yn given inputxn. Second, measure the complexity of this decision tree asthe output of Ω(W ). We measure complexity as the aver-age decision path length—the average number of decisionnodes that must be touched to make a prediction for an inputexample xn. We compute the average with respect to somedesignated reference dataset of example inputs D = xnfrom the training set. While many ways to measure complex-ity exist, we find average path length is most relevant to ournotion of simulatability. Remember that for us, human simu-lation requires stepping through every calculation requiredto make a prediction. Average path length exactly counts thenumber of true-or-false boolean calculations needed to makean average prediction, assuming the model is a decision tree.Total number of nodes could be used as a metric, but mightpenalize more accurate trees that have short paths for mostexamples but need more involved logic for few outliers.

Our true-average-path-length cost function Ω(W ) is de-tailed in Alg. 1. It requires two subroutines, TRAINTREE andPATHLENGTH. TRAINTREE trains a binary decision tree toaccurately reproduce the provided labeled examples xn, yn.We use the DecisionTree module distributed in Python’sscikit-learn (Pedregosa et al. 2011) with post-pruning to sim-plify the tree. These trees can give probabilistic predictions ateach leaf. (Complete decision-tree training details are in the

Algorithm 1 Average-Path-Length Cost Function

Require:y(·,W ) : binary prediction function, with parameters WD = xnNn=1 : reference dataset with N examples

1: function Ω(W )2: tree← TRAINTREE(xn, y(xn,W ))3: return 1

N

∑n PATHLENGTH(tree, xn)

supplement.) Next, PATHLENGTH counts how many nodesare needed to make a specific input to an output node in theprovided decision tree. In our evaluations, we will apply ouraverage-decision-tree-path-length regularization, or simply“tree regularization,” to several neural models.Alg. 1 defines our average-path-length cost function Ω(W ),which can be plugged into the abstract regularization termΨ(W ) in the objectives in equations 1 and 4.

Making the Decision-Tree Loss Differentiable Trainingdecision trees is not differentiable, and thus Ω(W ) as definedin Alg. 1 is not differentiable with respect to the networkparameters W (unlike standard regularizers such as the L1or L2 norm). While one could resort to derivative-free opti-mization techniques (Audet and Kokkolaras 2016), gradientdescent has been an extremely fast and robust way of trainingnetworks (Goodfellow, Bengio, and Courville 2016).

A key technical contribution of our work is introducingand training a surrogate regularization function Ω(W ) :supp(W ) → R+ to map each candidate neural model pa-rameter vector W to an estimate of the average-path-length.Our approximate function Ω is implemented as a standalonemulti-layer perceptron network and is thus differentiable. Letvector ξ of size k denote the parameters of this chosen MLPapproximator. We can train Ω to be a good estimator byminimizing a squared error loss function:

minξ

∑Jj=1(Ω(Wj)− Ω(Wj , ξ))

2 + ε||ξ||22 (5)

where Wj are the entire set of parameters for our model,ε > 0 is a regularization strength, and we assume we havea dataset of J known parameter vectors and their associ-ated true path-lengths: Wj ,Ω(Wj)Jj=1. This dataset canbe assembled using the candidate W vectors obtained whiletraining our target neural model y(·,W ), as well as by evalu-ating Ω(W ) for randomly generatedW . Importantly, one cantrain the surrogate function Ω in parallel with our network.In the supplement, we show evidence that our surrogate pre-dictor Ω(·) tracks the true average path length as we train thetarget predictor y(·,W ).

Training the Surrogate Loss Even moderately-sizedGRUs can have parameter vectors W with thousands ofdimensions. Our labeled dataset for surrogate training –Wj ,Ω(Wj)Jj=1—will only have one Wj example fromeach target network training iteration. Thus, in early itera-tions, we will have only few examples from which to learna good surrogate function Ω(W ). We resolve this challenge

3

Page 4: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

via augmenting our training set with additional examples: Werandomly sample weight vectors W and calculate the trueaverage path length Ω(W ), and we also perform several ran-dom restarts on the unregularized GRU and use those weightsin our training set.

A second challenge occurs later in training: as the modelparameters W shift away from their initial values, those earlyparameters may not be as relevant in characterizing the cur-rent decision function of the GRU. To address this, for eachepoch, we use examples only from the past E epochs (in ad-dition to augmentation), where in practice, E is empiricallychosen. Using examples from a fixed window of epochs alsospeeds up training. The supplement shows a comparison ofthe importance of these heuristics for efficient and accuratetraining—empirically, data augmentation for stabilizing sur-rogate training allows us to scale to GRUs with 100s of nodes.GRUs of this size are sufficient for many real problems, suchas those we encounter in healthcare domains.

Typically, we use J = 50 labeled pairs for surrogate train-ing for toy datasets and J = 100 for real world datasets.Optimization of our surrogate objective is done via gradientdescent. We use Autograd to compute gradients of the loss inEq. (5) with respect to ξ, then use Adam to compute descentdirections with step sizes set to 0.01 for toy datasets and0.001 for real world datasets.

4 Tree-Regularized MLPs: A DemonstrationWhile time-series models are the main focus of this work,we first demonstrate tree regularization on a simple binaryclassification task to build intuition. We call this task the 2DParabola problem, because as Fig. 2(a) shows, the trainingdata consists of 2D input points whose two-class decisionboundary is roughly shaped like a parabola. The true decisionfunction is defined by y = 5 ∗ (x − 0.5)2 + 0.4. We sam-pled 500 input points xn uniformly within the unit square[0, 1]× [0, 1] and labeled those above the decision functionas positive. To make it easy for models to overfit, we flipped10% of the points in a region near the boundary. A random30% were held out for testing.

For the classifier y, we train a 3-layer MLP with 100 firstlayer nodes, 100 second layer nodes, and 10 third layer nodes.This MLP is intentionally overly expressive to encourageoverfitting and expose the impact of different forms of regular-ization: our proposed tree regularization Ψ(W ) = Ω(W ) andtwo baselines: an L2 penalty on the weights Ψ(W ) = ||W ||2,and an L1 penalty on the weights Ψ(W ) = ||W ||1. For eachregularization function, we train models at many differentregularization strengths λ chosen to explore the full rangeof decision boundary complexities possible under each tech-nique.

For our tree regularization, we model our surrogate Ω(W )with a 1-hidden layer MLP with 25 units. We find this simplearchitecture works well, but certainly more complex MLPscould could be used on more complex problems. The objec-tive in equation 1 was optimized via Adam gradient descent(Kingma and Ba 2014) using a batch size of 100 and a learn-ing rate of 1e-3 for 250 epochs, and hyperparameters wereset via cross validation using grid search (see supplement for

full experimental details).Fig. 2 (b) shows the each trained model as a single point

in a 2D fitness space: the x-axis measures model complexityvia our average-path-length metric, and the y-axis measuresAUC prediction performance. These results show that sim-ple L1 or L2 regularization does not produce models withboth small node count and good predictions at any value ofthe regularization strength λ. As expected, large λ valuesfor L1 and L2 only produce far-too-simple linear decisionboundaries with poor accuracies. In contrast, our proposedtree regularization directly optimizes the MLP to have simpletree-like boundaries at high λ values which can still yieldgood predictions.

The lower panes of Fig. 2 shows these boundaries. Ourtree regularization is uniquely able to create axis-alignedfunctions, because decision trees prefer functions that areaxis-aligned splits. These axis-aligned functions require veryfew nodes but are more effective than L1 and L2 counterparts.The L1 boundary is more sharp, whereas the L2 is moreround.

5 Tree-Regularized Time-Series ModelsWe now evaluate our tree-regularization approach on time-series models. We focus on GRU-RNN models, with somelater experiments on new hybrid GRU-HMM models. Aswith the MLP, each regularization technique (tree, L2, L1)can be applied to the output node of the GRU across a rangeof strength parameters λ. Importantly, Algorithm 1 can com-pute the average-decision-tree-path-length for any fixed deepmodel given its parameters, and can hence be used to mea-sure decision boundary complexity under any regularization,including L1 or L2. This means that when training any model,we can track both the predictive performance (as measuredby area-under-the-ROC-curve (AUC); higher values meanbetter predictions), as well as the complexity of the deci-sion tree required to explain each model (as measured byour average path length metric; lower values mean moreinterpretable models). We also show results for a baselinestandalone decision tree classifier without any associateddeep model, sweeping a range of parameters controlling leafsize to explore how this baseline trades off path length andprediction quality. Further details of our experimental proto-col are in the supplement, as well as more extensive resultswith additional baselines.

5.1 TasksSynthetic Task: Signal-and-noise HMM We generateda toy dataset of N = 100 sequences, each with T = 50timesteps. Each timestep has a data vector xnt of 14 binaryfeatures and a single binary output label ynt. The data comesfrom two separate HMM processes. First, a “signal” HMMgenerates the first 7 data dimensions from 5 well-separatedstates. Second, an independent “noise” HMM generates theremaining 7 data dimensions from a different set of 5 states.Each timestep’s output label ynt is produced by a rule involv-ing both the signal data and the signal hidden state: the targetis 1 at timestep t only if both the first signal state is active andthe first observation is turned on. We deliberately designed

4

Page 5: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

0.0 0.2 0.4 0.6 0.8 1.00.0

0.2

0.4

0.6

0.8

1.0

(a) Training Data and Binary Class Labels for 2D Parabola

0.0 2.5 5.0 7.5 10.0 12.5 15.0 17.5 20.0Average Path Length

0.6

0.8

1.0

AU

C (T

est)

MLP (L2)MLP (L1)MLP (Tree)Decision Tree

(b) Prediction quality and complexity as reg. strength λ variesL1 0.0005 L1 0.002 L1 0.0025 L1 0.0035 L1 0.004 L1 0.0045

(c) Decision Boundaries with L1 regularizationL2 0.05 L2 0.8 L2 0.9 L2 1.0 L2 1.30 L2 2.0

(d) Decision Boundaries with L2 regularizationTree 0.01 Tree 100.0 Tree 700.0 Tree 9500.0 Tree 12000.0 Tree 15000.0

(e) Decision Boundaries Tree regularization

Figure 2: 2D Parabola task: (a) Each training data pointin 2D space, overlaid with true parabolic class boundary.(b): Each method’s prediction quality (AUC) and complexity(path length) metrics, across range of regularization strengthλ. In the small path length regime between 0 and 5, tree reg-ularization produces models with higher AUC than L1 or L2.(c-e): Decision boundaries (black lines) have qualitativelydifferent shapes for different regularization schemes, as regu-larization strength λ increases. We color predictions as truepositive (red), true negative (yellow), false negative (green),and false positive (blue).

the generation process so that neither logistic regression withx as features nor an RNN model that makes predictions fromhidden states alone can perfectly separate this data.

Real-World Tasks: We tested our approach on several realtasks: predicting medical outcomes of hospitalized septic pa-tients, predicting HIV therapy outcomes, and identifying stopphonemes in English speech recordings. To normalize scales,we independently standardized features x via z-scoring.

• Sepsis Critical Care: We study time-series data for 11 786

septic ICU patients from the public MIMIC III dataset(Johnson et al. 2016). We observe at each hour t a datavector xnt of 35 vital signs and lab results as well as a labelvector ynt of 5 binary outcomes. Hourly data xnt measurescontinuous features such as respiration rate (RR), bloodoxygen levels (paO2), fluid levels, and more. Hourly binarylabels ynt include whether the patient died in hospitaland if mechanical ventilation was applied. Models aretrained to predict all 5 output dimensions concurrentlyfrom one shared embedding. The average sequence lengthis 15 hours. 7 070 patients are used in training, 1 769 forvalidation, and 294 for test.

• HIV Therapy Outcome (HIV): We use the EuResist In-tegrated Database (Zazzi et al. 2012) for 53 236 patientsdiagnosed with HIV. We consider 4-6 month intervals (cor-responding to hospital visits) as time steps. Each datavector xnt has 40 features, including blood counts, viralload measurements and lab results. Each output vector ynthas 15 binary labels, including whether a therapy was suc-cessful in reducing viral load to below detection limits, iftherapy caused CD4 blood cell counts to drop to dangerouslevels (indicating AIDS), or if the patient suffered adher-ence issues to medication. The average sequence lengthis 14 steps. 37 618 patients are used for training; 7 986 fortesting, and 7 632 for validation.

• Phonetic Speech (TIMIT): We have recordings of 630speakers of eight major dialects of American English read-ing ten phonetically rich sentences (Garofolo et al. 1993).Each sentence contains time-aligned transcriptions of 60phonemes. We focus on distinguishing stop phonemes(those that stop the flow of air, such as “b” or “g”) fromnon-stops. Each timestep has one binary label ynt indicat-ing if a stop phoneme occurs or not. Each input xnt has 26continuous features: the acoustic signal’s Mel-frequencycepstral coefficients and derivatives. There are 6 303 se-quences, split into 3 697 for training, 925 for validation,and 1 681 for testing. The average length is 614.

5.2 ResultsThe major conclusions of our experiments comparing GRUswith various regularizations are outlined below.

Tree-regularized models have fewer nodes than otherforms of regularization. Across tasks, we see that in thetarget regime of small decision trees (low average-pathlengths), our proposed tree-regularization achieves higher pre-diction quality (higher AUCs). In the signal-and-noise HMMtask, tree regularization (green line in Fig. 3(d)) achievesAUC values near 0.9 when its trees have an average pathlength of 10. Similar models with L1 or L2 regularizationreach this AUC only with trees that are nearly double in com-plexity (path length over 25). On the Sepsis task (Fig. 4) wesee AUC gains of 0.05-0.1 at path lengths of 2-10. On theTIMIT task (Fig. 5a), we see AUC gains of 0.05-0.1 at pathlengths of 20-30. Finally, on the HIV CD4 blood cell counttask in Fig. 5b, we see AUC differences of between 0.03 and0.15 for path lengths of 10-15. The HIV adherence task in

5

Page 6: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

X[0] <= 0.5value = [3897, 1103]

class = off

value = [2546, 0]class = off

True

X[4] <= 0.5value = [1351, 1103]

class = off

False

X[3] <= 0.5value = [902, 1103]

class = onvalue = [449, 0]

class = off

X[6] <= 0.5value = [902, 516]

class = offvalue = [0, 587]

class = on

X[5] <= 0.5value = [741, 516]

class = offvalue = [161, 0]

class = off

X[11] <= 0.5value = [588, 516]

class = offvalue = [153, 0]

class = off

X[1] <= 0.5value = [254, 516]

class = onvalue = [334, 0]

class = off

X[8] <= 0.5value = [112, 280]

class = on

X[9] <= 0.5value = [142, 236]

class = on

value = [0, 280]class = on

value = [112, 0]class = off

value = [0, 236]class = on

value = [142, 0]class = off

(a) GRU λ = 1

X[0] <= 0.5value = [4439, 561]

class = off

value = [2546, 0]class = off

True

X[3] <= 0.5value = [1893, 561]

class = off

False

value = [1744, 0]class = off

X[4] <= 0.5value = [149, 561]

class = on

X[7] <= 0.5value = [26, 561]

class = onvalue = [123, 0]

class = off

value = [0, 537]class = on

value = [26, 24]class = off

(b) GRU λ = 800

X[0] <= 0.5value = [4413, 587]

class = off

value = [2546, 0]class = off

True

X[3] <= 0.5value = [1867, 587]

class = off

False

value = [1744, 0]class = off

X[4] <= 0.5value = [123, 587]

class = on

value = [0, 587]class = on

value = [123, 0]class = off

(c) GRU λ = 1 000

0 10 20 30 40Average Path Length

0.5

0.6

0.7

0.8

0.9

1.0

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(d) GRU

Figure 3: Toy Signal-and-Noise HMM Task: (a)-(c) Decision trees trained to mimic predictions of GRU models with 25 hiddenstates at different regularization strengths λ; as expected, increasing λ decreases the size of the learned trees (see supplement formore trees). Decision tree (c) suggests the model learns to predict positive output (blue) if and only if “x[0] == 1 and x[3] == 1and x[4] == 0”, which is consistent with the true rule we used to generate labels: assign positive label only if first dimension ison (x[0] == 1) and first state is active (emission probabilities for this state: [.5 .5 .5 .5 0 . . .]). (d) Tree-regularized GRU modelsreach a sweet spot of small path lengths yet high AUC predictions that alternatives cannot reach at any tested value of λ.

0 10 20 30Average Path Length

0.5

0.6

0.7

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(a) In-Hospital Mortality

age <= 64.212value = [35299, 65553]

class = died_in_hosp:OFF

BUN <= 28.163value = [30660, 14493]

class = died_in_hosp:ON

True

BUN <= 24.153value = [4639, 51060]

class = died_in_hosp:OFF

False

Platelets_count <= 101.882value = [29033, 2784]

class = died_in_hosp:ON

BUN <= 49.909value = [1627, 11709]

class = died_in_hosp:OFF

age <= 51.728value = [1427, 2784]

class = died_in_hosp:OFFvalue = [27606, 0]

class = died_in_hosp:ON

value = [1427, 569]class = died_in_hosp:ON

value = [0, 2215]class = died_in_hosp:OFF

age <= 48.95value = [1627, 6324]

class = died_in_hosp:OFFvalue = [0, 5385]

class = died_in_hosp:OFF

Arterial_BE <= -0.943value = [1627, 809]

class = died_in_hosp:ONvalue = [0, 5515]

class = died_in_hosp:OFF

value = [486, 514]class = died_in_hosp:OFF

value = [1141, 295]class = died_in_hosp:ON

age <= 74.354value = [4639, 19225]

class = died_in_hosp:OFFvalue = [0, 31835]

class = died_in_hosp:OFF

INR <= 1.437value = [4639, 5882]

class = died_in_hosp:OFFvalue = [0, 13343]

class = died_in_hosp:OFF

Hb <= 9.907value = [4639, 3299]

class = died_in_hosp:ONvalue = [0, 2583]

class = died_in_hosp:OFF

Platelets_count <= 289.44value = [742, 2610]

class = died_in_hosp:OFF

HR <= 94.223value = [3897, 689]

class = died_in_hosp:ON

value = [0, 2203]class = died_in_hosp:OFF

value = [742, 407]class = died_in_hosp:ON

value = [3241, 0]class = died_in_hosp:ON

value = [656, 689]class = died_in_hosp:OFF

(b) In-Hospital Mortality

0.0 2.5 5.0 7.5 10.0 12.5Average Path Length

0.5

0.6

0.7

0.8

0.9

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(c) Mechanical Ventilation

GCS <= 13.998value = [14985, 85867]class = mechvent:OFF

value = [0, 67874]class = mechvent:OFF

True

FiO2_100 <= 37.002value = [14985, 17993]class = mechvent:OFF

False

CO2_mEqL <= 31.131value = [13112, 599]class = mechvent:ON

FiO2_100 <= 46.663value = [1873, 17394]class = mechvent:OFF

value = [12674, 0]class = mechvent:ON

value = [438, 599]class = mechvent:OFF

RR <= 19.136value = [1873, 5146]

class = mechvent:OFFvalue = [0, 12248]

class = mechvent:OFF

value = [0, 2969]class = mechvent:OFF

CO2_mEqL <= 27.007value = [1873, 2177]

class = mechvent:OFF

paO2 <= 112.252value = [1330, 1185]class = mechvent:ON

value = [543, 992]class = mechvent:OFF

value = [885, 541]class = mechvent:ON

value = [445, 644]class = mechvent:OFF

(d) Mechanical Ventilation

Figure 4: Sepsis task: Study of different regularization techniques for GRU model with 100 states, trained to jointly predict 5binary outcomes. Panels (a) and (c) show AUC vs. average path length for 2 of the 5 outcomes (remainder in the supplement); inboth cases, tree-regularization provides higher accuracy in the target regime of low-complexity decision trees. Panels (b) and (d)show the associated decision trees for λ = 2 000; these were found by clinically interpretable by an ICU clinician (see main text).

20 40 60Average path Length

0.5

0.6

0.7

0.8

0.9

1.0

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(a) TIMIT Stop Phonemes

10.0 12.5 15.0 17.5 20.0Average Path Length

0.55

0.60

0.65

0.70

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(b) HIV: CD4+ ≤ 200 cells/ml

15.0 17.5 20.0 22.5 25.0 27.5Average Path Length

0.6

0.7

0.8

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(c) HIV Therapy Adherence

Baseline VL <= 45.68value = [33957, 39928]

class = Poor Adherence: OFF

value = [146, 3524]class = Poor Adherence: OFF

Baseline VL <= 900486.29value = [33811, 36404]

class = Poor Adherence: ON

No. of prior treatment lines < 4.0value = [29541, 31271]

class = Poor Adherence: OFF

Baseline CD4 <216.94value = [4270, 5133]

class = Poor Adherence: ON

value = [27262, 21472]class = Poor Adherence: OFF

value = [2279, 9799]class = Poor Adherence: ON

Age <38.0value = [2623, 4666]

class = Poor Adherence: ONvalue = [1647, 467]

class = Poor Adherence: ON

value = [1306, 2219]class = Poor Adherence: ON

Sex <= 0.5value = [1317, 2447]

class = Poor Adherence: ON

IDU <= 0.5value = [71, 412]

class = Poor Adherence: OFFvalue = [1246, 2035]

class = Poor Adherence: ON

value = [44, 13]class = Poor Adherence: ON

value = [27, 399]class = Poor Adherence: OFF

(d) HIV Therapy Adherence

Figure 5: TIMIT and HIV tasks: Study of different regularization techniques for GRU model with 75 states. Panels (a)-(c) aretradeoff curves showing how AUC predictive power and decision-tree complexity evolve with increasing regularization strengthunder L1, L2 or tree regularization on both TIMIT and HIV tasks. The GRU is trained to jointly predict 15 binary outcomes forHIV, of which 2 are shown here in Panels (b) - (c). The GRU’s decision tree proxy for HIV Adherence is shown in (d).

6

Page 7: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

Fig. 5d has AUC gains of between 0.03 and 0.05 in the pathlength range of 19 to 25 while at smaller paths all methodsare quite poor, indicating the problem’s difficulty. Overall,these AUC gains are particularly useful in determining howto administer subsequent HIV therapies.

We emphasize that our tree-regularization usually achievesa sweet spot of high AUCs at short path lengths not possiblewith standalone decision trees (orange lines), L1-regularizeddeep models (red lines) or L2-regularized deep models (bluelines). In unshown experiments, we also tested elastic netregularization (Zou and Hastie 2005), a linear combinationof L1 and L2 penalities. We found elastic nets to follow thesame trend lines as L1 and L2, with no visible differences. Indomains where human-simulatability is required, increases inprediction accuracy in the small-complexity regime can meanthe difference between models that provide value on a taskand models that are unusable, either because performance istoo poor or predictions are uninterpretable.

Our learned decision tree proxies are interpretable.Across all tasks, the decision trees which mimic the pre-dictions of tree-regularized deep models are small enoughto simulate by hand (path length ≤ 25) and help users graspthe model’s nonlinear prediction logic. Intuitively, the treesfor our synthetic task in Fig. 3(a)-(c) decrease in size as thestrength λ increases. The logic of these trees also matchesthe true labeling process: even the simplest tree (c) checks arelevant subset of input dimensions necessary to verify thatboth the first state and the first output dimension are active.

In Fig. 4, we show decision tree proxies for our deep mod-els on two sepsis prediction tasks: mortality and need forventilation. We consulted a clinical expert on sepsis treat-ment, who noted that the trees helped him understand whatthe models might be doing and thus determine if he wouldtrust the deep model. For example, he said that using FiO2,RR, CO2 and paO2 to predict need for mechanical ventilation(Fig. 4d) was sensible, as these all measure breathing quality.In contrast, the in-hospital mortality tree (Fig. 4b) predictsthat some young patients with no organ failure have highmortality rates while other young patients with organ failurehave low mortality. These counter-intuitive results led to hy-potheses about how uncaptured variables impact the trainingprocess. Such reasoning would not be possible from simplesensitivity analyses of the deep model.

Finally, we have verified that the decision tree proxies ofour tree-regularized deep models of the HIV task in Fig. 5dare interpretable for understanding why a patient has troubleadhering to a prescription; that is, taking drugs regularly asdirected. Our clinical collaborators confirm that the baselineviral load and number of prior treatment lines, which areprominent attributes for the decisions in Fig. 5d, are usefulpredictors of a patient with adherence issues. Several med-ical studies (Langford, Ananworanich, and Cooper 2007;Socas et al. 2011) suggest that patients with higher base-line viral loads tend to have faster disease progression, andhence have to take several drug cocktails to combat resistance.Juggling many drugs typically makes it difficult for these pa-tients to adhere as directed. We hope interpretable predictive

models for adherence could help assess a patient’s overallprognosis (Paterson et al. 2000) and offer opportunities forintervention (e.g. with alternative single-tablet regimens).

Decision trees trained to mimic deep models make faith-ful predictions. Across datasets, we find that each tree-regularized deep time-series model has predictions that agreewith its corresponding decision tree proxy in about 85-90%of test examples. Table 1 shows exact fidelty scores for eachdataset. Thus, the simulatable paths of the decision tree willbe trustworthy in a majority of cases.

Practical runtimes for tree regularization are less thantwice that of simpler L2. While our tree-regularized GRUwith 10 states takes 3977 seconds per epoch on TIMIT, asimilar L2-regularized GRU takes 2116 seconds per epoch.Thus, our new method has cost less than twice the baselineeven when the surrogate is serially computed. Because thesurrogate Ω(W ) will in general be a much smaller modelthan the predictor y(x,W ), we expect one could get fasterper-epoch times by parallelizing the creation of (W,Ω(W ))

training pairs and the training of the surrogate Ω(W ). Ad-ditionally, 3977 seconds includes the time needed to trainthe surrogate. In practice, we do this sparingly, only onceevery 25 epochs, yielding an amortized per-epoch cost of2191 seconds (more runtime results are in the supplement).

Decision trees are stable over multiple optimization runs.When tree regularization is strong (high λ), the decision treestrained to match the predictions of deep models are stable. Forboth signal-and-noise and sepsis tasks, multiple runs fromdifferent random restarts have nearly identical tree shapeand size, perhaps differing by a few nodes. This stability iscrucial to building trust in our method. On the signal-and-noise task (λ = 7000), 7 of 10 independent runs with randominitializations resulted in trees of exactly the same structure,and the others closely resembled those sharing the samesubtrees and features (more details in supplement).

The deep residual GRU-HMM achieves high AUC withless complexity. So far, we have focused on regularizingstandard deep models, such as MLPs or GRUs. Another op-tion is to use a deep model as a residual on another modelthat is already interpretable: for example, discrete HMMs par-tition timesteps into clusters, each of which can be inspected,but its predictions might have limited accuracy. In Fig. 6, weshow the performance of jointly training a GRU-HMM, anew model which combines an HMM with a tree-regularizedGRU to improve its predictions (details and further resultsin the supplement). Here, the ideal path length is zero, indi-cating only the HMM makes predictions. For small average-path-lengths, the GRU-HMM improves the original HMM’spredictions and has simulatability gains over earlier GRUs.On the mechanical ventilation task, the GRU-HMM requiresan average path length of only 28 to reach AUC of 0.88, whilethe GRU alone with the same number of states requires apath length of 60 to reach the same AUC. This suggests that

7

Page 8: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

Dataset Fidelitysignal-and-noise HMM 0.88SEPSIS (In-Hospital Mortality) 0.81SEPSIS (90-Day Mortality) 0.88SEPSIS (Mech. Vent.) 0.90SEPSIS (Median Vaso.) 0.92SEPSIS (Max Vaso.) 0.93HIV (CD4+ below 200) 0.84HIV (Therapy Success) 0.88HIV (Mortality) 0.93HIV (Poor Adherence) 0.90HIV (AIDS Onset) 0.93TIMIT 0.85

Table 1: Fidelity of predictions from our trained deep GRU-RNN and its corresponding decision tree. Fidelity is definedas the percentage of test examples on which the predictionmade by a tree agrees with the deep model (Craven andShavlik 1996). We used 20 hidden GRU states for signal-and-noise task, 50 states for all others.

jointly-trained deep residual models may provide even betterinterpretability.

0 5 10 15 20 25 30Average Path Length

0.935

0.940

0.945

0.950

0.955

AU

C (T

est)

GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)

(a) Signal-and-noise 20+5

0 10 20 30Average Path Length

0.72

0.73

0.74

0.75

AU

C (T

est)

GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)

(b) In-Hosp. Mort. 50+50

0 2 4 6 8 10Average Path Length

0.82

0.84

0.86

0.88

0.90

AU

C (T

est)

GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)

(c) Mech. Vent. 50+50

0 20 40 60 80AvHUagH 3ath LHngth

0.92

0.93

0.94

0.95

A8

C (T

Hst)

G58-H00 (L1)G58-H00 (L2)G58-H00 (TUHH)

(d) Stop Phonemes 50+25

Figure 6: Fitness curves for the GRU-HMM, showing predic-tion quality (AUC) vs. complexity (path length) across rangeof regularization strengths λ. Captions show the number ofHMM states plus the number of GRU states. See earlier fig-ures to compare these GRU-HMM numbers to simpler GRUand decision tree baselines.

6 Discussion and ConclusionWe have introduced a novel tree-regularization technique thatencourages the complex decision boundaries of any differen-tiable model to be well-approximated by human-simulatablefunctions, allowing domain experts to quickly understandand approximately compute what the more complex modelis doing. Overall, our training procedure is robust and effi-cient; future work could continue to explore and increase the

stability of the learned models as well as identify ways toapply our approach to situations in which the inputs are notinherently interpretable (e.g. pixels in an image).

Across three complex, real-world domains – HIV treat-ment, sepsis treatment, and human speech processing – ourtree-regularized models provide gains in prediction accuracyin the regime of simpler, approximately human-simulatablemodels. Future work could apply tree regularization to lo-cal, example-specific approximations of a loss (Ribeiro,Singh, and Guestrin 2016) or to representation learning tasks(encouraging embeddings with simple boundaries). Morebroadly, our general training procedure could apply tree-regularization or other procedure-regularization to a wideclass of popular models, helping us move beyond sparsitytoward models humans can easily simulate and thus trust.

AcknowledgementsMW is supported by the U.S. National Science Foundation.MCH is supported by Oracle Labs. SP is supported by theSwiss National Science Foundation project 51MRP0 158328.The authors thank the EuResist Network for providing HIVdata for this study, and thank Matthieu Komorowski for thepreprocessed sepsis data (Raghu et al. 2017). Computationswere supported by the FAS Research Computing Group atHarvard and sciCORE (http://scicore.unibas.ch/) scientificcomputing core facility at University of Basel.

ReferencesAdler, P.; Falk, C.; Friedler, S. A.; Rybeck, G.; Scheidegger,C.; Smith, B.; and Venkatasubramanian, S. 2016. Auditingblack-box models for indirect influence. In ICDM.Audet, C., and Kokkolaras, M. 2016. Blackbox andderivative-free optimization: theory, algorithms and applica-tions. Springer.Bahdanau, D.; Cho, K.; and Bengio, Y. 2014. Neural machinetranslation by jointly learning to align and translate. arXivpreprint arXiv:1409.0473.Balan, A. K.; Rathod, V.; Murphy, K. P.; and Welling, M.2015. Bayesian dark knowledge. In NIPS.Che, Z.; Kale, D.; Li, W.; Bahadori, M. T.; and Liu, Y. 2015.Deep computational phenotyping. In KDD.Chen, J. H., and Asch, S. M. 2017. Machine learning and pre-diction in medicinebeyond the peak of inflated expectations.N Engl J Med 376(26):2507–2509.Cho, K.; Gulcehre, B. v. M. C.; Bahdanau, D.; Schwenk,F. B. H.; and Bengio, Y. 2014. Learning phrase represen-tations using RNN encoder–decoder for statistical machinetranslation. In EMLNP.Choi, E.; Bahadori, M. T.; Schuetz, A.; Stewart, W. F.; andSun, J. 2016. Doctor AI: Predicting clinical events via recur-rent neural networks. In Machine Learning for HealthcareConference.Craven, M., and Shavlik, J. W. 1996. Extracting tree-structured representations of trained networks. In NIPS.Drucker, H., and Le Cun, Y. 1992. Improving generalizationperformance using double backpropagation. IEEE Transac-tions on Neural Networks 3(6):991–997.

8

Page 9: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

Erhan, D.; Bengio, Y.; Courville, A.; and Vincent, P. 2009.Visualizing higher-layer features of a deep network. Tech-nical Report 1341, Department of Computer Science andOperations Research, University of Montreal.Garofolo, J. S.; Lamel, L. F.; Fisher, W. M.; Fiscus, J. G.;Pallett, D. S.; Dahlgren, N. L.; and Zue, V. 1993. Timitacoustic-phonetic continuous speech corpus. Linguistic dataconsortium 10(5).Goodfellow, I.; Bengio, Y.; and Courville, A.2016. Deep Learning. MIT Press. http://www.deeplearningbook.org.Han, S.; Pool, J.; Tran, J.; and Dally, W. 2015. Learningboth weights and connections for efficient neural network. InNIPS.Hinton, G.; Vinyals, O.; and Dean, J. 2015. Distill-ing the knowledge in a neural network. arXiv preprintarXiv:1503.02531.Hochreiter, S., and Schmidhuber, J. 1997. Long short-termmemory. Neural computation 9(8):1735–1780.Hu, Z.; Ma, X.; Liu, Z.; Hovy, E.; and Xing, E. 2016. Har-nessing deep neural networks with logic rules. In ACL.Johnson, A. E.; Pollard, T. J.; Shen, L.; Lehman, L. H.; Feng,M.; Ghassemi, M.; Moody, B.; Szolovits, P.; Celi, L. A.; andMark, R. G. 2016. MIMIC-III, a freely accessible criticalcare database. Scientific Data 3.Kingma, D., and Ba, J. 2014. Adam: A method for stochasticoptimization. arXiv preprint arXiv:1412.6980.Krizhevsky, A.; Sutskever, I.; and Hinton, G. E. 2012. Ima-geNet classification with deep convolutional neural networks.In NIPS.Lakkaraju, H.; Bach, S. H.; and Leskovec, J. 2016. Inter-pretable decision sets: A joint framework for description andprediction. In KDD.Langford, S. E.; Ananworanich, J.; and Cooper, D. A. 2007.Predictors of disease progression in hiv infection: a review.AIDS Research and Therapy 4(1):11.Lei, T.; Barzilay, R.; and Jaakkola, T. 2016. Rationalizingneural predictions. arXiv preprint arXiv:1606.04155.Lipton, Z. C. 2016. The mythos of model interpretability.In ICML Workshop on Human Interpretability in MachineLearning.Lundberg, S., and Lee, S.-I. 2016. An unexpected unityamong methods for interpreting model predictions. arXivpreprint arXiv:1611.07478.Miotto, R.; Li, L.; Kidd, B. A.; and Dudley, J. T. 2016. Deeppatient: An unsupervised representation to predict the futureof patients from the electronic health records. ScientificReports 6:26094.Ochiai, T.; Matsuda, S.; Watanabe, H.; and Katagiri, S. 2017.Automatic node selection for deep neural networks usinggroup lasso regularization. In ICASSP.Paterson, D. L.; Swindells, S.; Mohr, J.; Brester, M.; Vergis,E. N.; Squier, C.; Wagener, M. M.; and Singh, N. 2000. Ad-herence to protease inhibitor therapy and outcomes in patientswith hiv infection. Annals of internal medicine 133(1):21–30.

Pedregosa, F.; Varoquaux, G.; Gramfort, A.; Michel, V.; et al.2011. Scikit-learn: Machine learning in Python. Journal ofMachine Learning Research 12:2825–2830.Raghu, A.; Komorowski, M.; Celi, L. A.; Szolovits, P.; andGhassemi, M. 2017. Continuous state-space models for opti-mal sepsis treatment-a deep reinforcement learning approach.In Machine Learning for Healthcare Conference.Rastegari, M.; Ordonez, V.; Redmon, J.; and Farhadi, A. 2016.XNOR-Net: ImageNet classification using binary convolu-tional neural networks. In ECCV.Ribeiro, M. T.; Singh, S.; and Guestrin, C. 2016. Why shouldI trust you?: Explaining the predictions of any classifier. InKDD.Ross, A.; Hughes, M. C.; and Doshi-Velez, F. 2017. Rightfor the right reasons: Training differentiable models by con-straining their explanations. In IJCAI.Selvaraju, R. R.; Das, A.; Vedantam, R.; Cogswell, M.;Parikh, D.; and Batra, D. 2016. Grad-cam: Why did you saythat? visual explanations from deep networks via gradient-based localization. arXiv preprint arXiv:1610.02391.Singh, S.; Ribeiro, M. T.; and Guestrin, C. 2016. Programsas black-box explanations. arXiv preprint arXiv:1611.07579.Socas, M. E.; Sued, O.; Laufer, N.; Lzaro, M. E.; Mingrone,H.; Pryluka, D.; Remondegui, C.; Figueroa, M. I.; Cesar, C.;Gun, A.; Turk, G.; Bouzas, M. B.; Kavasery, R.; Krolewiecki,A.; Prez, H.; Salomn, H.; Cahn, P.; and de SeroconversinStudy Group, G. A. 2011. Acute retroviral syndrome andhigh baseline viral load are predictors of rapid hiv progressionamong untreated argentinean seroconverters. Journal of theInternational AIDS Society 14(1):40–40.Sutskever, I.; Vinyals, O.; and Le, Q. V. 2014. Sequence tosequence learning with neural networks. In NIPS.Tang, W.; Hua, G.; and Wang, L. 2017. How to train acompact binary neural network with high accuracy? In AAAI.Zazzi, M.; Incardona, F.; Rosen-Zvi, M.; Prosperi, M.;Lengauer, T.; Altmann, A.; Sonnerborg, A.; Lavee, T.;Schulter, E.; and Kaiser, R. 2012. Predicting response toantiretroviral treatment by machine learning: the euresistproject. Intervirology 55(2):123–127.Zhang, Y.; Lee, J. D.; and Jordan, M. I. 2016. l1-regularizedneural networks are improperly learnable in polynomial time.In ICML.Zou, H., and Hastie, T. 2005. Regularization and variableselection via the elastic net. Journal of the Royal StatisticalSociety: Series B (Statistical Methodology) 67(2):301–320.

9

Page 10: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

Supplementary Material

A Details for Decision-Tree TrainingTraining decision trees with post-pruning. Our average path length function Ω(W ) for determining the complexity of a deepmodel with parameters W – defined in the main paper in Alg. 1 – assumes that we have a robust, black-box way to train binarydecision-trees called TRAINTREE given a labeled dataset xn, yn. For this we use the DecisionTree module distributed inPython’s sci-kit learn, which optimizes information gain with Gini impurity. The specific syntax we use (for reproducibility) is:tree = DecisionTree(min_sample_count=5)tree.fit(x_train, y_train)tree = prune_tree(tree, x_valid, y_valid)

The provided keyword options force the tree to have at least 5 examples from the training set in every leaf. We found that tuninghyperparameters of the TRAINTREE subprocedure, such as the minimum size of a leaf node, to be important for making usefultrees.

Generally, the runtime cost of sklearn’s fitting procedure scales superlinearly with the number of examples N and linearlywith the number of features F – a total complexity of O(FN log(N)). In practice, we found that with N = 1000 examples,F = 10 features, tree construction takes 15.3 microseconds.

The pruning procedure is a heuristic to create simpler trees, summarized in algorithm 2. After TRAINTREE delivers a workingdecision tree, we iterative propose removing each remaining leaf node, accepting the proposal if the squared prediction error on avalidation set improves. This pruning removes sub-trees that don’t generalize to unseen data.

Algorithm 2 Post-pruning for training decision trees.

Require:T : initial decision treeERRONVAL(·) : squared error on validation data

ERRONVAL(T ) ,∑Nn=1(T (xn)− yn)2

1: procedure PRUNETREE( T , err )2: e← ERRONVAL(T ).3: for node n ∈ SORTLEAFTOROOT(T.nodes) do4: T ′ ← REMOVENODE(T, n)5: enew ← ERRONVAL(T ′)6: if enew < e then T ← T ′

7: Return T

Sanity check: Surrogate path length closely follow true path length. Fig. A.1 shows that our surrogate predictor Ω(·) tracksthe true average path length as we train the target predictor y(·,W ) on several different datasets.

Sensitivity to different choices for surrogate training. In Fig. A.2, we show sample learning curves for variations of methodsfor approximating the average path length (also called “node count”) in a decision tree. In blue is the true value. Each of the other3 lines use the same surrogate model: an MLP with 25 hidden nodes. Increasing its capacity too much, i.e. 100 hidden nodes,leads to overfitting where the surrogate is able to predict the average path length extremely well for a small number of iterations,while the performance quickly decays. With an MLP of the right capacity, four additional tricks: (1) weight augmentation, (2)random restarts with an unregularized model, (3) fixed window of data, and (4) surrogate retraining greatly improve the accuracyof the average path length predictions.

Normally, if our differentiable model is a GRU, we compile examples using the GRU weights at every batch and calculate thetrue average path length. This dataset is used to train the surrogate model. If examples are very sparse, surrogate predictionsmay be unstable. Augmentation addresses this by randomly sampling weight vectors and computing the average path lengthto artificially create a larger dataset. Early epochs are especially problematic when it comes to lacking data. In addition toaugmentation, we use random restarts to separately train unregularized GRUs (each with different weight initializations) to growa dataset of weight vectors prior to training the regularized model.

As the GRU parameters take steps away from their initial values, our examples from those early epochs no longer describe thecurrent state of the model. Retraining and a fixed window of data address this by re-learning the surrogate function at a fixedfrequency using examples only from the last J epochs. In practice, both the augmentation size, the retraining frequency, and Jare functions of the learning rate and the dataset size. See table B.1 for exact numbers.

10

Page 11: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

(a) Path length estimates Ω for 2D Parabola task

(b) Path length estimates Ω for Signal-and-noise HMM task

Figure A.1: True average path lengths (yellow) and surrogate estimates Ω (green) across many iterations of network parametertraining iterations.

0 50 100 150 200 250 300Epoch

5

10

15

Ave

rage

3at

h Le

ngth

ground truthretraLnLng

augPentatLonretraLnLng + augPentatLon

Figure A.2: This figure shows the effects of weight augmentation and retraining. The blue line is the true average path length ofthe decision tree at each epoch. All other lines show predicted path lengths using the surrogate MLP. By randomly samplingweights and intermittently retraining the surrogate, we significantly improve the ability of the surrogate model to track thechanges in the ground truth.

11

Page 12: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

B Experimental ProtocolSee table B.1 for model hyperparameters for each dataset. For standard recurrent models such as HMM or GRU, the decision treeswere trained on the input data and the predictions of the model’s output node. For our deep residual GRU-HMM, the decisiontrees were trained on the predictions on the GRU’s output node only. For both synthetic and real-world datasets, our surrogateto the tree loss is a multilayer perceptron with 1 hidden layer of 25 nodes. For each dataset, when we investigated severalregularization strengths (λ), we initialize the model weights using the same random seed. We use the Adam algorithm (Kingmaand Ba 2014) for all optimization.

Dataset Total Num. Sequences Avg. seq. length Learning Rate Batch size Minimum Leaf Sample Post-pruned Epochs (Model) Epochs (Surrogate) Retraining Freq. Jparabola n/a n/a 1e-2 32 0 N 250 500 100 n/a

signal-and-noise HMM 100 50 1e-2 10 25 Y 300 1000 50 50HIV 53 236 14 1e-3 256 1 000 Y 300 5000 25 100

SEPSIS 11 786 15 1e-3 256 1 000 Y 300 5000 25 100TIMIT 6 303 614 1e-3 256 5 000 Y 200 5000 25 100

Table B.1: Dataset summaries and training parameters used in our experiments.

B.1 2D ParabolaDataset generation. The training data consists of 2D input points whose two-class decision boundary is roughly shaped likea parabola. The true decision function is defined by y = 5 ∗ (x− 0.5)2 + 0.4. We sampled all 200 input points xn uniformlywithin the unit square [0, 1]× [0, 1] and labeled those above the decision function as positive. To add randomness, we flipped10% of the points in the region near the boundary between y = 5 ∗ (x− 0.5)2 + 0.2 and y = 5 ∗ (x− 0.5)2 + 0.6.

Regularization strengths. Tested values of regularization strength parameter λ: 0.1, 0.5, 1, 5, 10, 25, 50, 75, 100, 250, 500,750, 1 000, 2 500, 5 000, 7 500, 10 000, 25 000, 50 000, 75 000, 100 000

B.2 Signal-and-noise HMMDataset generation The transition and emission matrices describing the generative process used to create the signal-and-noiseHMM are shown in Fig. B.1. The output yn at every timestep is created by concatenating a one-hot vector of an emitted state andthe 7-dimensional binary input vector. We emphasize that to output 1, the HMM must be in state 1 and the first input featuremust be 1. (

.5 .5 .5 .5 0 0 0

.5 .5 .5 .5 .5 0 0

.5 .5 .5 0 .5 0 0

.5 .5 .5 0 0 .5 0

.5 .5 .5 0 0 0 .5

)(a)

(.7 .3 0 0 0.5 .25 .25 0 00 .25 .5 .25 00 0 .25 .25 .50 0 0 .5 .5

)(b)(

.5 .5 .5 0 0 0 00 .5 .5 .5 0 0 00 0 .5 .5 .5 0 00 0 0 .5 .5 .5 00 0 0 0 .5 .5 .5

)(c)

(.2 .2 .2 .2 .2.2 .2 .2 .2 .2.2 .2 .2 .2 .2.2 .2 .2 .2 .2.2 .2 .2 .2 .2

)(d)

Figure B.1: Emission (5 states vs 7 features) and transition probabilities for the signal HMM (a, b) and noise HMM (c, d).

Training Details. With synthetic datasets, we explore (1, 5, 6, 10, 15, 20) GRU nodes, (5, 6, 20) HMM states, and GRU-HMMswith 5 HMM states and (1, 5, 10, 15) GRU nodes.

B.3 SepsisTraining Details. We explore (1, 5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75, 100) GRU nodes, (5, 6, 10, 11, 15, 20,25, 26, 30, 35, 50, 51, 55, 60, 75, 100) HMM states, and GRU-HMMs with (5, 10, 25, 50) HMM states and (1, 5, 10, 25, 50)GRU nodes. The input features are z-scored prior to training.

B.4 HIVTraining Details. We explore (1, 5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75) GRU nodes, (5, 6, 10, 11, 15, 20, 25,26, 30, 35, 50, 51, 55, 60, 75) HMM states, and GRU-HMMs with (5, 10, 25) HMM states and (1, 5, 10, 25, 50) GRU nodes.

12

Page 13: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

B.5 TIMITTraining Details. We explore (1, 5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75) GRU nodes, (5, 6, 10, 11, 15, 20, 25,26, 30, 35, 50, 51, 55, 60, 75) HMM states, and GRU-HMMs with (5, 10, 25) HMM states and (1, 5, 10, 25, 50) GRU nodes.Like Sepsis, the input features are z-scored prior to training.

C Extended ResultsFor signal-to-noise HMM, Sepsis, and TIMIT, we first show expanded versions of the fitness trace plots and the tree visualizations.For Sepsis and HIV, we show the additional output dimensions not in the paper.

We also include tables of the test AUC performance for our synthetic and real data sets over a vast array of parameter settings(GRU node counts, HMM state counts, regularization strengths). Consistent with the common wisdom of training deep models,we found that larger models, with regularization, tended to perform the best.

C.1 Signal-and-noise HMM: Plots

(a) GRU: Signal-and-noise HMM

0 10 20 30 40Average Path Length

0.5

0.6

0.7

0.8

0.9

1.0

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(b) GRUHMM: Signal-and-noise HMM

0 5 10 15 20 25 30Average Path Length

0.935

0.940

0.945

0.950

0.955

AU

C (T

est)

GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)

Figure C.1: Performance and complexity trade-offs using L1, L2, and Tree regularization on (a) GRU and (b) GRU-HMMperformance on the Signal-and-noise HMM dataset. Note the differences in scale.

13

Page 14: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

C.2 Signal-and-noise HMM: Tree Visualization

X[0] <= 0.5value = [3915, 1085]

class = off

value = [2546, 0]class = off

True

X[3] <= 0.5value = [1369, 1085]

class = off

False

X[4] <= 0.5value = [1246, 498]

class = off

X[4] <= 0.5value = [123, 587]

class = on

X[6] <= 0.5value = [920, 498]

class = offvalue = [326, 0]

class = off

X[5] <= 0.5value = [759, 498]

class = offvalue = [161, 0]

class = off

X[11] <= 0.5value = [606, 498]

class = offvalue = [153, 0]

class = off

X[8] <= 0.5value = [430, 340]

class = off

X[1] <= 0.5value = [176, 158]

class = off

X[7] <= 0.5value = [200, 340]

class = onvalue = [230, 0]

class = off

X[12] <= 0.5value = [170, 324]

class = onvalue = [30, 16]

class = off

X[9] <= 0.5value = [66, 324]

class = onvalue = [104, 0]

class = off

X[10] <= 0.5value = [31, 198]

class = on

X[10] <= 0.5value = [35, 126]

class = on

X[1] <= 0.5value = [31, 121]

class = onvalue = [0, 77]

class = on

value = [0, 78]class = on

X[2] <= 0.5value = [31, 43]

class = on

value = [11, 23]class = on

value = [20, 20]class = off

value = [0, 99]class = on

X[1] <= 0.5value = [35, 27]

class = off

value = [17, 18]class = on

value = [18, 9]class = off

X[2] <= 0.5value = [83, 94]

class = on

X[2] <= 0.5value = [93, 64]

class = off

value = [83, 0]class = off

value = [0, 94]class = on

X[10] <= 0.5value = [43, 40]

class = offvalue = [50, 24]

class = off

value = [27, 28]class = on

value = [16, 12]class = off

value = [0, 587]class = on

value = [123, 0]class = off

(a) GRU:0.1

X[0] <= 0.5value = [3915, 1085]

class = off

value = [2546, 0]class = off

True

X[3] <= 0.5value = [1369, 1085]

class = off

False

X[4] <= 0.5value = [1246, 498]

class = off

X[4] <= 0.5value = [123, 587]

class = on

X[6] <= 0.5value = [920, 498]

class = offvalue = [326, 0]

class = off

X[5] <= 0.5value = [759, 498]

class = offvalue = [161, 0]

class = off

X[11] <= 0.5value = [606, 498]

class = offvalue = [153, 0]

class = off

X[8] <= 0.5value = [430, 340]

class = off

X[1] <= 0.5value = [176, 158]

class = off

X[7] <= 0.5value = [200, 340]

class = onvalue = [230, 0]

class = off

X[12] <= 0.5value = [170, 324]

class = onvalue = [30, 16]

class = off

X[9] <= 0.5value = [66, 324]

class = onvalue = [104, 0]

class = off

X[10] <= 0.5value = [31, 198]

class = on

X[10] <= 0.5value = [35, 126]

class = on

X[1] <= 0.5value = [31, 121]

class = onvalue = [0, 77]

class = on

value = [0, 78]class = on

X[2] <= 0.5value = [31, 43]

class = on

value = [11, 23]class = on

value = [20, 20]class = off

value = [0, 99]class = on

X[1] <= 0.5value = [35, 27]

class = off

value = [17, 18]class = on

value = [18, 9]class = off

X[2] <= 0.5value = [83, 94]

class = on

X[2] <= 0.5value = [93, 64]

class = off

value = [83, 0]class = off

value = [0, 94]class = on

X[10] <= 0.5value = [43, 40]

class = offvalue = [50, 24]

class = off

value = [27, 28]class = on

value = [16, 12]class = off

value = [0, 587]class = on

value = [123, 0]class = off

(b) GRU:0.1

X[0] <= 0.5value = [3897, 1103]

class = off

value = [2546, 0]class = off

True

X[4] <= 0.5value = [1351, 1103]

class = off

False

X[3] <= 0.5value = [902, 1103]

class = onvalue = [449, 0]

class = off

X[6] <= 0.5value = [902, 516]

class = offvalue = [0, 587]

class = on

X[5] <= 0.5value = [741, 516]

class = offvalue = [161, 0]

class = off

X[11] <= 0.5value = [588, 516]

class = offvalue = [153, 0]

class = off

X[1] <= 0.5value = [254, 516]

class = onvalue = [334, 0]

class = off

X[8] <= 0.5value = [112, 280]

class = on

X[9] <= 0.5value = [142, 236]

class = on

value = [0, 280]class = on

value = [112, 0]class = off

value = [0, 236]class = on

value = [142, 0]class = off

(c) GRU:1.0

X[3] <= 0.5value = [4644, 356]

class = off

value = [3532, 0]class = off

True

X[0] <= 0.5value = [1112, 356]

class = off

False

value = [758, 0]class = off

X[4] <= 0.5value = [354, 356]

class = on

X[11] <= 0.5value = [231, 356]

class = onvalue = [123, 0]

class = off

X[7] <= 0.5value = [82, 308]

class = on

X[12] <= 0.5value = [149, 48]

class = off

X[9] <= 0.5value = [45, 295]

class = onvalue = [37, 13]

class = off

X[1] <= 0.5value = [19, 205]

class = on

X[8] <= 0.5value = [26, 90]

class = on

value = [0, 108]class = on

X[12] <= 0.5value = [19, 97]

class = on

value = [0, 83]class = on

value = [19, 14]class = off

value = [0, 75]class = on

value = [26, 15]class = off

X[10] <= 0.5value = [82, 48]

class = offvalue = [67, 0]

class = off

X[2] <= 0.5value = [31, 48]

class = onvalue = [51, 0]

class = off

value = [14, 31]class = on

value = [17, 17]class = off

(d) GRU:10

X[3] <= 0.5value = [4747, 253]

class = off

value = [3532, 0]class = off

True

X[0] <= 0.5value = [1215, 253]

class = off

False

value = [758, 0]class = off

X[4] <= 0.5value = [457, 253]

class = off

X[8] <= 0.5value = [334, 253]

class = offvalue = [123, 0]

class = off

X[10] <= 0.5value = [222, 253]

class = onvalue = [112, 0]

class = off

X[1] <= 0.5value = [95, 217]

class = on

X[2] <= 0.5value = [127, 36]

class = off

value = [0, 137]class = on

X[2] <= 0.5value = [95, 80]

class = off

value = [0, 80]class = on

value = [95, 0]class = off

X[12] <= 0.5value = [48, 36]

class = offvalue = [79, 0]

class = off

X[1] <= 0.5value = [29, 29]

class = offvalue = [19, 7]

class = off

value = [13, 16]class = on

value = [16, 13]class = off

(e) GRU:20X[3] <= 0.5

value = [4761, 239]class = off

value = [3532, 0]class = off

True

X[0] <= 0.5value = [1229, 239]

class = off

False

value = [758, 0]class = off

X[4] <= 0.5value = [471, 239]

class = off

X[11] <= 0.5value = [348, 239]

class = offvalue = [123, 0]

class = off

X[2] <= 0.5value = [151, 239]

class = onvalue = [197, 0]

class = off

X[1] <= 0.5value = [16, 186]

class = on

X[1] <= 0.5value = [135, 53]

class = off

value = [0, 101]class = on

X[10] <= 0.5value = [16, 85]

class = on

value = [0, 69]class = on

value = [16, 16]class = off

X[9] <= 0.5value = [36, 53]

class = onvalue = [99, 0]

class = off

value = [14, 42]class = on

value = [22, 11]class = off

(f) GRU:100

X[3] <= 0.5value = [4783, 217]

class = off

value = [3532, 0]class = off

True

X[0] <= 0.5value = [1251, 217]

class = off

False

value = [758, 0]class = off

X[4] <= 0.5value = [493, 217]

class = off

X[1] <= 0.5value = [370, 217]

class = offvalue = [123, 0]

class = off

X[8] <= 0.5value = [62, 217]

class = onvalue = [308, 0]

class = off

value = [0, 217]class = on

value = [62, 0]class = off

(g) GRU::400

X[0] <= 0.5value = [4439, 561]

class = off

value = [2546, 0]class = off

True

X[3] <= 0.5value = [1893, 561]

class = off

False

value = [1744, 0]class = off

X[4] <= 0.5value = [149, 561]

class = on

X[7] <= 0.5value = [26, 561]

class = onvalue = [123, 0]

class = off

value = [0, 537]class = on

value = [26, 24]class = off

(h) GRU:800

X[0] <= 0.5value = [4413, 587]

class = off

value = [2546, 0]class = off

True

X[3] <= 0.5value = [1867, 587]

class = off

False

value = [1744, 0]class = off

X[4] <= 0.5value = [123, 587]

class = on

value = [0, 587]class = on

value = [123, 0]class = off

(i) GRU:1 000

value = [5000, 0]class = off

(j) GRU:10 000

Figure C.2: Decision trees trained under varying tree regularization strengths for GRU models on the signal-and-noise HMMdataset dataset. As the tree regularization increases, the number of nodes collapses to a single one. If we focus on (h), we see thatthe tree resembles the ground truth data-generating function quite closely.

14

Page 15: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

C.3 Signal-and-noise HMM: AUCs

Model AUC (Test) Average Path Length Parameter Countlogreg 0.91832 17.302 6

decision tree 0.92050 29.4424 -hmm (5) 0.93591 25.5736 71

hmm (20) 0.94177 27.2784 581gru (1) 0.65049 1.8876 29gru (5) 0.94812 26.304 205gru (6) 0.94883 27.2118 264

gru (10) 0.94962 28.563 560gru (15) 0.93982 30.7172 1 065gru (20) 0.93368 37.0844 1 720

grutree (20/10.0) 0.94226 28.1850 1 720grutree (20/200.0) 0.94806 26.8140 1 720

grutree (20/7 000.0) 0.94431 22.4646 1 720grutree (20/9 000.0) 0.90555 9.1127 1 720

grutree (20/10 000.0) 0.82770 3.4400 1 720gruhmm (5/1) 0.95146 18.2202 100gruhmm (5/5) 0.95584 27.258 276

gruhmm (5/10) 0.95773 30.9624 631gruhmm (5/15) 0.94857 36.7188 1 136

gruhmmtree (5/15/1.0) 0.95382 24.115 1 136gruhmmtree (5/15/10.0) 0.95180 16.883 1 136gruhmmtree (5/15/50.0) 0.95258 12.573 1 136

gruhmmtree (5/15/200.0) 0.95145 8.926 1 136gruhmmtree (5/15/500.0) 0.95769 5.231 1 136gruhmmtree (5/15/900.0) 0.95708 3.942 1 136

gruhmmtree (5/15/2 000.0) 0.95648 2.694 1 136gruhmmtree (5/15/5 000.0) 0.95399 1.896 1 136gruhmmtree (5/15/7 000.0) 0.93591 0.000 1 136

Table C.1: Performance metrics across models on the signal-and-noise HMM dataset. The parameter count is included as ameasure of the model capacity.

15

Page 16: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

C.4 Sepsis: Plots

(a) In-Hospital Mortality

0 10 20 30Average Path Length

0.5

0.6

0.7

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(b) 90-Day Mortality

0 5 10 15 20Average Path Length

0.5

0.6

0.7

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(c) Mechanical Ventilation

0.0 2.5 5.0 7.5 10.0 12.5Average Path Length

0.5

0.6

0.7

0.8

0.9

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(d) Median Vasopressor

0 5 10 15 20 25Average Path Length

0.5

0.6

0.7

0.8

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(e) Max Vasopressor

0 5 10 15 20 25Average Path Length

0.5

0.6

0.7

0.8

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

Figure C.3: Performance and complexity trade-offs using L1, L2, and Tree regularization on GRU performance on the Sepsisdataset.

C.5 Sepsis: Tree Visualization

age <= 64.212value = [35299, 65553]

class = died_in_hosp:OFF

BUN <= 28.163value = [30660, 14493]

class = died_in_hosp:ON

True

BUN <= 24.153value = [4639, 51060]

class = died_in_hosp:OFF

False

Platelets_count <= 101.882value = [29033, 2784]

class = died_in_hosp:ON

BUN <= 49.909value = [1627, 11709]

class = died_in_hosp:OFF

age <= 51.728value = [1427, 2784]

class = died_in_hosp:OFFvalue = [27606, 0]

class = died_in_hosp:ON

value = [1427, 569]class = died_in_hosp:ON

value = [0, 2215]class = died_in_hosp:OFF

age <= 48.95value = [1627, 6324]

class = died_in_hosp:OFFvalue = [0, 5385]

class = died_in_hosp:OFF

Arterial_BE <= -0.943value = [1627, 809]

class = died_in_hosp:ONvalue = [0, 5515]

class = died_in_hosp:OFF

value = [486, 514]class = died_in_hosp:OFF

value = [1141, 295]class = died_in_hosp:ON

age <= 74.354value = [4639, 19225]

class = died_in_hosp:OFFvalue = [0, 31835]

class = died_in_hosp:OFF

INR <= 1.437value = [4639, 5882]

class = died_in_hosp:OFFvalue = [0, 13343]

class = died_in_hosp:OFF

Hb <= 9.907value = [4639, 3299]

class = died_in_hosp:ONvalue = [0, 2583]

class = died_in_hosp:OFF

Platelets_count <= 289.44value = [742, 2610]

class = died_in_hosp:OFF

HR <= 94.223value = [3897, 689]

class = died_in_hosp:ON

value = [0, 2203]class = died_in_hosp:OFF

value = [742, 407]class = died_in_hosp:ON

value = [3241, 0]class = died_in_hosp:ON

value = [656, 689]class = died_in_hosp:OFF

(a) In-Hospital Mortality

BUN <= 45.135value = [94394, 6458]

class = mortality_90d:ON

Arterial_BE <= -4.059value = [79992, 1309]

class = mortality_90d:ON

True

Arterial_BE <= -3.736value = [14402, 5149]

class = mortality_90d:ON

False

INR <= 1.598value = [6660, 1309]

class = mortality_90d:ONvalue = [73332, 0]

class = mortality_90d:ON

value = [5223, 0]class = mortality_90d:ON

Arterial_BE <= -8.251value = [1437, 1309]

class = mortality_90d:ON

value = [357, 645]class = mortality_90d:OFF

value = [1080, 664]class = mortality_90d:ON

age <= 67.861value = [1472, 4414]

class = mortality_90d:OFF

Total_bili <= 9.535value = [12930, 735]

class = mortality_90d:ON

Arterial_BE <= -7.194value = [1472, 1019]

class = mortality_90d:ONvalue = [0, 3395]

class = mortality_90d:OFF

value = [423, 578]class = mortality_90d:OFF

value = [1049, 441]class = mortality_90d:ON

value = [12663, 0]class = mortality_90d:ON

value = [267, 735]class = mortality_90d:OFF

(b) 90-Day Mortality

GCS <= 13.998value = [14985, 85867]class = mechvent:OFF

value = [0, 67874]class = mechvent:OFF

True

FiO2_100 <= 37.002value = [14985, 17993]class = mechvent:OFF

False

CO2_mEqL <= 31.131value = [13112, 599]class = mechvent:ON

FiO2_100 <= 46.663value = [1873, 17394]class = mechvent:OFF

value = [12674, 0]class = mechvent:ON

value = [438, 599]class = mechvent:OFF

RR <= 19.136value = [1873, 5146]

class = mechvent:OFFvalue = [0, 12248]

class = mechvent:OFF

value = [0, 2969]class = mechvent:OFF

CO2_mEqL <= 27.007value = [1873, 2177]

class = mechvent:OFF

paO2 <= 112.252value = [1330, 1185]class = mechvent:ON

value = [543, 992]class = mechvent:OFF

value = [885, 541]class = mechvent:ON

value = [445, 644]class = mechvent:OFF

(c) Mechanical Ventilation

Arterial_BE <= -5.098value = [98690, 2162]

class = median_dose_vaso:ON

SysBP <= 110.75value = [7456, 2162]

class = median_dose_vaso:ON

True

value = [91234, 0]class = median_dose_vaso:ON

False

GCS <= 11.867value = [2750, 2162]

class = median_dose_vaso:ONvalue = [4706, 0]

class = median_dose_vaso:ON

PAWmean <= 12.426value = [1355, 1772]

class = median_dose_vaso:OFFvalue = [1395, 390]

class = median_dose_vaso:ON

value = [956, 750]class = median_dose_vaso:ON

value = [399, 1022]class = median_dose_vaso:OFF

(d) Median Vasopressor

Arterial_BE <= -5.098value = [99243, 1609]

class = max_dose_vaso:ON

SysBP <= 110.75value = [8009, 1609]

class = max_dose_vaso:ON

True

value = [91234, 0]class = max_dose_vaso:ON

False

GCS <= 10.993value = [3303, 1609]

class = max_dose_vaso:ONvalue = [4706, 0]

class = max_dose_vaso:ON

WBC_count <= 13.128value = [1160, 1609]

class = max_dose_vaso:OFFvalue = [2143, 0]

class = max_dose_vaso:ON

value = [745, 585]class = max_dose_vaso:ON

value = [415, 1024]class = max_dose_vaso:OFF

(e) Max Vasopressor

Figure C.4: Decision trees trained using λ = 800.0 for a GRU model using Sepsis. The 5 output dimensions are jointly trained.

16

Page 17: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

C.6 Sepsis: AUCs

ModelIn-HospitalMortality

90-DayMortality

MechanicalVentilation

MedianVasopressor

MaxVasopressor

Total AveragePath Length

ParameterCount

logreg 0.6980 0.6986 0.8242 0.7392 0.7392 32.489 180decision tree 0.7017 0.7016 0.8509 0.7439 0.7427 76.242 -

hmm (5) 0.7128 0.7095 0.6979 0.7295 0.7290 35.125 405hmm (10) 0.7227 0.7297 0.8237 0.7409 0.7405 57.629 860hmm (15) 0.7216 0.7282 0.8188 0.7346 0.7341 61.832 1 365hmm (20) 0.7233 0.7350 0.8218 0.7371 0.7364 62.353 1 920hmm (25) 0.7147 0.7321 0.8089 0.7313 0.7310 63.415 2 525hmm (30) 0.7164 0.7297 0.8099 0.7316 0.7311 65.164 3 180hmm (35) 0.7177 0.7237 0.8095 0.7201 0.7195 65.474 3 885hmm (50) 0.7267 0.7357 0.8373 0.7335 0.7328 66.317 6 300hmm (75) 0.7254 0.7361 0.8059 0.7434 0.7430 72.553 11 325

hmm (100) 0.7294 0.7354 0.8129 0.7408 0.7403 80.415 17 600gru (1) 0.3897 0.6400 0.4761 0.7414 0.7411 31.816 117gru (5) 0.7357 0.7296 0.8795 0.7866 0.7862 45.395 645

gru (10) 0.7488 0.7445 0.8892 0.7983 0.7979 58.102 1 440gru (15) 0.7529 0.7450 0.8912 0.8020 0.8021 61.025 2 385gru (20) 0.7535 0.7497 0.8887 0.8018 0.8017 61.214 3 480gru (25) 0.7578 0.7486 0.8902 0.8113 0.8114 62.029 4 725gru (30) 0.7602 0.7508 0.8927 0.8063 0.8061 72.854 6 120gru (35) 0.7522 0.7483 0.8900 0.8095 0.8091 74.091 7 665gru (50) 0.7431 0.7390 0.8895 0.8054 0.8051 76.543 13 200gru (75) 0.7408 0.7239 0.8837 0.8006 0.8000 87.422 25 425

gru (100) 0.7325 0.7273 0.8781 0.7977 0.7975 94.161 41 400grutree (100/0.01) 0.7276 0.7314 0.8776 0.7873 0.7867 91.797 41 400grutree (100/1.0) 0.7147 0.7040 0.8741 0.7812 0.7810 82.019 41 400grutree (100/8.0) 0.7232 0.7203 0.8763 0.7845 0.7840 73.767 41 400

grutree (100/20.0) 0.7123 0.7085 0.8733 0.7813 0.7813 65.035 41 400grutree (100/70.0) 0.7360 0.7376 0.8813 0.7988 0.7986 61.012 41 400

grutree (100/300.0) 0.7210 0.7197 0.8681 0.7676 0.7678 54.177 41 400grutree (100/2 000.0) 0.7230 0.7167 0.8335 0.7616 0.7619 48.206 41 400grutree (100/5 000.0) 0.6546 0.6552 0.6752 0.6668 0.6530 26.085 41 400grutree (100/7 000.0) 0.6063 0.6554 0.6565 0.6230 0.6138 20.214 41 400grutree (100/8 000.0) 0.5298 0.5242 0.5025 0.5026 0.5057 13.383 41 400

gruhmm (1/5) 0.4222 0.6472 0.4678 0.7478 0.7477 41.583 722gruhmm (1/10) 0.4007 0.6295 0.4730 0.7418 0.7419 61.041 1 517gruhmm (1/25) 0.4019 0.6207 0.4773 0.7353 0.7352 65.955 4 802gruhmm (1/50) 0.3999 0.6162 0.4772 0.7120 0.7121 70.534 13 277gruhmm (5/5) 0.7430 0.7372 0.8798 0.8009 0.8006 47.639 1 050

gruhmm (5/10) 0.7408 0.7320 0.8819 0.7991 0.7988 63.627 1 845gruhmm (5/25) 0.7365 0.7279 0.8776 0.7955 0.7952 68.215 5 130gruhmm (5/50) 0.7222 0.7107 0.8660 0.7814 0.7811 71.572 13 605gruhmm (10/5) 0.7468 0.7467 0.8949 0.8098 0.8097 50.902 1 505

gruhmm (10/10) 0.7490 0.7478 0.8958 0.8098 0.8096 63.522 2 300gruhmm (10/25) 0.7422 0.7407 0.8916 0.8055 0.8054 70.919 5 585gruhmm (10/50) 0.7254 0.7221 0.8824 0.7903 0.7903 71.297 14 060gruhmm (25/5) 0.7580 0.7568 0.8941 0.8236 0.8235 51.794 3 170

gruhmm (25/10) 0.7592 0.7563 0.8945 0.8225 0.8225 64.223 3 965gruhmm (25/25) 0.7525 0.7508 0.8912 0.8186 0.8184 72.480 7 250gruhmm (25/50) 0.7604 0.7583 0.8954 0.8106 0.8103 79.127 11 025gruhmm (50/5) 0.7655 0.7592 0.9006 0.8228 0.8226 64.229 6 945

gruhmm (50/10) 0.7648 0.7568 0.9003 0.8220 0.8219 69.281 7 740gruhmm (50/25) 0.7600 0.7555 0.8981 0.8205 0.8203 85.503 11 025gruhmm (50/50) 0.7412 0.7373 0.8910 0.8056 0.8055 101.637 19 500

gruhmmtree (50/50/0.5) 0.7432 0.7492 0.879 0.7854 0.7849 84.188 19 500gruhmmtree (50/50/20.0) 0.7435 0.747 0.8826 0.7914 0.7906 77.815 19 500gruhmmtree (50/50/50.0) 0.7384 0.7548 0.8914 0.7922 0.7918 71.719 19 500gruhmmtree (50/50/200.0 0.747 0.7502 0.8767 0.7832 0.7824 69.715 19 500

gruhmmtree (50/50/300.0) 0.7539 0.7623 0.8942 0.8092 0.8091 66.9 19 500gruhmmtree (50/50/600.0 0.7435 0.7453 0.8821 0.7909 0.7905 63.703 19 500

gruhmmtree (50/50/1 000.0) 0.7575 0.7502 0.8739 0.7882 0.7873 60.949 19 500gruhmmtree (50/50/3 000.0) 0.7396 0.7484 0.8926 0.8013 0.8011 54.751 19 500gruhmmtree (50/50/4 000.0) 0.7432 0.7511 0.8915 0.802 0.8024 44.868 19 500gruhmmtree (50/50/7 000.0) 0.7308 0.7477 0.8813 0.7881 0.7882 27.836 19 500gruhmmtree (50/50/9 000.0) 0.7132 0.7319 0.8261 0.7301 0.7299 0.0 19 500

Table C.2: Performance metrics for multi-dimensional classification on a held-out portion of the Sepsis dataset. Total AveragePath Length refers to the summed average path lengths across the 5 output dimensions. Refer to Fig. C.3 for average-path-lengthssplit across dimensions.

17

Page 18: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

C.7 HIV:Plots

(a) Therapy Success

15 20 25 30Average Path Length

0.50

0.55

0.60

0.65

0.70

0.75

AU

C (T

est)

GRU(L1)GRU(L2)GRU(Tree)Decision Tree

(b) CD4+ ≤ 200 cells/ml

10.0 12.5 15.0 17.5 20.0Average Path Length

0.55

0.60

0.65

0.70

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(c) Adherence

15.0 17.5 20.0 22.5 25.0 27.5Average Path Length

0.6

0.7

0.8

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(d) Mortality

10.0 12.5 15.0 17.5 20.0 22.5Average Path Length

0.70

0.75

0.80

0.85

AU

C (T

est)

GRUHMM (L1)GRUHMM (L2)GRUHMM (Tree)Decision Tree

(e) Onset of AIDS

4.0 4.5 5.0 5.5 6.0Average Path Length

0.54

0.56

0.58

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

Figure C.5: Performance and complexity trade-offs using L1, L2, and Tree regularization on GRU for the HIV dataset. The 5outputs shown here were trained jointly.

18

Page 19: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

C.8 HIV: AUCs

ModelPoor

Adherence MortalityCD4+

Count ≤ 200TherapySuccess

Total AveragePath Length

ParameterCount

logreg 0.6884 0.7031 0.5741 0.6092 38.942 1155decision tree 0.7100 0.7601 0.5937 0.6286 62.150 -

hmm (5) 0.7106 0.7611 0.6012 0.6265 41.864 865hmm (10) 0.7287 0.7627 0.6237 0.6409 46.309 1780hmm (25) 0.7243 0.7627 0.6327 0.6384 56.159 4825hmm (50) 0.7181 0.7639 0.6412 0.6370 69.014 10900hmm (75) 0.7244 0.7661 0.6294 0.6518 70.476 18225

hmm (100) 0.7261 0.7657 0.6287 0.6524 71.159 26800gru (5) 0.6457 0.6814 0.6695 0.6834 58.347 1310

gru (25) 0.7516 0.7986 0.7073 0.6991 60.072 8050gru (50) 0.7011 0.8290 0.6995 0.7054 67.513 19850gru (75) 0.7623 0.8514 0.7117 0.7490 64.870 35400

gru (100) 0.7340 0.8216 0.6981 0.7235 67.183 54700grutree (100/0.01) 0.7176 0.7948 0.7046 0.6803 91.020 54700

grutree (100/1.0) 0.7134 0.7997 0.7138 0.6892 86.774 54700grutree (100/20.0) 0.7157 0.8066 0.7216 0.7114 76.025 54700grutree (100/70.0) 0.7485 0.8210 0.7413 0.7060 68.952 54700

grutree (100/300.0) 0.7251 0.8178 0.7264 0.6746 54.058 54700grutree (100/2 000.0) 0.7030 0.8169 0.6342 0.6627 49.839 54700grutree (100/5 000.0) 0.6549 0.7582 0.6142 0.6352 23.895 54700grutree (100/7 000.0) 0.6167 0.7524 0.5740 0.5634 15.283 54700grutree (100/8 000.0) 0.5874 0.7412 0.5003 0.5027 7.391 54700

gruhmm (5/5) 0.6430 0.6647 0.5418 0.6479 67.619 2175gruhmm (5/10) 0.6708 0.6720 0.5879 0.6517 72.137 3090gruhmm (5/25) 0.6951 0.6981 0.6476 0.6955 68.200 6135gruhmm (5/50) 0.6810 0.7002 0.6760 0.7114 71.518 12210gruhmm (10/5) 0.7018 0.7147 0.7049 0.7208 64.852 3635

gruhmm (10/10) 0.7190 0.7378 0.7136 0.7578 73.252 4550gruhmm (10/25) 0.7264 0.7457 0.7217 0.7951 70.884 7595gruhmm (10/50) 0.7570 0.7522 0.7224 0.8234 69.726 13670gruhmm (25/10) 0.7462 0.7861 0.7152 0.8217 68.241 9830gruhmm (25/25) 0.7435 0.8102 0.7425 0.8186 79.261 12875gruhmm (25/50) 0.7484 0.7714 0.7501 0.8006 76.174 18950gruhmm (50/10) 0.7437 0.7668 0.7813 0.8260 70.081 21630gruhmm (50/25) 0.7380 0.7557 0.7824 0.8215 88.617 24675gruhmm (50/50) 0.7317 0.7684 0.7920 0.8007 97.864 30750

gruhmmtree (50/50/0.5) 0.7432 0.7692 0.8790 0.7804 73.168 30750gruhmmtree (50/50/50.0) 0.7426 0.8152 0.8914 0.7979 67.729 30750gruhmmtree (50/50/200.0 0.7461 0.8308 0.8767 0.8032 59.025 30750gruhmmtree (50/50/600.0 0.7467 0.8820 0.8821 0.8293 52.128 30750

gruhmmtree (50/50/1 000.0) 0.7375 0.8951 0.8739 0.7882 48.247 30750gruhmmtree (50/50/4 000.0) 0.7242 0.8461 0.8515 0.8030 14.868 30750gruhmmtree (50/50/7 000.0) 0.7280 0.8462 0.8313 0.7484 1.836 30750

Table C.3: Performance metrics for multi-dimensional classification on a held-out portion of the HIV dataset. Total Average PathLength refers to the summed average path lengths across the output dimensions.

19

Page 20: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

C.9 TIMIT:Plots/Tree Visualization

20 40 60Average path Length

0.5

0.6

0.7

0.8

0.9

1.0

AU

C (T

est)

GRU (L1)GRU (L2)GRU (Tree)Decision Tree

(a) TIMIT Stop Phonemes

MFCC 2 <= 1.025value = [2080661, 189997]

class = Non-Stop

MFCC 2 derivative <= 2.127value = [1766114, 53937]

class = Non-Stop

True

MFCC 1 derivative <= -0.331value = [314547, 136060]

class = Non-Stop

False

MFCC 2 derivative <= -1.662value = [1742321, 28522]

class = Non-Stop

Energy <= -0.456value = [23793, 25415]

class = Stop

MFCC 4 <= 0.245value = [53519, 83779]

class = Stop

MFCC 3 derivative <= -0.878value = [261028, 52281]

class = Non-Stop

MFCC 4 derivative <= -1.146value = [68837, 23291]

class = Non-Stop

Energy <= -0.565value = [1673484, 5231]

class = Non-Stopvalue = [0, 21347]

class = StopMFCC 6 derivative <= 0.812

value = [23793, 4068]class = Non-Stop

value = [28890, 0]class = Non-Stop

MFCC 3 derivative <= -0.071value = [24629, 83779]

class = Stop

Energy <= -0.603value = [4078, 26582]

class = Stop

MFCC 4 <= 0.414value = [256950, 25699]

class = Non-Stop

MFCC 6 derivative <= -0.627value = [7141, 23291]

class = Stopvalue = [61696, 0]class = Non-Stop

MFCC 2 derivative <= 1.366value = [382910, 5231]

class = Non-Stopvalue = [1290574, 0]

class = Non-Stopvalue = [19761, 0]class = Non-Stop

value = [4032, 4068]class = Stop

value = [0, 24726]class = Stop

MFCC 1 derivative <= -0.644value = [24629, 59053]

class = Stopvalue = [0, 24646]

class = Stopvalue = [4078, 1936]

class = Non-Stopvalue = [80106, 0]class = Non-Stop

Energy derivative <= -0.014value = [176844, 25699]

class = Non-Stop

value = [0, 19829]class = Stop

value = [7141, 3462]class = Non-Stop

value = [364646, 0]class = Non-Stop

MFCC 6 derivative <= 0.784value = [18264, 5231]

class = Non-Stop

Energy <= -0.645value = [3557, 49725]

class = Stop

MFCC 2 <= 1.325value = [21072, 9328]

class = Non-Stop

MFCC 2 <= 1.402value = [16519, 17934]

class = Stop

MFCC 2 derivative <= -0.737value = [160325, 7765]

class = Non-Stop

value = [14320, 0]class = Non-Stop

value = [3944, 5231]class = Stop

value = [0, 46482]class = Stop

value = [3557, 3243]class = Non-Stop

value = [14043, 0]class = Non-Stop

Energy derivative = -0.023value = [7029, 9328]

class = Stop

MFCC 1 <= 0.366value = [16519, 3453]

class = Non-Stopvalue = [0, 14481]

class = StopMFCC 1 <= 0.421

value = [8879, 7765]class = Non-Stop

value = [151446, 0]class = Non-Stop

value = [1886, 4533]class = Stop

value = [5143, 4795]class = Non-Stop

value = [2628, 3453]class = Stop

value = [13891, 0]class = Non-Stop

value = [4584, 6000]class = Stop

value = [4294, 1765]class = Non-Stop

(b) GRU:500

Figure C.6: (a) Performance and complexity trade-offs using L1, L2, and Tree regularization for GRU models on TIMIT. (b)Decision tree trained using λ = 500.0 tree regularization on GRU.

20

Page 21: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

C.10 TIMIT:AUCs

Model AUC Average Path Length Parameter Countlogreg 0.7747 23.460 27

decision tree 0.8668 59.2061 -hmm (5) 0.8900 51.911 295

hmm (10) 0.8981 56.273 640hmm (25) 0.9129 57.602 1 975hmm (50) 0.9189 63.752 5 200hmm (75) 0.9251 71.473 9 675

gru (1) 0.9169 42.602 86gru (5) 0.9451 49.275 490

gru (10) 0.9509 60.079 1 130gru (25) 0.9547 62.051 3 950gru (50) 0.9578 64.957 11 650gru (75) 0.9620 68.998 23 100

gruhmm (1/5) 0.9419 54.9723 381gruhmm (1/10) 0.9535 53.5642 726gruhmm (1/25) 0.9636 57.3290 2601

gruhmm (5/5) 0.9569 55.9531 785gruhmm (5/10) 0.9575 57.6199 1 130gruhmm (5/25) 0.9603 59.9925 2 465gruhmm (10/5) 0.9626 57.0652 1 425

gruhmm (10/10) 0.9641 60.7877 1 770gruhmm (10/25) 0.9651 61.0018 3 105gruhmm (25/5) 0.9635 57.5288 4 245

gruhmm (25/10) 0.9657 60.5212 4 590gruhmm (25/25) 0.9663 65.0161 5 925gruhmm (50/5) 0.9676 62.2378 11 945

gruhmm (50/10) 0.9679 65.1191 12 290gruhmm (50/25) 0.9685 67.4301 13 625grutree (75/0.01) 0.9517 66.2801 23 100

grutree (75/0.1) 0.9466 62.4316 23 100grutree (75/0.5) 0.9367 60.8764 23 100grutree (75/2.0) 0.9311 58.3659 23 100grutree (75/5.0) 0.9302 55.7588 23 100

grutree (75/10.0) 0.9288 46.6616 23 100grutree (75/100.0) 0.8911 40.1123 23 100grutree (75/500.0) 0.8998 28.4240 23 100grutree (75/700.0) 0.8628 25.136 23 100grutree (75/800.0) 0.7471 22.6671 23 100

grutree (75/1 000.0) 0.7082 17.1523 23 100grutree (75/6 000.0) 0.5441 11.1108 23 100grutree (75/7 000.0) 0.5088 8.9910 23 100

gruhmmtree (50/25/0.1) 0.9507 69.1110 13 625gruhmmtree (50/25/1.0) 0.9465 67.5773 13 625gruhmmtree (50/25/6.0) 0.9515 65.1494 13 625

gruhmmtree (50/25/20.0) 0.9449 64.0072 13 625gruhmmtree (50/25/30.0) 0.9482 62.5406 13 625gruhmmtree (50/25/70.0) 0.9460 58.0111 13 625

gruhmmtree (50/25/100.0) 0.9470 51.2417 13 625gruhmmtree (50/25/500.0) 0.9401 42.1882 13 625gruhmmtree (50/25/700.0) 0.9352 40.1281 13 625

gruhmmtree (50/25/1 000.0) 0.9390 38.0072 13 625gruhmmtree (50/25/3 000.0) 0.9280 25.9120 13 625gruhmmtree (50/25/4 000.0) 0.9311 21.7170 13 625gruhmmtree (50/25/7 000.0) 0.9290 10.1122 13 625gruhmmtree (50/25/9 000.0) 0.9134 1.0563 13 625

gruhmmtree (50/25/10 000.0) 0.9125 0.0000 13 625

Table C.4: Performance metrics across models on a held-out portion of the TIMIT dataset.

21

Page 22: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

D GRU-HMM: Deep Residual Timeseries ModelHidden Markov Model For our purposes, Hidden Markov Models (HMMs) can be viewed as stochastic RNNs which canbe interpreted as probabilistic generative models. In this work, we consider an HMM to generate a latent variable sequencez = [z1, . . . zT ] via a Markov chain, where each latent indicates one of K possible discrete states: zt ∈ 1, ...,K. This statesequence is then used to jointly produce the “data” xt and “outcomes” yt observed at each timestep. The joint distribution overz, x, y factorizes as:

p(z, y) = π0(z0)

T∏t=1

p(zt|zt−1, A) · p(xt|zt, φ)Bern(yt|σ(∑k

wkδk(zt))), (6)

where A is a transition matrix such that Ai,j = Pr(zt = i|zt−1 = j), π0 = p(z0) is the initial state distribution, φkKk=1 are theemission parameters that generate data. We can then apply the same objective as above for training.

GRU-HMM: Modeling the residuals of an HMM. We now consider an additional model, the GRU-HMM, designed forinterpretability. The idea is to use a GRU to to model the residual errors when predicting the binary target via the HMMbelief states. We can further penalize the complexity of the GRU predictions via our tree regularization, so that higher-qualitypredictions do not come at the price of a much less interpretable model.

1-

sigm sigm tanh

sigmht-1

ht yt

rt zth~t

x t-1

λ

st

xtxt-1 xt 1+ ……

… st-1 st 1+ …

Figure D.1: Deep residual model: GRU-HMM. The orange triangle indicates the output used in surrogate training for treeregularization.

We train the deep residual model on the same suite of synthetic and real world datasets. See Tables C.1, C.2, C.4 for acomparison of GRU-HMM with vanilla GRU and HMM models under different regularization and expressiveness parameters.We can see that across the datasets, deep residual models perform around 1% better than their vanilla equivalents with roughlythe same number of model parameters.

By nature of being a residual model, decision trees were trained only on the GRU output node, leaving the HMM unconstrained.See Figure D.1 for a pictoral representation. Similar to what we did for GRU models, figures C.1b, D.2 compare modelperformance as the λ parameter for L1, L2, and Tree regularization increase. We can see a similar albeit less pronounced effectwhere Tree regularization dominates other methods in low node count regions. It is important to notice the range of the AUC axisin these figures, where the worst the residual model can performance is the HMM-only AUC. Figure D.3 show the regularizedtrees produced by the GRU-HMM. Although they share some structure with Figure C.4, there are important distinctions thatencourage us to conclude that the GRU in a residual models performs a different role than when trained alone.

22

Page 23: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

D.1 GRU-HMM: Sepsis Plots

(a) In-Hospital Mortality

0 10 20 30Average Path Length

0.72

0.73

0.74

0.75

AU

C (T

est)

GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)

(b) 90-Day Mortality

0 5 10 15 20Average Path Length

0.73

0.74

0.75

0.76

AU

C (T

est)

GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)

(c) Mechanical Ventilation

0 2 4 6 8 10Average Path Length

0.82

0.84

0.86

0.88

0.90

AU

C (T

est)

GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)

(d) Median Vasopressor

0 5 10 15 20 25Average Path Length

0.74

0.76

0.78

0.80

AU

C (T

est)

GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)

(e) Max Vasopressor

0 5 10 15 20Average Path Length

0.74

0.76

0.78

0.80

AU

C (T

est)

GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)

Figure D.2: Performance and complexity trade-offs using L1, L2, and Tree regularization on GRU-HMM performance on theSepsis dataset.

D.2 GRU-HMM: Sepsis Tree Visualization

age <= 69.718value = [67747, 33105]

class = died_in_hosp:ON

BUN <= 22.133value = [52415, 4570]

class = died_in_hosp:ON

True

BUN <= 31.111value = [15332, 28535]

class = died_in_hosp:OFF

False

value = [32484, 0]class = died_in_hosp:ON

INR <= 1.896value = [19931, 4570]

class = died_in_hosp:ON

age <= 77.443value = [15332, 9641]

class = died_in_hosp:ONvalue = [0, 18894]

class = died_in_hosp:OFF

age <= 57.721value = [18980, 734]

class = died_in_hosp:ON

age <= 52.295value = [951, 3836]

class = died_in_hosp:OFF

PTT <= 41.791value = [8645, 1213]

class = died_in_hosp:ON

PTT <= 30.663value = [6687, 8428]

class = died_in_hosp:OFF

value = [8998, 0]class = died_in_hosp:ON

WBC_count <= 22.659value = [9982, 734]

class = died_in_hosp:ONvalue = [951, 620]

class = died_in_hosp:ONvalue = [0, 3216]

class = died_in_hosp:OFFvalue = [7785, 0]

class = died_in_hosp:ONINR <= 1.446

value = [860, 1213]class = died_in_hosp:OFF

Calcium <= 8.363value = [4565, 1988]

class = died_in_hosp:ON

Sodium <= 139.973value = [2122, 6440]

class = died_in_hosp:OFF

value = [9707, 0]class = died_in_hosp:ON

value = [275, 734]class = died_in_hosp:OFF

value = [522, 494]class = died_in_hosp:ON

value = [338, 719]class = died_in_hosp:OFF

Sodium <= 139.612value = [1680, 1988]

class = died_in_hosp:OFFvalue = [2885, 0]

class = died_in_hosp:ONArterial_BE <= -0.994value = [2122, 2658]

class = died_in_hosp:OFFvalue = [0, 3782]

class = died_in_hosp:OFF

value = [1087, 841]class = died_in_hosp:ON

value = [593, 1147]class = died_in_hosp:OFF

value = [381, 1081]class = died_in_hosp:OFF

Arterial_pH <= 7.41value = [1741, 1577]

class = died_in_hosp:ON

value = [589, 821]class = died_in_hosp:OFF

value = [1152, 756]class = died_in_hosp:ON

(a) In-Hospital Mortality

BUN <= 35.563value = [87796, 13056]

class = mortality_90d:ON

age <= 74.873value = [68699, 2792]

class = mortality_90d:ON

True

age <= 68.597value = [19097, 10264]

class = mortality_90d:ON

False

value = [50800, 0]class = mortality_90d:ON

HCO3 <= 21.996value = [17899, 2792]

class = mortality_90d:ON

INR <= 1.706value = [11194, 1634]

class = mortality_90d:ON

PTT <= 31.04value = [7903, 8630]

class = mortality_90d:OFF

WBC_count <= 13.352value = [2489, 2060]

class = mortality_90d:ON

WBC_count <= 19.981value = [15410, 732]

class = mortality_90d:ONvalue = [9400, 0]

class = mortality_90d:ONSysBP <= 112.842

value = [1794, 1634]class = mortality_90d:ON

WBC_count <= 17.45value = [4372, 2038]

class = mortality_90d:ON

WBC_count <= 14.916value = [3531, 6592]

class = mortality_90d:OFF

value = [2489, 0]class = mortality_90d:ON

value = [0, 2060]class = mortality_90d:OFF

value = [14746, 0]class = mortality_90d:ON

value = [664, 732]class = mortality_90d:OFF

value = [744, 1136]class = mortality_90d:OFF

value = [1050, 498]class = mortality_90d:ON

PAWmean <= 9.83value = [3784, 1132]

class = mortality_90d:ONvalue = [588, 906]

class = mortality_90d:OFFBUN <= 72.07

value = [3531, 3094]class = mortality_90d:ON

value = [0, 3498]class = mortality_90d:OFF

value = [2124, 0]class = mortality_90d:ON

SysBP <= 114.366value = [1660, 1132]

class = mortality_90d:ON

Arterial_pH <= 7.361value = [3276, 1864]

class = mortality_90d:ONvalue = [255, 1230]

class = mortality_90d:OFF

value = [573, 612]class = mortality_90d:OFF

value = [1087, 520]class = mortality_90d:ON

value = [612, 1174]class = mortality_90d:OFF

INR <= 1.683value = [2664, 690]

class = mortality_90d:ON

value = [2109, 0]class = mortality_90d:ON

value = [555, 690]class = mortality_90d:OFF

(b) 90-Day Mortality

GCS <= 13.999value = [20205, 80647]class = mechvent:OFF

GCS <= 11.98value = [2026, 65848]class = mentvent:OFF

True

FiO2_100 <= 36.012value = [18179, 14799]class = mechvent:ON

False

value = [0, 56767]class = mechvent:OFF

FiO2_100 <= 36.065value = [2026, 9081]

class = mechvent:OFFvalue = [13502, 0]

class = mechvent:ONCO2_mEqL <= 23.03value = [4677, 14799]class = mechvent:OFF

value = [2026, 0]class = mechvent:ON

value = [0, 9081]class = mechvent:OFF

paO2 <= 114.951value = [3694, 2015]class = mechvent:ON

FiO2_100 <= 44.082value = [983, 12784]

class = mechvent:OFF

value = [3694, 0]class = mechvent:ON

value = [0, 2015]class = mechvent:OFF

SpO2 <= 96.803value = [983, 3729]

class = mechvent:OFFvalue = [0, 9055]

class = mechvent:OFF

value = [983, 965]class = mechvent:ON

value = [0, 2764]class = mechvent:OFF

(c) Mechanical Ventilation

Total_bili <= 9.574value = [96006, 4846]

class = median_dose_vaso:ON

SysBP <= 109.952value = [95031, 2153]

class = median_dose_vaso:ON

True

SysBP <= 115.002value = [975, 2693]

class = median_dose_vaso:OFF

False

Arterial_BE <= -3.222value = [31191, 2153]

class = median_dose_vaso:ONvalue = [63840, 0]

class = median_dose_vaso:ONvalue = [0, 2171]

class = median_dose_vaso:OFFvalue = [975, 522]

class = median_dose_vaso:ON

PTT <= 46.332value = [5138, 2153]

class = median_dose_vaso:ONvalue = [26053, 0]

class = median_dose_vaso:ON

Hb <= 10.206value = [4420, 1040]

class = median_dose_vaso:ONvalue = [718, 1113]

class = median_dose_vaso:OFF

value = [2925, 0]class = median_dose_vaso:ON

PAWmean <= 11.874value = [1495, 1040]

class = median_dose_vaso:ON

value = [1022, 422]class = median_dose_vaso:ON

value = [473, 598]class = median_dose_vaso:OFF

(d) Median Vasopressor

Total_bili <= 9.58value = [97028, 3824]

class = max_dose_vaso:ON

SysBP <= 109.952value = [96056, 1129]

class = max_dose_vaso:ON

True

SysBP <= 115.002value = [972, 2695]

class = max_dose_vaso:OFF

False

Arterial_BE <= -3.222value = [32215, 1129]

class = max_dose_vaso:ONvalue = [63841, 0]

class = max_dose_vaso:ONvalue = [0, 2171]

class = max_dose_vaso:OFFvalue = [972, 524]

class = max_dose_vaso:ON

PTT <= 46.235value = [6162, 1129]

class = max_dose_vaso:ONvalue = [26053, 0]

class = max_dose_vaso:ON

value = [5450, 0]class = max_dose_vaso:ON

value = [712, 1129]class = max_dose_vaso:OFF

(e) Max Vasopressor

Figure D.3: Decision trees trained using Tree regularization (λ = 2000.0) from GRU-HMM predictions on the Sepsis dataset.

23

Page 24: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

D.3 GRU-HMM: HIV Plots/Tree Visualization

Baseline CD4 <= 15486.29value = [21797, 52088]

class = CD4 < 200: OFF

Baseline VL <= 359.17value = [7653, 5088]

class = CD4 < 200: ON

True

Baseline CD4 <= 45000.07value = [14144, 47000]class=CD4 < 200:OFF

Baseline VL <= 201.6value = [4389, 4657]

class=CD4 < 200:OFF

Baseline VL <= 45639.51value = [3624, 431]

class=CD4 < 200:ON

MU_PR69 <= 0.5value = [2846, 4111]

class=CD4 < 200:OFF

HIST AZT <= 0.5value = [1543, 546]

class=CD4 < 200:ON

HIST EFV <= 0.5value = [2116, 431]

class=CD4 < 200:ONvalue = [1148, 0]

class=CD4 < 200:ON

value = [1435, 246]class=CD4 < 200:ON

value = [681, 185]class=CD4 < 200:ON

value = [0, 3094]class=CD4 < 200:OFF

value = [2846, 1017]class=CD4 < 200:OFF

value = [0, 546]class=CD4 < 200:ON

value = [1543, 0]class=CD4 < 200:ON

Baseline VL <= 0.5value = [6402, 26851]class=CD4 < 200:OFF

MU_RT210 <= 0.5value = [7742, 20149]class=CD4 < 200:OFF

value = [3699, 6318]class=CD4 < 200:OFF

value = [2703, 20533]class=CD4 < 200:OFF

value = [3574, 16742]class=CD4 < 200:OFF

MU_RT181 <= 0.5value = [4168, 3407]

class=CD4 < 200:OFF

value = [2687, 2806]class=CD4 < 200:OFF

value = [1481, 601]class=CD4 < 200:OFF

(a) GRU-HMM: CD4+ ≤ 200 cells/ml

Baseline CD4 <= 15486.29value = [21797, 52088]

class = CD4 < 200: OFF

Baseline VL <= 359.17value = [7653, 5088]

class = CD4 < 200: ON

True

Baseline CD4 <= 45000.07value = [14144, 47000]class=CD4 < 200:OFF

Baseline VL <= 201.6value = [4389, 4657]

class=CD4 < 200:OFF

Baseline VL <= 45639.51value = [3624, 431]

class=CD4 < 200:ON

MU_PR69 <= 0.5value = [2846, 4111]

class=CD4 < 200:OFFvalue = [1543, 546]

class=CD4 < 200:ONHIST EFV <= 0.5

value = [2116, 431]class=CD4 < 200:ON

value = [1148, 0]class=CD4 < 200:ON

value = [1435, 246]class=CD4 < 200:ON

value = [681, 185]class=CD4 < 200:ON

value = [0, 3094]class=CD4 < 200:OFF

value = [2846, 1017]class=CD4 < 200:OFF

Baseline VL <= 0.5value = [6402, 26851]class=CD4 < 200:OFF

MU_RT210 <= 0.5value = [7742, 20149]class=CD4 < 200:OFF

value = [3699, 6318]class=CD4 < 200:OFF

value = [2703, 20533]class=CD4 < 200:OFF

value = [3574, 16742]class=CD4 < 200:OFF

value = [4168, 3407]class=CD4 < 200:OFF

(b) GRU-HMM: CD4+ ≤ 200 cells/ml

Figure D.4: HIV task: Study of different regularization techniques for GRU-HMM model with 75 GRU nodes and 25 HMMstates, trained to predict whether CD4+ ≤ 200 cells/ml. (a) Example decision tree for λ = 1000.0. (b) Example decision tree forλ = 3000.0. The tree in (b) is slightly smaller than the tree in (a) as a result of the regularisation.

D.4 GRU-HMM: TIMIT Plots/Tree Visualization

0 20 40 60 80AvHUagH 3ath LHngth

0.92

0.93

0.94

0.95

A8

C (T

Hst)

G58-H00 (L1)G58-H00 (L2)G58-H00 (TUHH)

(a) GRU-HMM: Stop vs Non-Stop

MFCC 2 <= 1.074value = [2041620, 229038]

class = Non-Stop

MFCC 2 derivative <= 2.058value = [1774382, 81420]

class = Non-Stop

True

MFCC 1 derivative <= -0.319value = [267238, 147618]

class = Non-Stop

False

MFCC 2 derivative <= -1.676value = [1751198, 490941]

class = Non-Stop

Energy <= -0.272value = [23184, 32329]

class = Stop

MFCC 4 <= 0.166value = [34386, 96249]

class = Stop

Energy derivative <= -0.015value = [232852, 51369]

class = Non-Stop

MFCC 4 derivative <= -0.805value = [63000, 29051]

class = Non-Stop

Energy <= -0.541value = [1688198, 20040]

class = Non-Stopvalue = [0, 32329]

class = Stopvalue = [23184, 0]class = Non-Stop

value = [23051, 0]class = Non-Stop

MFCC 1 derivative <= -0.534value = [11335, 96249]

class = Stop

Energy <= -0.592value = [30893, 43427]

class = Stop

MFCC 2 derivative <= -1.117value = [201959, 7942]

class = Non-Stop

Energy derivative <= 1.213value = [11008, 29051]

class = Stopvalue = [51992, 0]class = Non-Stop

Energy derivative <= -0.42value = [428353, 20040]

class = Non-Stopvalue [ 1259845, 0]class = Non-Stop

value = [0, 73721]class = Stop

MFCC 3 derivative <= -0.047value = [11335, 22528]

class = Stop

MFCC 3 derivative <= -0.678value = [14822, 40733]

class = Stopvalue = [16071, 2694]

class = Non-Stopvalue = [6189, 7942]

class = Stopvalue = [195770, 0]class = Non-Stop

value = [0, 25936]class = Stop

value = [11008, 3115]class = Non-Stop

MFCC 2 derivative <= 1.192value = [33091, 8671]

class = Non-Stop

MFCC 1 derivative <= -0.329value = [395262, 11369]

class = Non-Stopvalue = [1156, 10342]

class = StopMFCC 2 derivative <= -0.113

value = [10179, 12186]class = Stop

value = [0, 25373]class = Stop

MFCC 4 <= 0.532value = [14822, 15360]

class = Stop

value = [25713, 0]class = Non-Stop

value = [7378, 8671]class = Stop

MFCC 4 <= 0.581value = [86945, 11369]

class = Non-Stopvalue = [308317, 0]class = Non-Stop

value = [2999, 7001]class = Stop

value = [7180, 5185]class = Non-Stop

value = [7440, 2757]class = Non-Stop

value = [7382, 12603]class = Stop

value = [65989, 0]class = Non-Stop

MFCC 2 <= 0.742value = [20956, 11369]

class = Non-Stop

value = [14291, 3496]class = Non-Stop

value = [6665, 7873]class = Stop

(b) GRU-HMM: Stop vs Non-Stop

MFCC 2 <= 1.074value = [2052943, 217715]

class = Non-Stop

MFCC 2 derivative <= 2.058value = [1807658, 48144]

class = Non-Stop

True

MFCC 1 derivative <= -0.319value = [245285, 169571]

class = Non-Stop

False

MFCC 2 derivative <= -1.676value = [1778423, 21866]

class = Non-Stop

Energy <= -0.272value = [29235, 26278]

class = Non-Stop

MFCC 4 <= 0.166value = [18549, 112086]

class = Stop

Energy derivative <= -0.015value = [226736, 57485]

class = Non-Stop

MFCC 4 derivative <= -0.805value = [70185, 21866]

class = Non-Stopvalue = [1708238, 0]

class = Non-Stopvalue = [9159, 23170]

class = Stopvalue = [20076, 3108]

class = Non-Stopvalue = [18549, 4502]

class = Non-Stopvalue = [0, 107584]

class = StopEnergy <= -0.605

value = [16835, 57483]class = Stop

value = [209901, 0]class = Non-Stop

Energy derivative <= 0.772value = [18193, 21866]

class = Stopvalue = [51992, 0]class = Non-Stop

value = [0, 54243]class = Stop

value = [16835, 3242]class = Non-Stop

value = [4748, 15281]class = Stop

value = [13445, 6585]class = Non-Stop

(c) GRU-HMM: Stop vs Non-Stop

Figure D.5: TIMIT task: Study of different regularization techniques for GRU-HMM model with 75 GRU nodes and 25 HMMstates, trained to predict STOP phonemes. (a) Tradeoff curves showing how AUC predictive power and decision-tree complexityevolve with increasing regularization strength under L1, L2, or Tree regularization. (b) Example decision tree for λ = 3000.0.(c) Example decision tree for λ = 7000.0. When comparing with figure C.6b, this tree is significantly smaller, suggesting thatthe GRU performs a different role in the residual model.

24

Page 25: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

E Runtime comparisonsTraining Time for Tree-Regularized Models. Table E.1 shows the wall time for training one epoch of each of the modelspresented in this paper using each of the datasets. Please note that the wall times for GRU-TREE and GRU-HMM-TREE includethe cost of surrogate training. If the retraining frequency is small, then the amortized cost should be small.

Dataset Model Epoch Time (Sec.)Signal-and-noise HMM HMM 16.66± 2.53Signal-and-noise HMM GRU 30.48± 1.92Signal-and-noise HMM GRU-HMM 50.40± 5.56Signal-and-noise HMM GRU-TREE 43.83± 3.84Signal-and-noise HMM GRU-HMM-TREE 73.24± 7.86SEPSIS HMM 589.80± 24.11SEPSIS GRU 822.27± 11.17SEPSIS GRU-HMM 1 666.98± 147.00SEPSIS GRU-TREE 2 015.15± 388.12SEPSIS GRU-HMM-TREE 2 443.66± 351.22TIMIT HMM 1 668.96± 126.96TIMIT GRU 2 116.83± 438.83TIMIT GRU-HMM 3207.16± 651.85TIMIT GRU-TREE 3 977.01± 812.11TIMIT GRU-HMM-TREE 4 601.44± 805.88

Table E.1: Training time for recurrent models measured against all datasets used in this paper. Epoch time denotes the number ofseconds it took for a single pass through all the training data. The epoch times for GRU-TREE and GRU-HMM-TREE includesurrogate training expenses. If we retrain sparsely, then the cost of surrogate training is amortized and the epoch time for GRUand GRU-TREE, GRU-HMM and GRU-HMM-TREE are approximately the same. To measure epoch time, we used 10 HMMstates, 10 GRU states, and 5 of each for GRU-HMM models. We trained the surrogate model for 5000 epochs. These tests wererun on a single Intel Core i5 CPU.

25

Page 26: Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio

F Extended Stability TestsIn the paper, we noted that decision trees are stable over multiple run. Here, we show that using the signal-and-noise HMM dataset,10 independent runs with random initializations and λ = 1000.0 produce either the same or comparable trees. Additionally, weshow that with weak regularization (λ = 0.01), the variability of the learned decision trees is high. Figures F.1, F.2 includeexamples of such trees on the signal-and-noise dataset. Similar results are found for real-world datasets.

X[0] <= 0.5value = [4413, 587]

class = off

value = [2546, 0]class = off

True

X[3] <= 0.5value = [1867, 587]

class = off

False

value = [1744, 0]class = off

X[4] <= 0.5value = [123, 587]

class = on

value = [0, 587]class = on

value = [123, 0]class = off

(a) 7/10 Runs

X[0] <= 0.5value = [4467, 533]

class = off

value = [2546, 0]class = off

True

X[3] <= 0.5value = [1921, 533]

class = off

False

value = [1744, 0]class = off

X[4] <= 0.5value = [177, 533]

class = on

X[8] <= 0.5value = [54, 533]

class = onvalue = [123, 0]

class = off

value = [0, 475]class = on

X[9] <= 0.5value = [54, 58]

class = on

value = [0, 58]class = on

value = [54, 0]class = off

(b) 2/10 Runs

X[13] <= 0.5value = [4901, 99]

class = off

value = [4514, 0]class = off

True

X[1] <= 0.5value = [387, 99]

class = off

False

X[3] <= 0.5value = [161, 99]

class = offvalue = [226, 0]

class = off

X[12] <= 0.5value = [89, 99]

class = onvalue = [72, 0]

class = off

value = [89, 0]class = off

value = [0, 99]class = on

(c) 1/10 Runs

Figure F.1: Decision trees from 10 independent runs on the signal-and-noise HMM dataset with λ = 1000.0. Seven of the tenruns resulted in a tree of the same structure. The other three trees are similar, often having additional subtrees but sharing thesame splits and features.

X[2] <= 0.5value = [3149, 1851]

class = off

X[1] <= 0.5value = [2267, 291]

class = off

True

X[1] <= 0.5value = [882, 1560]

class = on

False

X[13] <= 0.5value = [1006, 291]

class = offvalue = [1261, 0]

class = off

X[4] <= 0.5value = [1006, 154]

class = offvalue = [0, 137]

class = on

X[5] <= 0.5value = [791, 154]

class = offvalue = [215, 0]

class = off

X[0] <= 0.5value = [791, 88]

class = offvalue = [0, 66]

class = on

value = [453, 0]class = off

X[12] <= 0.5value = [338, 88]

class = off

X[11] <= 0.5value = [279, 88]

class = offvalue = [59, 0]

class = off

value = [279, 0]class = off

value = [0, 88]class = on

value = [0, 1244]class = on

X[4] <= 0.5value = [882, 316]

class = off

X[5] <= 0.5value = [673, 316]

class = offvalue = [209, 0]

class = off

X[13] <= 0.5value = [673, 237]

class = offvalue = [0, 79]

class = on

X[11] <= 0.5value = [673, 167]

class = offvalue = [0, 70]

class = on

value = [606, 0]class = off

X[12] <= 0.5value = [67, 167]

class = on

value = [0, 167]class = on

value = [67, 0]class = off

(a)

X[4] <= 0.5value = [988, 4012]

class = on

X[2] <= 0.5value = [332, 3755]

class = on

True

X[2] <= 0.5value = [656, 257]

class = off

False

X[9] <= 0.5value = [332, 1764]

class = onvalue = [0, 1991]

class = on

X[10] <= 0.5value = [65, 1408]

class = on

X[10] <= 0.5value = [267, 356]

class = on

X[13] <= 0.5value = [65, 975]

class = onvalue = [0, 433]

class = on

X[3] <= 0.5value = [65, 759]

class = onvalue = [0, 216]

class = on

X[6] <= 0.5value = [65, 514]

class = onvalue = [0, 245]

class = on

X[5] <= 0.5value = [65, 453]

class = onvalue = [0, 61]

class = on

value = [0, 453]class = on

value = [65, 0]class = off

X[3] <= 0.5value = [267, 152]

class = offvalue = [0, 204]

class = on

X[6] <= 0.5value = [267, 24]

class = offvalue = [0, 128]

class = on

value = [259, 0]class = off

value = [8, 24]class = on

X[13] <= 0.5value = [429, 33]

class = off

X[10] <= 0.5value = [227, 224]

class = off

value = [420, 0]class = off

value = [9, 33]class = on

X[13] <= 0.5value = [227, 90]

class = offvalue = [0, 134]

class = on

X[3] <= 0.5value = [225, 47]

class = offvalue = [2, 43]

class = on

value = [193, 0]class = off

X[9] <= 0.5value = [32, 47]

class = on

value = [10, 32]class = on

value = [22, 15]class = off

(b)

X[12] <= 0.5value = [1574, 3426]

class = on

X[2] <= 0.5value = [902, 3112]

class = on

True

X[0] <= 0.5value = [672, 314]

class = off

False

X[9] <= 0.5value = [802, 1242]

class = on

X[9] <= 0.5value = [100, 1870]

class = on

X[0] <= 0.5value = [313, 979]

class = on

X[3] <= 0.5value = [489, 263]

class = off

X[11] <= 0.5value = [313, 359]

class = onvalue = [0, 620]

class = on

X[3] <= 0.5value = [236, 226]

class = off

X[10] <= 0.5value = [77, 133]

class = on

X[8] <= 0.5value = [236, 87]

class = offvalue = [0, 139]

class = on

value = [236, 0]class = off

value = [0, 87]class = on

X[13] <= 0.5value = [77, 59]

class = offvalue = [0, 74]

class = on

X[3] <= 0.5value = [76, 27]

class = offvalue = [1, 32]

class = on

value = [71, 0]class = off

value = [5, 27]class = on

X[10] <= 0.5value = [419, 110]

class = off

X[8] <= 0.5value = [70, 153]

class = on

X[0] <= 0.5value = [322, 28]

class = off

X[4] <= 0.5value = [97, 82]

class = off

value = [178, 0]class = off

X[4] <= 0.5value = [144, 28]

class = off

X[8] <= 0.5value = [115, 27]

class = offvalue = [29, 1]

class = off

value = [91, 0]class = off

value = [24, 27]class = on

X[0] <= 0.5value = [68, 80]

class = onvalue = [29, 2]

class = off

value = [68, 0]class = off

value = [0, 80]class = on

X[10] <= 0.5value = [70, 83]

class = onvalue = [0, 70]

class = on

X[0] <= 0.5value = [70, 29]

class = offvalue = [0, 54]

class = on

value = [52, 0]class = off

value = [18, 29]class = on

value = [0, 1192]class = on

X[0] <= 0.5value = [100, 678]

class = on

X[3] <= 0.5value = [100, 285]

class = onvalue = [0, 393]

class = on

X[10] <= 0.5value = [100, 162]

class = onvalue = [0, 123]

class = on

X[8] <= 0.5value = [100, 75]

class = offvalue = [0, 87]

class = on

X[11] <= 0.5value = [100, 15]

class = offvalue = [0, 60]

class = on

value = [88, 0]class = off

value = [12, 15]class = on

value = [494, 0]class = off

X[2] <= 0.5value = [178, 314]

class = on

X[13] <= 0.5value = [146, 92]

class = off

X[4] <= 0.5value = [32, 222]

class = on

X[10] <= 0.5value = [146, 34]

class = offvalue = [0, 58]

class = on

value = [110, 0]class = off

X[11] <= 0.5value = [36, 34]

class = off

value = [26, 9]class = off

value = [10, 25]class = on

value = [0, 204]class = on

value = [32, 18]class = off

(c)

X[1] <= 0.5value = [3543, 1457]

class = off

X[10] <= 0.5value = [2280, 261]

class = off

True

X[3] <= 0.5value = [1263, 1196]

class = off

False

value = [1798, 0]class = off

X[3] <= 0.5value = [482, 261]

class = off

X[11] <= 0.5value = [261, 261]

class = offvalue = [221, 0]

class = off

X[6] <= 0.5value = [105, 261]

class = onvalue = [156, 0]

class = off

X[2] <= 0.5value = [77, 260]

class = onvalue = [28, 1]

class = off

value = [0, 182]class = on

X[0] <= 0.5value = [77, 78]

class = on

value = [0, 78]class = on

value = [77, 0]class = off

X[6] <= 0.5value = [629, 1091]

class = on

X[10] <= 0.5value = [634, 105]

class = off

X[7] <= 0.5value = [470, 1091]

class = onvalue = [159, 0]

class = off

X[11] <= 0.5value = [337, 1091]

class = onvalue = [133, 0]

class = off

value = [0, 914]class = on

X[10] <= 0.5value = [337, 177]

class = off

value = [337, 0]class = off

value = [0, 177]class = on

value = [517, 0]class = off

X[11] <= 0.5value = [117, 105]

class = off

X[0] <= 0.5value = [25, 105]

class = onvalue = [92, 0]

class = off

value = [0, 70]class = on

X[2] <= 0.5value = [25, 35]

class = on

value = [10, 23]class = on

value = [15, 12]class = off

(d)

Figure F.2: Decision trees from 10 independent runs on the signal-and-noise HMM dataset with λ = 0.01. With low regularization,the variance in tree size and shape is high.

26