Grid Cell Inspired Scalar Encoder

I had a similar idea after watching the talk, Circuitry and Mathematical Codes for Navigation in the Brain, given by Ila Fiete about grid cells. In her talk, she mentions modulo based number system (residue number system) which is like an abstraction of how grid cells work. To explore the advantages of modulo based number system for arithmetic operations that she talks about, I wrote a simple code which converts a number between fixed based number system and modulo based number system. After having residue values, it is straightforward to make an SDR out of it. The problem I realized later is that in encoding process it is desired to have similar SDRs (high overlap) for semantically similar entities, however, with this approach, even an increment give completely different SDR. Or I am missing something. This method may have some other benefits that I have not foreseen yet, though.

Here is the code I wrote:

    from functools import reduce
    import numpy as np


    def gcd(a, b):
        """Return greatest common divisor using Euclid's Algorithm."""
        while b:
            a, b = b, a % b
        return a


    def lcm(a, b):
        """Return lowest common multiple."""
        return a * b // gcd(a, b)


    def lcmm(*args):
        """Return lcm of args."""
        return reduce(lcm, args)


    class ResidueNumberSystem:
        def __init__(self, *modulos):
            self.modulos = sorted(modulos)
            self.least_common_multiple = lcmm(*self.modulos)
            self.weights = self.find_weights()

        def find_weights(self):
            weights = []
            for i in range(len(self.modulos)):
                m = self.modulos[i]
                rest = self.modulos[:i] + self.modulos[i + 1:]
                M = lcmm(*rest)
                for j in range(1, self.least_common_multiple // M + 1):
                    if (M * j) % m == 1:
                        weights.append(M * j)
                        break
            return weights

        def encode(self, n):
            digits = [n % m for m in self.modulos]
            return digits

        def decode(self, digits):
            n = 0
            for d, w in zip(digits, self.weights):
                n += d * w
            return int(n) % self.least_common_multiple

        def to_sdr(self, digits):
            sdr = np.zeros(sum(self.modulos), dtype=np.uint8)
            offset = 0
            for i in range(len(self.modulos)):
                idx = digits[i] + offset
                sdr[idx] = 1
                offset += self.modulos[i]

            return sdr

        def tabulate(self):
            header_format = '{:8s}' + '{:4d}' * len(self.modulos)
            trow_format = '{:8d}' + '{:4d}' * len(self.modulos)
            print(header_format.format('n', *self.modulos))
            print('=' * (8 * (len(self.modulos) + 1)))
            for i in range(self.least_common_multiple):
                print(trow_format.format(i, *self.encode(i)))


    class RN:
        def __init__(self, rns, digits):
            self.digits = self.normalize(rns.modulos, digits)
            self.rns = rns

        @staticmethod
        def normalize(modulos, digits):
            num_modulos = len(modulos)
            num_digits = len(digits)
            digit_array = np.zeros(num_modulos, dtype=np.int8)
            digit_array[num_modulos - num_digits:] = digits
            for i in range(num_modulos):
                digit_array[i] = digit_array[i] % modulos[i]

            return digit_array

        def __str__(self):
            return '{} % {}'.format(str(tuple(self.digits)), str(tuple(self.rns.modulos)))

        def __repr__(self):
            return 'RN(ResidueNumberSystem({}), {}'.format(str(tuple(self.rns.modulos)), str(tuple(self.digits)))

        def __neg__(self):
            return RN(self.rns, -1 * self.digits)

        def __pos__(self):
            return self

        def __abs__(self):
            return self

        def __invert__(self):
            x = np.array(self.rns.modulos, dtype=np.int8) - self.digits
            return RN(self.rns, x)

        def __int__(self):
            n = self.rns.decode(self.digits)
            return n

        def __add__(self, other):
            if isinstance(other, int):
                other = np.full(len(self.digits), other)
            elif isinstance(other, (list, tuple)):
                other = np.array(other, dtype=np.uint8)
            elif isinstance(other, self.__class__):
                other = other.digits

            x = self.digits + other
            for i in range(len(x)):
                x[i] = x[i] % self.rns.modulos[i]

            return RN(self.rns, x)

        def __radd__(self, other):
            return self + other

        def __sub__(self, other):
            if isinstance(other, int):
                other = np.full(len(self.digits), other)
            elif isinstance(other, (list, tuple)):
                other = np.array(other, dtype=np.uint8)
            elif isinstance(other, self.__class__):
                other = other.digits

            x = self.digits - other
            for i in range(len(x)):
                x[i] = x[i] % self.rns.modulos[i]

            return RN(self.rns, x)

        def __rsub__(self, other):
            return -self + other

        def __mul__(self, other):
            if isinstance(other, (list, tuple)):
                other = np.array(other, dtype=np.uint8)
            elif isinstance(other, self.__class__):
                other = other.digits

            x = self.digits * other
            for i in range(len(x)):
                x[i] = x[i] % self.rns.modulos[i]

            return RN(self.rns, x)

        def __rmul__(self, other):
            return self * other

        def __pow__(self, n):
            return RN(self.rns, self.digits ** n)

        def __eq__(self, other):
            return np.all(self.digits == other.digits)

        def __neq__(self, other):
            return np.any(self.digits != other.digits)

        def __iadd__(self, other):
            return self + other

        def __isub__(self, other):
            return self - other

        def __imul__(self, other):
            return self * other

        def __idiv__(self, other):
            return self / other

        def __ipow__(self, other):
            return self ** other


    if __name__ == '__main__':
        rns = ResidueNumberSystem(2, 3, 5)
        n = 10
        digits = rns.encode(n)
        sdr = rns.to_sdr(digits)
        print(n, digits, sdr)
        n = 11
        digits = rns.encode(n)
        sdr = rns.to_sdr(digits)
        print(n, digits, sdr)
7 Likes