Do AI Weather Models Learn Physics?
How do AI weather models learn to predict weather? Are they actually learning physics?
This kind of question is called model interpretability. It's a famously hard problem, much harder than making a good model in the first place.
Let's take a look at WindBorne's WeatherMesh family of models, and try to peer into what's going on inside, with the question in mind: Is the model really learning physics? Or is it just pattern-matching?
The Architecture
The model has three main parts:
┌─────────┐ ┌───────────┐ ┌─────────┐ │ Encoder │ ──▶ │ Processor │ ──▶ │ Decoder │ └─────────┘ └───────────┘ └─────────┘
The Encoder encodes from physical weather variables into what's called a latent space that the model learns. The latent space is like the "mind" of the model, it's where it "thinks". The Processor is what evolves the latent space forward in time, and it operates only in the latent space. Finally, the Decoder decodes the latent space into something useful: physical weather variables.
WeatherMesh can forecast for any lead time, and encoding/decoding is both computationally expensive and a source of error, so the processor is chained multiple times in a row. Use it once to move forwards 3 hours, chain it 8 times to move forwards 24 hours.
Inside the Processor
The processor has multiple parts. "Thinking" about how weather is going to evolve is hard, and so there are 6 transformer layers that make up the processor. These transformers are similar to the transformer layers that make up ChatGPT, but they operate in 3D rather than 1D.
┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
input ──▶│ Layer 1 │──▶│ Layer 2 │──▶│ Layer 3 │──▶│ Layer 4 │──▶│ Layer 5 │──▶│ Layer 6 │──▶ output
└─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘
Between each layer, the model computes a new latent space. So, what if we ask this: Does the model smoothly change the latent space between each layer? If 6 layers are +3 hours, is 1 layer +30 minutes?
The model isn't explicitly trained for this to be the case. It only sees forecasts in training at 3 hour intervals. If it learns to evolve in smoother increments, that is a very good sign that the model is truly learning and emulating physics.
Visualizing the Latent Space
There is one more thing we need before we can go look at this. How do we "see" the latent space? It's an incomprehensible vector space of high dimension! Well, there is a linear algebra tool that we can use called Principal Component Analysis. It lets us find the perspective in the latent space that is most interesting, and then we can plot values over a map of the globe and see what it looks like.
If we do that for a bunch of different components, we can then animate them and see how they change. The video below does exactly that over a 24-hour forecast. The decimal after the time tells which transformer layer we are looking at.
We see that the principal components show gradual motion of atmospheric properties within the processor! So it is implicitly learning the continuous nature of physics. There's lots of other cool things here, as well, but too much to say for this post.