Saved in:
Bibliographic Details
Main Authors: Tseng, Albert, Yu, Tao, Park, Youngsuk
Format: Preprint
Published: 2025
Subjects:
Online Access:https://arxiv.org/abs/2502.20586
Tags: Add Tag
No Tags, Be the first to tag this record!
_version_ 1866916919911645184
author Tseng, Albert
Yu, Tao
Park, Youngsuk
author_facet Tseng, Albert
Yu, Tao
Park, Youngsuk
contents Low precision (LP) datatypes such as MXFP4 can accelerate matrix multiplications (GEMMs) and reduce training costs. However, directly using MXFP4 instead of BF16 during training significantly degrades model quality. In this work, we present the first near-lossless training recipe that uses MXFP4 GEMMs, which are $2\times$ faster than FP8 on supported hardware. Our key insight is to compute unbiased gradient estimates with stochastic rounding (SR), resulting in more accurate model updates. However, directly applying SR to MXFP4 can result in high variance from block-level outliers, harming convergence. To overcome this, we use the random Hadamard tranform to theoretically bound the variance of SR. We train GPT models up to 6.7B parameters and find that our method induces minimal degradation over mixed-precision BF16 training. Our recipe computes $>1/2$ the training FLOPs in MXFP4, enabling an estimated speedup of $>1.3\times$ over FP8 and $>1.7\times$ over BF16 during backpropagation.
format Preprint
id arxiv_https___arxiv_org_abs_2502_20586
institution arXiv
publishDate 2025
record_format arxiv
spellingShingle Training LLMs with MXFP4
Tseng, Albert
Yu, Tao
Park, Youngsuk
Machine Learning
Low precision (LP) datatypes such as MXFP4 can accelerate matrix multiplications (GEMMs) and reduce training costs. However, directly using MXFP4 instead of BF16 during training significantly degrades model quality. In this work, we present the first near-lossless training recipe that uses MXFP4 GEMMs, which are $2\times$ faster than FP8 on supported hardware. Our key insight is to compute unbiased gradient estimates with stochastic rounding (SR), resulting in more accurate model updates. However, directly applying SR to MXFP4 can result in high variance from block-level outliers, harming convergence. To overcome this, we use the random Hadamard tranform to theoretically bound the variance of SR. We train GPT models up to 6.7B parameters and find that our method induces minimal degradation over mixed-precision BF16 training. Our recipe computes $>1/2$ the training FLOPs in MXFP4, enabling an estimated speedup of $>1.3\times$ over FP8 and $>1.7\times$ over BF16 during backpropagation.
title Training LLMs with MXFP4
topic Machine Learning
url https://arxiv.org/abs/2502.20586