Tracking model confidence by Deja Vu counters

Here-s the problem and possibly a solution to it.

Generally with ML models, besides being trained to predict some output for a given input, the notion of confidence into a model’s predictions is measured globally (across whole training set) via a confusion matrix and subsequent interpretations of it via metrics for accuracy, precision, recall, etc…

Based on these globally-measured parameters, all the user of a model can do is to skew its probability-like predictions towards a desired threshold for either recall or precision or whatever metric they consider more important.

One problem with these limited metrics is that in deployment, when the model signals below a threshold level, it isn’t able to extract details upon why the model’s confidence is below threshold.

Another problem is they can hardly be used for other problems than 1 WTA classification, e.g. do not provide a level of confidence regarding a predicted “correct” regression - e.g. acceleration pedal and steering wheel angle.

A DejaVu counter is able to predict, for any given input pattern, two important values:

  • how often the current input pattern was encountered in the past (or during training for off-line learning)
  • and how often these encounters required learning corrections or the accumulated loss related to the pattern.

What are these counts - encounters and corrections - useful for?

  • low value of encounters can inform the model (or the agent using it) that a given input is a new (or very rare) pattern. So it can predict a low confidence value of prediction, regardless as how confident the underlying model claims to be.
  • high encounters/corrections ratio signals the model is within known territory, it is within learned space with high confidence it will make a correct prediction for current input.
  • the lower (closer to 1) the encounters/corrections ratio is, the higher chance that, despite the current pattern being familiar, the confidence in it providing a good prediction is low. It could be e.g. unpredictable randomness present in the training dataset. Or otherwise useless data point.

Moreover, a DejaVu counter can provide causal insight into which sub-patterns in the input contribute most to its global measures of confidence and adapt its subsequent learning/exploration strategies, e.g.:

  • gather different data (use a new perspective or new sensors) for unpredictable yet frequent patterns (change perspective)
  • gather more data for infrequent patterns (explore unknown territory)
  • steer away from situations leading to high confusion (play safe)
  • or simply attempt to ignore certain sub-features (or lower their weight) prone to increase confusion. (filtering out noise)
1 Like

This is interesting. Do you have any article, paper, etc. references that describe it (or an implementation of it) in more detail?

How does it handle changes over time? Does it have (or can/should it add) a trailing moving average for the “encounters” and/or “corrections” variables? What I am thinking here is “a long time ago this pattern was on highly reliable in predicting X but lately it seems more likely to predict Y suggesting a shift in the data or environment. I am going to give you a lower confidence until it becomes stable predicting Y”.

2 Likes

LLMs learn to align themselves with tokens to express their epistemic uncertainty in tokens themselves - so I can imagine kWTA could learn a similar mechanism (assuming its turing complete)

2 Likes

@klokare I have some code within a “sdr_machine” ML framework, which unfortunately I haven’t published anything yet.
The most trivial implementation of this would simply increment (and track) a counter for every ON bit in a sdr:

def increment(sdr, counters, increment = 1):
    '''
    sdr: a list of ON bits, aka "sparse" sdr encoding. 
    counters: integer vector of maximum sdr size length.
    '''
    counters[sdr] += increment
    return counters[sdr] / len(sdr) 

What I actually implemented is bit more complex variant which instead of counting ON bits, it counts bit pair occurrences. It uses a memory large enough to hold one counter for each possible bit pair.

Using numba jit compiler it is still reasonably fast, for 32 ON bit sdrs it is able to reach 100k writings (or readings) per second on a single sluggish core.

The conundrum with that kind of detail is what constitutes a relevant trail length? Yes I made a version that keeps counts for a fixed length queue (automatically decrements old SDRs). It halves the performance.

Combining several counters for different history lengths could be interesting too. No I haven’t got into that level of sophistication.
This is a shifting, playground territory for few ideas.

It would be interesting to have an internal sense of confidence during normal LLM generating mode, in order to internally trigger a more complex processing, like chain or tree-of-thought or domain specific vector database search or whatever external tool usage.

A kind of “hmm, I’m not sure, how should I handle this?” reasoning.

1 Like

It kinda of already does - just that we haven’t aligned it that way. Current LLMs have a pretty accurate notion of their epistemic uncertainty (which is something we can’t probe out as effectively as the LLM telling us itself; Its interesting because somehow, it has understood what the tokens indicating percentage mean - that 45% is actually smaller than 50% even though it has no prior knowledge for that - its juts token #65 and #9925 for it)

For example: If you intercept its CoT during solving something, and swap in an answer for the final answer it thinks it computed with something absurd, it would just ignore its own CoT and provide an approximate answer.

But if the intercepted answer is close to reasonable, then it would accept it and give us the final answer in the format requested.

So its still leveraging its own epistemic uncertainty - just not in the way we want. Which is why its an alignment problem :slight_smile:

2 Likes

Slightly related, here-s a work about a sense of uncertainty of a model’s predictions with application in robotics (avoiding collisions & free path finding):

Link: Ajna
Video: https://youtu.be/VTEaJFb9AaE
Paper: https://www.science.org/doi/10.1126/scirobotics.add5139

What I find intriguing is that while it seems more difficult to integrate “confidence sensing” with raw data, the resulting model requirements are reduced, allowing for on-board, on-line navigation on light drone hardware.

2 Likes

A followup dissertation: https://www.youtube.com/watch?v=XThD7UUksv4

1 Like