Towards the cutest neural network
← Back to Kevin's homepagePublished: 2025 April 28I recently needed to use a microcontroller to estimate the pose (translation and orientation) of an object using readings from six different sensors. Since the readings were non-linear and coupled with each other, an explicit analytical solution was out of the question.
I figured I’d have a go at using a simple neural network to approximate it:
- generate training data (on my computer) using a forward simulation (pose to sensor readings)
- train a lil’ neural network (a few dense layers, 100’s of parameters, tops) to approximate the inverse mapping function (sensor readings to pose)
- deploy this network on my microcontroller (Cortex-M0, 16 kB RAM, 32 kB flash) to actually do the inference
Since neural networks have been around since the 1980’s, I figured it’d be straightforward. A quick background search uncovered lots of promising leads too, especially regarding “quantization”, which I wanted to do as my microcontroller doesn’t have hardware support for floating point operations.
However, this turned out to be much more difficult than I’d anticipated. It seems like my use case — end-to-end training of a simple dense neural network with integer-only inference — is quite uncommon.
The vast majority of papers and software libraries I found turned out to be complex, heavyweight (in terms of inference code size), and have lots of unstated assumptions and background requirements.
To make a web design analogy: It felt like I kept falling into npm create-react-app
rabbit holes rather than what I wanted: The moral equivalent of opening index.html
with notepad.exe and typing <h1>Welcome to my Homepage</h1>
.
I’m writing up my notes to:
- checkpoint my understanding of the space and current best solution
- help anyone else who simply wants a wholesome lil’ universal function approximator with low conceptual and hardware overhead
- solicit “why don’t you just …” emails from experienced practitioners who can point me to the library/tutorial I’ve been missing =D (see the alternatives-considered down the page for what I struck out on)
tl;dr, how to do it?
- use TensorFlow to do quantization-aware training and save the resulting model out to a
.tflite
flatbuffer file - use the microflow-rs crate for inference — it’s basically a proc macro that reads tflite file and generates straightforward Rust code that uses nalgebra to multiply matrices
Props to Matteo for being seemingly the only person in this space who can put a clear “hello world” inference example in their README.md:
use microflow::model;
#[model("path/to/model.tflite")]
struct MyModel;
fn main() {
let prediction = MyModel::predict(input_data);
}
It should be cuter.
Unfortunately, MicroFlow (and TensorFlow’s out-of-the-box quantization routines, from what I can tell) require floating point operations for inference.
While I could use software floating point, I want to make the cutest possible neural network, which means integer arithmetic operations only.
(The code size and speed of software floating point depends on the routines used. qfplib-m0-full has a good overview and shows that, e.g., the GCC compiler software floating point multiplication takes 166 cycles.)
On the MicroFlow side, removing floating point operations would require a big redesign, as:
- they’re baked into the API (predict assumes f32 inputs and outputs).
- the internals themselves rely on floating point
- the MicroFlow proc macro reads the
.tflite
flatbuffer (source) - when reading a fully-connected layer it combines the various quantization scale factors and zero points into f32 constants
- these constants are then used during inference at runtime
- the MicroFlow proc macro reads the
On the TensorFlow side, I can’t find any documentation/options about whether the quantization-aware training routines use floating point activation scaling or if they can model quantized multipliers.
But I’m getting ahead of myself — let’s first review the mathematics of neural networks and meaning of “quantization”.
Neural network overview
The basic idea of neural networks you have a bunch of training data:
- $x$, the input data you have (in my case, readings from six sensors)
- $y$, the output value you are trying to predict (the object pose)
Since you have no idea how to go from $x$ to $y$ yourself, you throw together $n$ general equations (“layers”) like:
\begin{align} a_1 &= \sigma(W_1 x + b_1) \newline a_2 &= \sigma(W_2 a_1 + b_2) \newline \vdots \newline \hat{y} &= \sigma(W_{n} a_{n-1} + b_n) \end{align}
These equations take your input $x$ and make a prediction $\hat{y}$ which, hopefully, is close to $y$.
The equations require:
- free parameters $W$ and $b$ (the “weights” and “biases”, respectively); in our case these are a matrix and vector rather than scalar values, since our input data $x$ comes from multiple sensors and the the output pose $y$ has six components (three spatial positions and three spatial rotations)
- $\sigma$, some nonlinear “activation” function (traditionally sigmoid is used, but these days folks like ReLU, $\sigma(x) = \mathrm{max}(0, x)$)
You start with random values for all of your $W$ and $b$, see how well they work for all the $x$ and $y$ pairs of your training dataset, then keep adjusting $W$ and $b$ until $\hat{y}$ (the prediction) is close to $y$ (the actual value).
For more details on how and why this works, see Michael Nielsen’s awesome Neural networks and deep learning.
Quantization overview
The real numbers of the mathematical neural network theory are usually implemented on a computer with floating point numbers. These are already quantized: The real numbers are infinite, but 32 bits of memory on a computer can represent, at most, $2^{32} - 1$ different things. How you map, 32, 8, or 4 bits of memory to the real number line is up to you, the programmer. See:
The “quantization” of neural networks refers to replacing some (or all) of the typically 32-bit floating point numbers of the network parameters with smaller representations.
One method, called “fake quantization”, is to store the parameters as, e.g., 8-bit integers, then convert them back to 32-bit floating point when doing the actual math. This reduces storage and memory bandwidth requirements, but since it still involves floating point it’s not what I’m interested in for my microcontroller application. (Floating point: Not Cute.)
The quantization I want is calculating $\sigma(Wx + b)$ using only integer arithmetic. Typically this is done by storing the weights as i8 and the biases as i32. (The reason the biases are larger is because they’re used as the initial accumulator value for the corresponding row of the matrix-vector product $Wx$, which is the sum of a bunch of i8 * i8 products, and so would likely overflow a i8 accumulator.)
Once you have the accumulated i32 sum and you pass it through the activation function $\sigma$, you then need to convert it back into an i8 so it can be used as $x$ in the next layer’s $\sigma(Wx + b)$
This is known as “activation scaling” and, as we saw earlier in MicroFlow, it is usually implemented with floating point multiplication.
However, activation scaling can also be done using a “quantized multiplier” instead, where the scaling factor is $2^{-n} M_0$, with $0.5 \le M_0 \le 1$ and $n$ a natural number. This allows the scaling to be done with a fixed-point multiplication and arithmetic shift.
There’s one other large aspect of quantization to consider: How to find the suitable, e.g., 8-bit weights for the network. There are two approaches:
post-training quantization: You basically just round the parameters of an already-trained neural network to the closest quantized number. It’s up to you to select the quantization sizes (bit-widths) and how to map those bits to real numbers. The overall accuracy of the network will go down compared to the, e.g., original f32 parameter network, but your network will be much smaller and faster to run. You also don’t need to do any training, which is usually a hassle for, e.g., large vision models that require terabytes of Internet cat photos.
quantization-aware-training: You train the network with the forward-pass modeling the quantization of the parameters. This leads to better accuracy, since you are optimizing your actual use case (quantized inference) rather than optimizing for f32 inference and then unceremoniously rounding away those carefully trained weights later. The downsides are that:
- you have to train a network (not a problem in my case, since my network is small and I have to do that anyway)
- you have to do some fancy stuff so gradients can flow through the quantization functions (the literature calls this “the straight-through estimator”, but I grug programmer call this “pretend bad function not here”)
For more on the details of quantization / numerics, see:
- Lei Mao’s Quantization for Neural Networks
- this stack overflow answer about
.tflite
inference - gemmlowp’s quantization.md, which explains in full detail the numerical computer bits
- this Tiny ML walkthrough/notebook
- this paper: Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
- this
.tflite
model visualizer
Recap + next steps
I’m currently able to use TensorFlow to perform quantization-aware training of my two-layer dense neural network, yielding quantization parameters in a .tflite
that I can run on my tiny microcontroller using the lovely MicroFlow crate.
This is pretty cute (and sufficient for my actual application), but it requires floating point operations (for input, activation, and output scaling), which add latency and code size (not cute).
A maximally cute neural network would efficiently use the instructions available on my hardware to transform the incoming sensor readings (six u16 values) to outgoing estimates of the pose (six i16 values).
No runtime allocators, flatbuffers, or Extensible Compiler Frameworks.
How could this be accomplished?
Unless I had an expert guide who was absolutely sure it’d be straightforward (email me!), I’d avoid high-level frameworks like TensorFlow and PyTorch and instead implement the quantization-aware training myself.
I’d use JAX, since it provides both:
- automatic differentiation (gradients are necessary for training)
- just-in-time compilation (so the training goes fast on my computer)
In fact, their docs show a simple a demo of training a neural network using gradient descent, using no other framework or dependencies.
I spent an hour with my friend Niko figuring out how to define a custom gradient and we managed to descend our way to solving a toy quantized problem:
import jax
import jax.numpy as jnp
import jax.nn as nn
# Quantization function that rounds x to one decimal point.
# Define a custom gradient so it's ignored during automatic differentiation.
@jax.custom_vjp
def quantize(x):
return jnp.round(x * 10.0) / 10.0
def quantize_fwd(x):
return quantize(x), (None, )
def quantize_bwd(res, g):
return (1.0 * g, )
quantize.defvjp(quantize_fwd, quantize_bwd)
def predict(weights, biases, x):
a1 = nn.relu(quantize(weights) @ x + quantize(biases))
return a1
def loss(weights, biases, x, y):
yhat = predict(weights, biases, x)
return jnp.mean((yhat - y)**2)
loss_and_grad = jax.value_and_grad(loss, argnums=(0, 1))
weights = jnp.array([0.5])
biases = jnp.array([0.5])
input = jnp.array([1.])
target = jnp.array([0.2])
learning_rate = 0.01
for i in range(100):
loss_value, (dloss_weights, dloss_biases) = loss_and_grad(weights, biases, input, target)
weights = weights - learning_rate * dloss_weights
biases = biases - learning_rate * dloss_biases
print(loss_value, (weights, biases), quantize(weights), quantize(biases))
# this prints out decreasing loss values reaching 0 when quantized weights and biases sum to the target value (0.2).
It’s all straightforward math! I didn’t even to bring in that “Adam” guy.
So to do the full quantization-aware training, I’d write out:
- quantization of the weights, biases, and activations
- custom gradients for the backward pass through this quantization
- the Learned Step Size Quantization to find the dyadic rational activation scaling
then train to find network parameters.
To make sure I didn’t mess up the implementation, I’d compare my fully-integer-quantized task loss with the mostly-integer-quantized task loss from TensorFlow.
Then, for the inference, I’d manually write the Rust that does the matrix multiplication and activation scaling. I’d have my Python training notebook write out the weights as a string of Rust that my firmware can include! into the binary.
I suspect the entire training notebook and inference code would be less than 200 lines total. Most importantly, I would actually understand what’s going on, which is aesthetically much more satisfying than messing with complex frameworks =D
Until then!
Thanks
Thanks to Rick Lamers and Matteo Carnelos for reviewing these notes.
Appendix: TensorFlow notes
“TensorFlow” is an OG deep neural network framework from Google. “TensorFlow Lite” is when they got inference working on mobile phones. “TensorFlow Lite Micro” is when they got inference working on microcontrollers (presumably so you can say “Hello Google” or whatever to their smart speakers or your Nest thermostat or something).
The best introduction and conceptual documentation I found was this 2023 lecture series on TensorFlow Lite Micro.
In a very Google move at some point they renamed the project and docs to “Lite RT” but didn’t finish the job, so the various code repos, python packages, and file formats are all still various permutations of the earlier names.
I managed to follow their tflite-micro/tensorflow/lite/micro/examples/hello_world, but here’s the pyproject.toml
I needed to cobble together to do it:
requires-python = ">=3.12"
dependencies = [
"absl-py>=1.2",
"ai-edge-litert>=1.2.0",
"jupyter>=1.1.1",
"numpy>=1.10",
"tensorflow>=2.18.1",
"tensorflow-model-optimization>=0.8.0",
"tf-keras>=2.18.0",
]
And yeah, go learn how to use uv because Google does not keep tensorflow-model-optimization
(required for quantization) compatible with tensorflow
— just trying to pull the latest of everything will give you transitive dependency conflicts.
The Bazel scripts used in the hello_world example didn’t work with Bazel 8 and when I tried Bazel 7 it blew up with some python errors. ¯\_(ツ)_/¯ They do have makefiles, but those required me to upgrade make — first time I’ve ever run into that in my career.
They don’t have it anywhere in the docs, but to build the runtime so you can link to it in your firmware, you need to run something like:
make -f tensorflow/lite/micro/tools/make/Makefile OPTIMIZED_KERNEL_DIR=cmsis_nn TARGET=cortex_m_generic TARGET_ARCH=cortex-m0 microlite
Unfortunately, this all turned out for naught because even just linking to the tensorflow lite micro runtime blew up my program size to 37 kB, which is too much for my 32 kB flash microcontroller (especially considering that I need to also include, uh, my program and the actual neural network model).
Appendix: Alternatives considered
Here are some other libraries/frameworks I found but wasn’t able to use:
CMSIS-NN is an ARM-specific neural network kernel. Maybe I could string together the inference code myself, but I couldn’t find any examples or “howto”-style documentation, just reference docs for (presumably) people who are writing machine learning compilers. The one example linked in the repo requires tflite micro, which was too big for my microcontroller anyway.
IREE looks active and “scales […] down to satisfy the constraints and special considerations of mobile and edge deployments” but the graphic on their readme suggests the runtime size is > 25kB, which is too big for my application so I didn’t even try.
MicroTVM came up in a Google search but there weren’t any simple usage examples. I then discovered on a forum post that it had been phased out.
uTensor looked promising — their readme calls out their runtime size of 2kB — but I didn’t pursue it because their overview doc spends a bunch of time discussing concepts I know I don’t need, so the whole thing felt a bit too heavy/complex for my “I just wanna matmul a simple fixed network architecture” situation.
TinyEngine seems to be a well-documented academic project, but all of their examples and discussion involve complex network architectures for vision/audio and on-device training; I couldn’t find an example of a simple dense network inference, nor did I see anything about integer-only inference.