Saved in:
Bibliographic Details
Main Authors: Singh, Sidak Pal, Mobahi, Hossein, Agarwala, Atish, Dauphin, Yann
Format: Preprint
Published: 2025
Subjects:
Online Access:https://arxiv.org/abs/2502.02407
Tags: Add Tag
No Tags, Be the first to tag this record!
_version_ 1866917912518852608
author Singh, Sidak Pal
Mobahi, Hossein
Agarwala, Atish
Dauphin, Yann
author_facet Singh, Sidak Pal
Mobahi, Hossein
Agarwala, Atish
Dauphin, Yann
contents Curvature regularization techniques like Sharpness Aware Minimization (SAM) have shown great promise in improving generalization on vision tasks. However, we find that SAM performs poorly in domains like natural language processing (NLP), often degrading performance -- even with twice the compute budget. We investigate the discrepancy across domains and find that in the NLP setting, SAM is dominated by regularization of the logit statistics -- instead of improving the geometry of the function itself. We use this observation to develop an alternative algorithm we call Functional-SAM, which regularizes curvature only through modification of the statistics of the overall function implemented by the neural network, and avoids spurious minimization through logit manipulation. Furthermore, we argue that preconditioning the SAM perturbation also prevents spurious minimization, and when combined with Functional-SAM, it gives further improvements. Our proposed algorithms show improved performance over AdamW and SAM baselines when trained for an equal number of steps, in both fixed-length and Chinchilla-style training settings, at various model scales (including billion-parameter scale). On the whole, our work highlights the importance of more precise characterizations of sharpness in broadening the applicability of curvature regularization to large language models (LLMs).
format Preprint
id arxiv_https___arxiv_org_abs_2502_02407
institution arXiv
publishDate 2025
record_format arxiv
spellingShingle Avoiding spurious sharpness minimization broadens applicability of SAM
Singh, Sidak Pal
Mobahi, Hossein
Agarwala, Atish
Dauphin, Yann
Machine Learning
Computation and Language
Curvature regularization techniques like Sharpness Aware Minimization (SAM) have shown great promise in improving generalization on vision tasks. However, we find that SAM performs poorly in domains like natural language processing (NLP), often degrading performance -- even with twice the compute budget. We investigate the discrepancy across domains and find that in the NLP setting, SAM is dominated by regularization of the logit statistics -- instead of improving the geometry of the function itself. We use this observation to develop an alternative algorithm we call Functional-SAM, which regularizes curvature only through modification of the statistics of the overall function implemented by the neural network, and avoids spurious minimization through logit manipulation. Furthermore, we argue that preconditioning the SAM perturbation also prevents spurious minimization, and when combined with Functional-SAM, it gives further improvements. Our proposed algorithms show improved performance over AdamW and SAM baselines when trained for an equal number of steps, in both fixed-length and Chinchilla-style training settings, at various model scales (including billion-parameter scale). On the whole, our work highlights the importance of more precise characterizations of sharpness in broadening the applicability of curvature regularization to large language models (LLMs).
title Avoiding spurious sharpness minimization broadens applicability of SAM
topic Machine Learning
Computation and Language
url https://arxiv.org/abs/2502.02407