$L_0$ norm regularisation1 is a pretty fascinating technique for neural network pruning or for training sparse networks, where weights are encouraged to be completely 0. It is easy to implement, with only a few lines of code (see below), but getting there conceptually is not so easy.

Several ML tricks are needed to achieve gradient flow through the network, because the default $L_0$ regularisation loss is non-differentiable for evolving reasons. (While solving one problem we introduce another and the loss function “evolves”.)

The form of the loss is $\mathcal{L}(f(x; \tilde{\theta} \odot z), y) + \mathcal{L}_{\mathrm{reg}}$, where $z$ is a discrete binary mask on the parameters $\tilde{\theta}$. We use $\tilde{\theta}$ because ultimately the parameters that we care about are not $\tilde{\theta}$ exactly, but $\theta = \tilde{\theta} \odot z$

The final solution involves sampling from a hard-concrete distribution; which is obtained by stretching a binary-concrete distribution and then transforming the samples with a hard-sigmoid.


Preliminaries

$L_p$ regularisation

Regularisation adds a term $\mathcal{L}_{\mathrm{reg}}$ to the loss function, which penalises the complexity of solution ($\theta$ weights) `typically’ used to avoid overfitting and reduce generalisation error. The Maximum Likelihood Estimate of model parameters $\theta$ is given by

\[\hat{\theta}_{\mathrm{MLE}} = \mathrm{argmin}_{\theta} \frac{1}{N} \sum_{i=1}^N \mathcal{L}(f(x_i, \theta), y_i) + \lambda \mathcal{L}_{\mathrm{reg}}\]

$L_p$ regularisation is a type of penalising cost based on the p-norm of the $\theta$ vector, $\mathcal{L}_{\mathrm{reg}} = \mid \mid \theta \mid \mid_p$, where $\mid \mid \theta \mid \mid_p = (\mid\theta_1 \mid^p + \mid\theta_2 \mid^p + \cdots)^{\frac{1}{p}}$. $L_1$ and $L_2$ regularisation are typically used in gradient based methods, but $L_0$ regularisation involves counting of non-zero weights, and is non-differentiable.

Note: $L_2$ norm is continuously differentiable but $L_1$ is not continuously differentiable (at $\theta=0$).

Reparameterisation Trick

The reparameterization trick is used when we want to sample from a distribution (and learn the parameters of that distribution). The “ trick” is to reparameterise the distribution, such that a sample has a deterministic differentiable) and noise non-differentiable component.2 This means re-expressing the sampling function as dependent on trainable parameters and some independent noise.

Fpr example, a sample from $\mathcal{N}(\mu, \sigma^2)$ can be obtained by sampling $u$ from the standard form of the normal distribution, $u \sim \mathcal{N}(0, 1)$ and then transforming it using $\mu + \sigma u$. This reparameterisation makes it possible to reduce the problem of estimating gradients wrt parameters of a distribution, to estimating gradients wrt parameters of a deterministic function.

Concrete Distributions

The class of “Concrete” distributions was invented to enable discrete distributions to use the reparameterisation trick, by approximating discrete distributions as continuous distributions.3 The high level strategy is to first, relax the state of a discrete variable into a probability vector by adding noise. Second, use a softmax (or logistic in the case of binary) function instead of an argmax over the probabilities. Sampling from the Concrete distribution then becomes taking the softmax of logits, perturbed by fixed additive noise.

Note: Don’t overthink the semantics of “Concrete”; it’s just a (in my opinion poor) name and stands for a “CONtinuous relaxation of disCRETE random variables”.



Method

Problem: $L_0$ Regularisation Cost is Non-differentiable
Solution: Use the probability rather than the counts, of the weights being 0

Writing out $L_0$ regularisation, the maximum likelihood estimate is given by

\[\hat{\theta} = \mathrm{argmin}_{\theta} \frac{1}{N}(\sum_{i=1}^N \mathcal{L}(f(x_i; \theta), y_i)) + \lambda \mid \mid \theta \mid \mid_0 \tag{eq 1}\label{eq:1}\]

Where $\mid \mid \theta \mid \mid_0 = \sum_{j=1}^{\mid \theta |} \mathbb{1} [\theta_j \neq 0]$. This loss is non-differentiable because the counting of parameters is non-differentiable.

To work around this, a soft form of counting is required, i.e., the probability of the weights being 0. We thus consider $\theta = \tilde{\theta} \odot z$, where $\odot$ is element-wise multiplication. The variable $z \sim \mathrm{Bernoulli}(\pi)$ can be viewed as $\{ 0,1 \}$ gates, which determine if the parameter $\theta$ is effectively present or absent. The probability of $z$ being 0 or 1, is controlled by the parameter $\pi$. We therefore need to learn $\pi$.

\[\pi^* = \mathrm{argmin}_{\pi} \mathbb{E}_{z \sim Bern(\pi)} \frac{1}{N} \sum_{i=1}^N \mathcal{L} (h(x_i, \tilde{\theta} \odot z), y_i) + \lambda \sum_{j=1}^{\mid \theta \mid} \pi_j \tag{eq 2}\label{eq:2}\]

The regularisation cost is now differentiable because instead of raw counts of $\theta$, \eqref{eq:1} we are summing the average probability ($\pi$) of the gates $z$ being 0, and thus the parameters $\theta=\tilde{\theta} \odot z$ being 0. $\pi_j$ is the parameter of each Bernoulli distribution that corresponds to a binary gate.

At this point, we have solved the problem of parameter counting, but still cannot use gradient based optimization for $\pi$ because the $z$ we introduced is a discrete stochastic random variable.


Problem 2: The gated parameters $\tilde{\theta}\odot z$ are non-differentiable because the masks $z \in \{0, 1\}$ are i) discrete, ii) stochastic
Solution 2i: Sample random variables from Binary Concrete Distribution
Solution 2ii: Apply Reparameterisation Trick

We have solved the first problem of the regularisation term $L_{\mathrm{reg}}$ being differentiable by reformulating $\mid \mid \theta \mid \mid_0 \rightarrow \sum_{j=1}^{|\theta|} \pi_j$. But in doing so, we rewrote the term $h(x; \theta) \rightarrow h(x; \tilde{\theta}\odot z)$. Since $z$ is stochastic, gradient does not flow and we would like to employ the reparameterisation trick. However, we are not able to reparameterise the discrete distribution due to the discontinuous nature of discrete states. Therefore, we need to first approximate the Bernoulli with a Binary Concrete distribution.

Next we apply the reparameterisation trick on the Binary Concrete distribution, resulting in learnable parameters $(\mathrm{log} \alpha)$ + some noise which is gumbel distributed. The noise takes the form $\log (u) - log(1-u)$, where $u \sim Uniform(0,1)$.

Let $s$ be a random variable distributed in the (0, 1) interval sampled from a Binary Concrete distribution. After applying the reparameterisation trick (details in Louizos 2017), we can sample

\[s = \mathrm{Sigmoid}((\mathrm{log} u - \mathrm{log} (1-u) + \mathrm{log} \alpha) / \beta)\]

where $u \sim \mathrm{Uniform}(0, 1)$. Here $\mathrm{log}\alpha$ is the location parameter and $\beta$ is the temperature. The temperature controls the degree of approximation. With $\beta = 0$ we recover the original Bernoulli r.v. (but lose the differentiable properties). $\alpha$ and $\beta$ are now trainable parameters, while the stochasticity comes from $u \sim U(0, 1)$.


Problem: The continuous distribution has too much probability mass which are not at 0 and 1.
Solution: “stretch” this distribution beyond (0,1) and “fold” it back. \

We can “stretch” the samples from the distribution to $(\gamma, \zeta)$ interval, where $\gamma <0$ and $\zeta>1$. $\tilde{s} = s(\zeta - \gamma) + \gamma$, then apply a hard-sigmoid to fold the samples back to the interval (0, 1). $z=\mathrm{min}(1, \mathrm{max}(0, \tilde{s}))$.

def sample_z(self):
  if self.training:
    # sample s from binary concrete
    u = torch.FloatTensor(self.num_heads).uniform_().cuda()
    s_ = torch.sigmoid((torch.log(u) - torch.log(1-u) + self.log_alpha) / self.beta)
    
  else: 
    # test time
    # sample without noise
    s_ = torch.sigmoid(self.log_alpha)

  # stretch values and fold them back to (0,1)
  s_ = s_ * (self.zeta - self.gamma) + self.gamma
  z = torch.clip(s_, min=0, max=1)
  return z


Problem: $z$ is no longer drawn from a Bernouli, so what should be the new regularisation term?
Solution: Compute the probability of $z$ being 0, but under a CDF.

ecall the regularisation term $L_{\mathrm{reg}}$ has evolved from no. Of non-zero parameters \eqref{eq:1} , to probability of being 0 under a Bernouli distribution \eqref{eq:2}.

We still want to compute the probability of being 0 but since we now have a continuous instead of discrete Bernoulli, we need the cumulative distribution function (CDF) $Q(s \mid \alpha, \beta)$.

\[\pi^* = \mathrm{argmin}_{\pi} \mathbb{E}_{z \sim Bern(\pi)} \frac{1}{N} \sum_{i=1}^N \mathcal{L} (h(x_i, \tilde{\theta} \odot z), y_i) + \lambda \sum_{j=1}^{\mid \theta \mid} (1-Q(s_j \leq0 \mid \alpha_j, \beta_j)) \tag{eq 3}\label{eq:3}\]

The regularisation cost works out to be

\[\sum_{j=1}^{\mid \theta \mid}(1-Q_{s_j}(0 \mid \alpha, \beta)) = \sum_{j=1}^{\mid \theta \mid} \mathrm{sigmoid}(\mathrm{log} \alpha_j - \beta\times \mathrm{log}\frac{-\gamma}{\zeta})\]
self.log_ratio_ = math.log(-gamma / self.zeta)
def get_reg_cost(self):
  if self.log_alpha.requires_grad:
    cost = torch.sigmoid(self.log_alpha - self.beta * self.log_ratio_).sum()


Concluding Notes (mostly for implementation)

  1. When someone writes “Hard Concrete”, they mean Hard sigmoid clamping on a continuous relaxation of Bernouli (Concrete) distribution.

  2. $\alpha$ and $\beta$ are the parameters that we need to train.

  3. Start with gates initialised near 1, not 0 or 0.5, I find that this is the only initialisation where the gates can be trained to a reasonable value.

  4. Disable early stopping callbacks, or increase the patience level for early stopping. Compared to training a model from scratch where we expect the performance to continuously increase, we expect the performance to drop rather than increase, as long as it doesnt drop too far we’re happy.

  5. Consider scaling the $L_0$ Regularisation loss to be in a similar range as the task objective. e.g., normalise by batch size and total number of heads.


References

  1. Louizos, Welling and Kingma. (2017) Learning Sparse Neural Networks Through L0 Regularization 

  2. Kingma and Welling. Auto-Encoding Variational Bayes Note: Reparameterisation trick was popularised in ML but not invented by these guys. 

  3. Maddison, Mnih, Yee. (2016). The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables