The grokking challenge?

I found this paper quite interesting. it talks about a phenomenon where if you train a NN for long enough I.e tens of thousands of epochs, it suddently and sharply generalizes perfectly for the test dataset, this was only tested on small datasets but they managed to introspect the network and found out that it had learned an actual algorithm for some cases instead of just memorizing the data and thats why it managed to generalize.

anyway. I found this kind of grokking to be a really interesting challenge, can we make a network that learns this kind of problems faster?

if particular I found the a + b mod 97 problem interesting because its so simple and the training data can be generated with a few lines of code but it is still extremely hard to get a network to generalize it.

def generate_mod97_dataset(split=0.5):
    import random
    
    data = []

    for i in range(97):
        for j in range(97):
            data.append((i, j, (i + j) % 97))

    random.seed(0)
    random.shuffle(data)

    i = int(len(data) * split)
    train = data[:i]
    test = data[i:]

    return train, test
5 Likes

Ah, that’s a tough one. Intriguing paper.

Since it uses transformers, I wonder if it is not what happens in “normal” training for transformers.

I know normally they need huge amount of data (e.g. LLaMA 13B is trained on 1T tokens) but a consequence of such huge amount of training data is they normally use a single training epoch.

Datasets for fine tuning (Alpaca, Vicuna, OpenAssistant) despite being much smaller, are very efficient - e.g. using lora trick a 20B parameter transformer can be fine tuned in a day with $20 worth of GPU lease.

What would happen with a small transformer trained on a limited dataset for many epochs beyond its overfit?

Sorry for the distraction, I guess rambling again about transformers wasn’t what you were looking for.

2 Likes

Regarding your dataset generator: - I haven’t seen mentioned in the paper x and y should be smaller than 97. Of course, learning for larger numbers would be even more difficult.

A few more observations on the paper:

  • do I understand correctly, that the token embeddings for x and y were random? I mean using a similarity-preserving representation like a scalar encoder should be much easier to extrapolate over?
  • the model is small indeed (400k parameters), with only two transformer blocks (of width 128) on top of each other, while usual transformers stack dozens of blocks and are 1k-10k wide.
  • it failed to generalize with more complex equations like
    (x**3 + x*y**2 + y) % 97

I wonder if the latter could be solved with some form of curriculum:

  • have the model learn simple operations first
  • add a couple more blocks on the already trained ones.
  • continue training with complex equations.

Even better would be to rethink transformer metaphor from a simple, very long “ladder” to swapable blocks + recursion using a “router”

Which slowly leads me towards the hive of micro agents concept. I know, I’m biased towards that idea.

2 Likes

That type of approach for training (the algorithm is the goal) is already finding much more complex roles, which are not necessarily constructive so to say…

If they have any open research papers available at the back of this they may be quite interesting input for you.

3 Likes

The paper they tested different methods to “wiggle” the model from its stable, overfit state, and they found weight decay as best method (faster generalization)

Here-s an article about how weight decay works.

Which makes sense because it pushes the NN towards a simpler encoding of the algorithm.

That’s interesting that we can see AI had an “aha!” moment when the network shifts towards a simpler (and implicitly sparser) parameter state.

That’s cool.

4 Likes

Grokking is just a visible phase transition, hence why it pretty much breaks classical ML statistical understanding.

I brought it up quite a few months ago here, so some of my posts might still be up regarding this. But in a nutshell, during training its widely believed LLMs simply go through millions of different phase transitions which are a function of the compute provided. Grokking/DD requires a lot of it - the graphs are on log scale, so often order of magnitudes more compute to achieve it.

Not really. Nothing what you see here is scalable. Its an interesting from a Mechanistic Interpretability point of view as we have (in general) little understanding of it. The best analogy are splines - you can read some brilliant explanations regarding it online.

So it doesn’t really impact practical training of models in any way. China is definitely not using anything like that. Nor has it much to do with LAWS in any regard tbh

3 Likes

It doesnt have to be a transformer though. I suppose any NN will have a hard time with this, its probably not architecture limited but gradient descent limited.

thats why I wonder what would happen if we could find a optmizer that works better than gradient descent for this.

1 Like

thanks to your words I remembered the name of the original article that introduced me to this phenomenon.

https://www.lesswrong.com/posts/N6WM6hs7RQMKDhYjB/a-mechanistic-interpretability-analysis-of-grokking

in that article:

Input format is x|y|= , where x,y are one-hot encoded inputs, and = is an extra token.

but I guess the key point would be to make it as hard as possible to grasp any pattern in the data yet have the model generalize, so encoding the input in a way that makes the answer obvious is an option but not an interesting one IMHO.

2 Likes

Keep an open mind.

1980’s I remember reading an article in New Scientist where a neural network could reverse a computer simulation of a stick drawn truck into a parking spot. I could have thought that won’t scale. Now I’m sitting in the drivers seat of a car while I passively watch it drive down the road with many, many more magnitudes of compute than the first computer I had in 1981.

Networks sometimes learn some very strange simple patterns we don’t always realise that exist in the sea of complexity.

1 Like

Do you have any thoughts on this?

Looks potentially like a strong claim but I can’t understand the relevance. Are these things actually computable?

1 Like

Yes but that’s not what grokking is about; there is no theoretical backing you can give about LLMs being grokked during training or far from their grokking point without a interpretability framework. The most famous of which is mechanistic interpretability, popularized by Neel Nanda like in the LW article posted above.

Phase transitions however are observed in practice to a degree, but never discretely; we guess there are phase transitions on non-toy tasks but again have no way to confirm that really except its likely that it happens.

There’s an inherent uncertainty here about these phenomena so I would advise not drawing conclusions willy nilly.

I’m not honestly sure here what the contribution of the paper here is. They simply mathematically formulate the task’s critical point where the network would undergo a phase transition - which as they find is M\epsilon=1. It just seems like an odd way to rederive the math associated with criticality through a toy task.

They also make the obvious point that as a markovian chain gets longer, the chances of errors grow exponentially under a bound. That’s not specific to an LLM though - you could provide any arbitrary function producing a PDF here and the result would hold because the task itself is framed as a chain of probabilities sampled at each point.

Overall, what they show is that for any Autoregressive task there would be atleast 1 guranteed phase transition when the network achieves (and exceeds) the critical point. But that has little to no bearing for induction heads or circuits - infact, as I said, this result holds for any arbitrary learner producing a PDF.

I don’t think their claim has much merit except being really obvious - I don’t even get what the paper was meant to be demonstrating in the first place. It seems… pretty bad IMO, like someone just trying to publish anything to fluff up their CV

2 Likes

Meh, I’m not sure about that. Using special distorting spectacles just to make problems harder to the brain because we want to improve intelligence NOT vision… nature doesn’t do it.

I would have an algorithm to programmatically or evolutionary:

  1. alter spectacles in many ways (transform input encoder)
  2. keep the most promising ones with existing brains (a new generation of spectacles)
  3. alter processor (brain) in many ways.
  4. keep the most promising ones (a new generation of brains better fit for the new spectacles)
  5. repeat.

One might say this is evolution, not intelligence.
I would answer, what’s the actual difference?

How the whole AI research process of:

  • exposing hundreds and thousands of new ANN structures to billions or trillions of inputs
  • keeping better ANN structures, alter them further in new ways that might work,
  • then repeat

Is it different from evolution?

Consider how many experiments not worth publishing could have been, for each paper claiming improvements.

PS also consider that training a single NN for thousands or million cycles in order to improve a desired outcome is not much different from an evolutionary process.

I don’t think there are generally applicable shortcuts to intelligence. There were lots of brains stumbling into the problem till Archimede’s “evrika” or Newton’s revelation that earth pulls the apple.
These simple, enlightening revelations do not often pop up in “human” brains just because they are, duh, generally intelligent.

All of them give their shots at looking in different ways at the problem, and very few lucky ones stumble into a shortcut

2 Likes

That being said, I would be curious how well a forward-forward algorithm would generalize (if at all) in the addition modulo 97 problem.

2 Likes

well, I have already tried with 1 and 2 layers symmetric k_wta autoencoders for 60K epochs running FFA and it overfits instantly for some hyperparammeters but did not generalize at all. maybe a different architecture would.

3 Likes

Ok, did you attempt to do weight decay with FF? I don’t know how would that go, the basic idea behind it is pushing towards a sparser structure (fewer active weights) that could solve the same problem as over fitting.

1 Like

It probably won’t; it barely works outside toy tasks. Modular arithmetic might be a bit outside its domain :wink:

Hm? why autoencoders? they won’t generalize for seq2seq tasks; you’d need a totally different architecture. Even a FF stack would probably be enough.

2 Likes

I did try with and without decay, it changes general sparsity and hurts memorization a bit but had no effect on the test performance.

but it could just be that the enforced sparsity networks, I’m so fond of are no good for this kind of stuff.

modular arithmetic is not fundamentally a seq2seq task. it can very well be encoded as one-hot inputs and one-hot outputs in a regular NN.

2 Likes

Yes… that’s seq2seq… the sequence is just in a more sparse format

1 Like

I think we-re talking more of a NN with 2N inputs and N outputs where N is e.g.97.

It can be a MLP or anything else in between.

1 Like

I have played a bit around this problem and so far I had best results with the following setup:

  • Each 0 to 96 input integer is encoded as a 400 long dense vector with a VarCycleEncoder(*).
  • Since addition (modulo N) is commutative the dataset contains only pairs (X,Y) without the corresponding (Y,X) pair.
  • For the same reason, instead of having a 800 long dense input for a pair of values, I added the dense representation of Y and X to get a 400 long dense representation of Y and X together
  • the resulting vector (Y and X) is SDR-ified at different sparsification levels e.g. 50/400, 80/400, 100/400, 133/400, 200/400 to obtain different bit only input data sets for all possible (97*96/2=4656) pairs.
  • a random half of these pairs are used for training, half for testing
  • training was done on a sklearn MLP regressor with 97 outputs instead of a classifier. The reasons are is I figured out how to keep the regressor train indefinitely, long after a classifier would have stopped with 100% accuracy on training data. Probably MLP Classifier have similar settings too.
  • The hidden layers for these results were (400,200,200,200,200,200) . Various depths and widths might work, adding more depth has slight improvements.

The best result I got was 99.4% on the testing half after 200 iterations, which takes ~103 sec. on an old, 2 core/4 threads laptop.

What I found it interesting is:

  • encoding matters. A lot. Other encoders had much poorer performance.
  • “SDR-fication” matters again, a lot. Just feeding the MLP the overlapped dense representations had much poorer results within the same compute budget, than shifting highest values in the dense vector to 1 and lower ones to 0
  • Contrary to Numenta’s SDR theory, best results (in this case of MLP trained on SDR) were obtained with very low sparsity, 1/3 bits 1 and even half 1 half 0 bits yielded close to top results.

Now, sure it raises the question upon why using a “special” encoding instead of the “neutral” one-hot encoding. That’s a long discussion. In my opinion, searching for and finding out efficient encodings which might also be useful for different problems is a valid path for several reasons I would gladly discuss.

2 Likes