Using HTM for branch prediction

Hey! I know a fair bit about HTM, and know a fair bit about microarchitecture! I’ll throw in my two cents here! Feel free to ask me any questions as well!

First off, it seems like most people here don’t have a great idea of how modern branch predictors work, and what actually makes them work as well as they do.

These predictors actually don’t just consist of one predictor. They’re actually a large table of thousands of predictors normally. Every branch gets mapped to a hash table entry based on its memory address, which then learns patterns for that specific branch (and any others that get mapped to the same entry).

A hash table of 2-bit saturating counters (literally working based on “which way has this branch gone the past couple times it appeared?”), can actually get up to about 93% accuracy on most code.

There are more sophisticated systems as well. For example, most Intel CPUs from the past decade or so have used Two-Level Predictors. Instead of the address mapping to a single predictor in a 1D table, it maps to a row of predictors in a 2D table, and the path taken from the previous N branches (the history) is used to select exactly which predictor from that row (which column).

The reason why AMD is using neural networks now is because the whole problem here is trying to get as much predictive power out of as little circuitry as possible. The problem with Two-Level Predictors is that the size of each row in the table scales exponentially with the length of the history (for N history bits, you need 2^N predictors per row). Meanwhile, if you throw a perceptron at the problem, you can get similar predictive power while requiring only linear scaling (one weight per history bit). It’s more expensive to implement per predictor, but it brings things back to linear rather than exponential scaling.

It’s also important to remember that these aren’t very sophisticated neural networks. Basically they’re just “take the last N bits, multiply each bit by a corresponding weight, sum them up, and check if they’re positive or negative. Positive = branch taken, negative = branch not taken. Adjust weights accordingly based on the actual result.” No deep neural networks, no fancy backprop, just a single-layer perceptron. Just a weighted sum of the last N branches taken.

It’s also important to bring up some other aspects here. There are many cases where branches are unpredictable. Say we have the following code:

array arr;
for(int i = 0; i < length(arr); i++){
    if(arr < n){
        //do something
        //do something else

If the values of arr are randomly distributed around n, then the branch predictor can’t do anything to learn this because the pattern to be learned isn’t in the code, it’s in the data. Any benefit you gain here is basically from the branch predictor equivalent of overfitting, and even then it’ll only actually work if arr is short and frequently iterated over.

So it’s not really a matter of learning that “the code goes 0011001100110011 through branches”. It’s about learning complex patterns of how each branch in the code is likely to influence each other branch. And, doing all of that with as few transistors as possible.

If you wanted to use HTM for branch prediction, here’s how it would have to work:

  • The HTM network should be quite compact; no wasted transistors! You can’t throw a million-neuron network with 10k synapses per neuron at this problem. You might only be able to manage a few hundred neurons each with a couple hundred synapses before this thing gets more expensive than even the largest traditional predictors. If anyone here is willing to learn some Verilog or VHDL, you can try making one.

  • Encoding the history bits isn’t enough. You need to be able to give the predictor hints about exactly where in the code the current branch is. Some other hints, such as a few bits from registers and what kinds of instructions are around might be useful too, though branch predictors never take those into account as far as I’m aware.

  • You’ll need to have some way to rewind the network to a previous state in case of a mispredict. Modern CPUs are quite deeply pipelined, so it might not know that a branch was mispredicted until after it’s already predicted the 5-10 branches that come afterward. Not to mention, the whole reason for making more accurate predictors is to make longer pipelines more feasible.

  • You’ll have to get useful data out of the network as well regarding exactly which way to go, and do this on a tight transistor budget. This might be a case where reinforcement learning has to come in.

I’m not confident that HTM would actually provide a large competitive advantage here, but hey, if anyone wants to learn some Verilog/VHDL and embed it in a RISC-V sim or something, I’d love to see it. If it is useful, it’ll probably be from making it easier to incorporate other information from around the CPU into making better decisions.

Then of course, all of this is completely ignoring Branch Target Buffers, where HTM would likely provide much less of an advantage, and Memory Prefetching, where HTM would likely actually do a better job (if done right), and would make a far bigger impact than a better branch predictor. After all, a branch mispredict costs 15-20 cycles. A cache miss costs about 300.