image Fourier Head:
Helping Large Language Models Learn
Complex Probability Distributions

Brown University, Google DeepMind
*Equal contribution



The Fourier head is a neural network layer which learns a continuous probability density function using Fourier series, and returns a discrete approximation of it.

Large language models are often adapted to model non-linguistic tokens. If these tokens have an underlying continuous structure, then replacing the linear classification head with the Fourier head can boost downstream performance.


Toy Example: Fourier Head Learns a Square Wave

We demonstrate how the Fourier head can learn a square wave. As we increase the number of frequencies, the Fourier head does a better job approximating the square wave, and it becomes less smooth. This trend illustrates the Fourier Head Scaling Law. In this example, we consider a Fourier PDF with \(N=1,\dots,64\) frequencies, and \(128\) output dimensions.



Toy Example: Fourier Head Learns a Gaussian Mixture Model

We demonstrate how the Fourier head can learn a GMM. As we increase the number of frequencies, the Fourier head does a better job approximating the GMM, and it becomes less smooth. This trend illustrates the Fourier Head Scaling Law. In this example, we consider a Fourier PDF with \(N=1,\dots,64\) frequencies, and \(128\) output dimensions.



Toy Example: Fourier Head Learns a Complicated Gaussian Mixture Model

We demonstrate how the Fourier head can learn a complicated GMM. As we increase the number of frequencies, the Fourier head does a better job approximating the GMM, and it becomes less smooth. This trend illustrates the Fourier Head Scaling Law. In this example, we consider a Fourier PDF with \(N=1,\dots,64\) frequencies, and \(128\) output dimensions.



Toy Example: Fourier head outperforms the linear classification head

MY ALT TEXT

As a low-dimensional demonstration, we task an MLP with learning to approximate a continuous bimodal density using a categorical distribution and a cross entropy objective. We observe that a standard linear classification head fails to distinguish between the two modes, and overfits to high-frequency noise in the training set. In contrast, our proposed Fourier head learns a smoother, more accurate categorical distribution. Our paper provides theoretical justification for the Fourier head, as well as empirical justification on a large scale imitation learning task, and a time series foundation model pretraining task.

Abstract

As the quality of large language models has improved, there has been increased interest in using them to model non-linguistic tokens. For example, the Decision Transformer recasts agentic decision making as a sequence modeling problem, using a decoder-only LLM to model the distribution over the discrete action space for an Atari agent. However, when adapting LLMs to non-linguistic domains, it remains unclear if softmax over discrete bins captures the continuous structure of the tokens and the potentially complex distributions needed for high quality token generation. We introduce a neural network layer, constructed using Fourier series, which we can easily substitute for any linear layer if we want the outputs to have a more continuous structure. We perform extensive analysis on synthetic datasets, as well as on large-scale decision making and time series forecasting tasks. We also provide theoretical evidence that this layer can better learn signal from data while ignoring high-frequency noise. All of our results support the effectiveness of our proposed Fourier head in scenarios where the underlying data distribution has a natural continuous structure. For example, the Fourier head improves a Decision Transformer agent's returns by 46% on the Atari Seaquest game, and increases a state-of-the-art times series foundation model's forecasting performance by 3.5% across 20 benchmarks unseen during training.


Large-scale example #1: Fourier head increases returns of an Atari agent by 46%.


MY ALT TEXT

We replace the linear classification head in the Decision Transformer agent with a Fourier head. We find that the Decision Transformer agent with the Fourier head achieves larger returns than the baseline agent with the linear head, so long as the Fourier head has sufficiently many frequencies. We also find that the Fourier agent's returns have lower variance. For normalized returns, higher is better; for smoothness, lower is better.

MY ALT TEXT

We present next-action distribution examples for two different Decision Transformer agents--one trained to predict the next action using a linear classification head as in the original implementation, and the other using our proposed Fourier head. We can see that the Fourier agent produces a "clump" of actions that is semantically meaningful. Namely, this agent almost certainly wants to shoot in the down right or right direction, presumably because there is a submarine in that direction. In contrast, the linear agent's next-action distribution doesn't clearly depict a strategy, and assigns higher likelihoods to incorrect actions. Because the Fourier head outputs a smoother PMF, it concentrates more probability mass near the correct action, resulting in better returns.


Large-scale example #2: Fourier head improves time series forecasting accuracy.


MY ALT TEXT

We replace the linear classification head in the Chronos time series foundation model with a Fourier head. We find that the model with the Fourier head achieves better forecasting accuracy than the baseline model with the linear head.

MY ALT TEXT

We present next-value distribution examples for two different time series foundation models--one trained to predict the next quantized value using a linear classification head, and the other using our proposed Fourier head. We can see that the Fourier head model produces a smoother next-token distribution than the linear head.


Overview of Procedure: Forward Pass Through the Fourier Head


MY ALT TEXT

At a high level, the Fourier head inputs \( x \in \mathbb{R}^n \), uses a linear layer to learn the coefficients for a Fourier series with \( N \) frequencies over \([-1, 1]\), and quantizes the interval \([-1, 1]\) into \( m \) equal bins. Then, the Fourier head evaluates the learned Fourier PDF at those \( m \) bin center points, and returns those \( m \) likelihoods as a categorical distribution. We include a PyTorch implementation on our github repository, as well as many examples of its usage.

BibTeX

@misc{gillman2024fourierheadhelpinglarge,
  title={Fourier Head: Helping Large Language Models Learn Complex Probability Distributions}, 
  author={Nate Gillman and Daksh Aggarwal and Michael Freeman and Saurabh Singh and Chen Sun},
  year={2024},
  eprint={2410.22269},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2410.22269}, 
}