Neural Tangent Kernel, Every Model trained by GD is a kernel machine (Review)
Motivation (why we should care):
-
Shallow networks can be reduced to Kernel machines, but we have no such theoretical results for arbitarily deep networks and architectures.
-
If DNNs are reducible to Kernel Machines, then it would appear that DNNs are “memorising” the data because each new test-time example is compared to previous examples for Kernel Machines, and this suggets how Deep Models could be working under the hood. I quote the author
This paper shows that all gradient descent does is memorize the training examples, and then make predictions about new examples by comparing them with the training ones.[1]
The high-level idea
The central claim of this paper[2] is that every model trained by Gradient Descent, regardless of its architecture, is approximately a kernel machine of the form:
\begin{equation} y = g(\sum_i a_i K(x, x_i) + b) \end{equation}
where $g$ is an optional non-linearity, $a_i$ weighs each training example, $b$ is some constant, and a kernel $K$, which is introduced in the paper as a “Path kernel”. The Path Kernel(new in this paper) is defined as \begin{equation} K(x, x’) = \int_{c(t)} \nabla_{w_t} y(x) \cdot \nabla_{w_t} y(x’) dt \end{equation}
Note the author wrote $\nabla_{w}$ in the paper but he probably means $\nabla_{w_t}$. (Thx Tim Viera for asking about this in RG)
Intuitively, this says that for datapoints $x$ and $x_i$, for every set of model parameters $w$ which is obtained for infinitesmally small steps of gradient descent until model convergence, if $\langle \nabla f_w(x), \nabla f_w(x_1) \rangle$ are close, then $x$ and $x_1$ are close.
Discussion
-
The proof (refer to the paper) takes the following form: first express $\frac{dy}{dt}$ in terms of the gradient flow (when $\epsilon \rightarrow 0$), then, take the integral $\int_{c(t)} \frac{dy}{dt} dt $, perform a division and multiply by the same Kernel trick, and then substitute the definition of the Neural Tangent Kernel (NTK)[3,4], $K^g_{f, w(t)}(x, x_i) = \nabla_{w_t}f_w(x) \cdot \nabla_{w_t} f_w(x’)$ and the Path kernel. This divide and multiply by trick is more than slightly questionable. We could have had ANY kernel to be $K(x, x_i)$ in eq(1), as long as we do the divide and multiply by Kernel trick.
-
Regarding the weight $a_i$, even if the denominator is any generic kernel $K(x, x_i)$, we can still consider the numerator (see eq below). Examples weighted by the Path Kernel makes sense, but the other term in the numerator, $\mathcal{L}’(y_i^*, y_i)$, suggests that if the example has high loss then they should be more heavily weighted which seems strange. While it could suggest that parameter udates have been taken wrt that training example, the lack of some kind of decay factor could also mean that the model simply continues to get high loss on those training examples.
\begin{equation} a_i = \frac{\int_{c(t)} K^g_{f, w_t}(x, x_i)\mathcal{L}’(y_i^*, y_i) dt}{K(x, x_i)} \end{equation}
-
Crucially, the weighting $a_i$ depends on the new test-time input $x$, and not only on the training examples $x_i$. In fact, the Path Kernel is contained implicitly in the definition of $a_i$.
-
$K$ in eq(1) actually depends on the entire dataset (the evolution of parameters from initialisation until convergence), which limits applying known theoretical results of Kernel Machines.
-
The intuition that we should consider points that are close in gradient space (instead of say Euclidean space) is interesting, but is not new and is often mentioned in interpretability literature. So then if we just consider the contributions from the “new” Path kernel, it seems strange that there is no decay on earlier steps of gradient descent. Shouldn’t the part of the path closer to model convergence contribute much more to similarity of datapoints than the earlier parts of the path?
-
I disagree with the suggestion that “Deep learning also relies on such features, namely the gradients of a predefined function, and uses them for prediction via dot products in feature space” - arguably the gradients of the function are not “pre-defined” because they are data-dependent, and the weights update in hard to predict ways on each step of gradient descent.
-
As a final minor point but just for note-taking sake: One of the required properties of kernel functions are that they are symmetric $K(x, x_i) = K(x_i, x)$. Based on the definition of the path kernel the model should undertake slightly different parameter trajectory, which is not exactly equivalent but will probably be close enough.
Regardless, I still think this is a fresh way to think about similarity of datapoints, as aggregating the similarity of gradients across the entire trajectory of weight updates via gradient descent. Figure 1 of the paper provides a nice illustration.
References
[2] Pedro, D. (2020) Every Model Learned by Gradient Descent Is Approximately a Kernel Machine.
[4] Ultra-Wide Deep Nets and the Neural Tangent Kernel (NTK)