I’m working on a brand new HTM implementation for my graduation project… Its a lot of work but I have to write a entire framework for GPU support.
I decided that since its brand new that I should try a more modern/DoD approach of API design. This is what I’ve came up and has partially implemented. I think it is not a bad design. Being more pythonic and more resembles what DL frameworks look like. I hope I can get my framework mature enough to release in the next few month.
Encoding
DoD design, less stack access. You generally only use an encoder in one place anyway.
//Now (tiny-htm/HTMHelper/NuPIC)
ScalarEncoder encoder(/*min=*/0, /*max=*/1, /*num_cells=*/64, /*num_active=*/4);
Tensor t = encoder.encode(/*value*/0.1);
//New
Tensor t = encoder::scalar(/*value=*/0.1, /*min=*/0,
/*max=*/1, /*num_cells=*/64, /*num_active=*/4);
Spatial Pooler
Makes compute()
const. Better performance
SpatialPooler sp(/*input_shape=*/{32,32}, /*output_shape=*/{16,16}
, /*global_inhibition=*/true, /*global_density=*/0.1);
//Now
Tensor y = sp.compute(/*input=*/x, /*enable_learning=*/true);
//New
Tensor y = sp.compute(/*input=*/x);
sp.learn(/*x=*/x, /*y=*/y);
Temporal Memory
Again, keeps compute()
const for better performance
TemporalMemory tm(...);
//Now
for(const auto& sdr : input_series) {
tm.compute(/*input=*/sdr, /*enable_learning*/true);
auto predictive_cells = tm.getPredcitiveCells();
auto active_cells = tm.getActiveCells();
//Do something with prediction
}
//New
Tensor last_active = zeros();
for(const auto& sdr : input_series) {
auto [predictive_cells, active_cells] = tm.compue(/*input=*/sdr);
tm.learn(last_active, active_cells);
last_active = active_cells;
//Do something with prediction
}
Seralizing
Non intrusive saving, allow to save in any format easily (as longs as someone codes it)
SpatialPooler sp(...);
//Now
std::ofstream out("sp.txt");
sp.save(out);
//New
auto state_dict = sp.states(); //Copies the internal state to a std::map<std::string, std::any>
save_to_XXX(state_dict, "sp.msgpack")
This also have the benefit of able to save a tree of states.
//Or make it a tree then save it
network_states["SP0"] = state_dict;
save_to_XXX(network_states, "network.msgpack")
Any flaws/things I can change?
Thanks.
Edit: Fix some code. I have written too much Python apparently.