Disclaimer: This is part of my notes on AI research papers. I do this to learn and communicate what I understand. Feel free to comment if you have any suggestion, that would be very much appreciated.
The following post is a comment on the paper Mixed-Privacy Forgetting in Deep Networks by Aditya Golatkar, Alessandro Achille, Avinash Ravichandran, Marzia Polito, and Stefano Soatto.
Golatkar et. al. introduce a novel method for forgetting in a mixed-privacy setting, where a core subset of the training samples will not be forgotten. Their method allow efficient removal of all non-core (a.k.a. user) data by simply setting to zero a subset of the weights of the model, with minimal loss in performance. To do so, they introduce Mixed-Linear Forgetting (ML-Forgetting), which they claim to be the first algorithm to achieve forgettiing for deep networks trained on large-scale computer vision problems without compromising the accuracy.
Contributions
- Introduce the problem of mixed-privacy forgetting in deep networks.
- Propose ML-Forgetting, that trains a set of non-linear core weights and a set of linear user weights.
- By-design, ML-Forgetting allows to forget user all data by setting the user weights to zero.
- First algorithm to achieve forgetting for deep networks trained on large-scale computer vision problems without compromising the accuracy.
- Can handle multiple sequential forgetting requests, as well as class removal.
Mixed-Linear Forgetting
The main idea behind the method lies in the concept of quadratic forgetting, which comes from forgetting from a linear regression model, that has a quadratic loss function. User data is learned using such loss function, taking advantage of its convexity. First, they introduce the Mixed-Linear model, and then discuss the forgetting mechanism.
Mixed-Linear Model
Two separate minimization problems are solved, one for the core data ($\mathcal{D}_c$) and one for the user data ($\mathcal{D}_u$). If $f_{\textbf{w}}$ is the model with parameters $\textbf{w}$, we have:
$$\textbf{w}_c^* = \arg\min_{\textbf{w}_c} L_{\mathcal{D_c}}(f_{\textbf{w}_c})$$ $$\textbf{w}_u^* = \arg\min_{\textbf{w}_u} L_{\mathcal{D_u}}(f_{\textbf{w}_c^*+\textbf{w}_u})$$
where $L_{\mathcal{D}}$ is the loss function for the dataset $\mathcal{D}$. Since the deep network $f_{\textbf{w}}$ is non-linear, the loss function $L_{\mathcal{D}_u}(f_{\textbf{w}_c^*+\textbf{w}_u})$ can be highly non-convex. In light of [2], if $\textbf{w}_u$ is a small perturbation, we can hope for a linear approximation $f_{\textbf{w}}$ around $f_{\textbf{w}_c^*}$, to have a similar performance to fine-tuning the entire model. Thus, the Mixed-Linear model is defined as the first-order Taylor expansion:
$$f^{\text{ML}}_{\textbf{w}_c^*+\textbf{w}_u} (\textbf{x}) = f_{\textbf{w}_c^*}(\textbf{x}) + \nabla_w f_{\textbf{w}_c^*}(\textbf{x}) \cdot \textbf{w}_u$$
Furthermore, they use Cross-Entropy loss and Mean Squared Error loss, leading to the following minimization problem:
$$\textbf{w}_c^* = \arg\min_{\textbf{w}_c} L^{\text{CE}}_{\mathcal{D_c}}(f_{\textbf{w}_c})$$ $$\textbf{w}_u^* = \arg\min_{\textbf{w}_u} L^{\text{MSE}}_{\mathcal{D_u}}(f^{\text{ML}}_{\textbf{w}_c^*+\textbf{w}_u})$$
The MSE loss has the advantage that the weights $\textbf{w}_u$ are the solution of a quadratic minimization problem, which can be solved in closed form.
Forgetting Mechanism
As seen in [3] and [4], in the case of the quadratic training loss, the optimal forgetting step to delete $\mathcal{D}_f \subset \mathcal{D}$ is given by: $$\textbf{w}_u \mapsto \textbf{w}_u - \Delta\textbf{w}_u + \sigma^2 \epsilon$$ where $\Delta\textbf{w}_u = H^{-1}_{\mathcal{D}_r}(\textbf{w}_c)\nabla_\textbf{w}L_{\mathcal{D}_r}(f_{\textbf{w}_u})$ is the optimal forgetting step, $H_{\mathcal{D}_r}(\textbf{w}_c)$ is the Hessian of the loss function $L_{\mathcal{D}_r}$, $\mathcal{D}_r=\mathcal{D}-\mathcal{D}_f$ is the retained data, and $\epsilon \sim N(0,I)$ is a random noise vector. As $\Delta\textbf{w}_u$ is only an approximation of the optimal forgetting step, by adding noise, they can destroy the information that may leak. In practice is not feasible to compute the Hessian, so they use the Jacobian-Vector Product (JVP) instead (see [2]).
Personal Thoughts
- Although the method is interesting, I am not sure how practical it is. The theoretical framework heavily relies on the assumption that the perturbation $\textbf{w}_u$ is small, which may not be the case in practice. I find useful the fact of using core data to train a “foundational” (or core) model and then fine-tune it with user data (actually, this is the trend in SOTA models e.g., for generative AI). However, if the user data is far from being “small enough” and because of the linear approximation, the method may not work as expected.
References
[1] Golatkar, A., Achille, A., Ravichandran, A., Polito, M., & Soatto, S. (2021). Mixed-Privacy Forgetting in Deep Networks arXiv:2012.13431.
[2] Mu, F., Liang, Y., & Li, Y. (2020). Gradients as features for deep representation learning. arXiv:2004.05529
[3] Guo, C., Goldstein, T., Hannun, A., & Van Der Maaten, L. (2020). Certified data removal from machine learning models arXiv:911.03030
[4] Golatkar, A., Achille, A., & Soatto, S. (2020). Eternal sunshine of the spotless net: Selective forgetting in deep networks arXiv:1911.04933