Learning Machine Learning
I’m finally making progress with machine learning. I’ve spent the last week hacking on some Rust neural network code and I feel like I have now got a good grasp of how to program them. In the weekend I got into a crazy flow state and wrote a huge amount of code. Then for the last few days I have undone most of that and taken a more considered approach
The code is on GitHub neural-net-rs There are two interesting branches1;
Update: 18 April 2025
All code has been merged into main
.
main
This is my "stream of consciousness" code. It became overly complicated because I added multithreading at the beginning then tried to back it out. Threading isn't the best way to optimise neural networks. Not immediately anyway. I have abandoned this branch.mnist
This was a new start where I concentrated on getting the code correct first. I'm pretty happy where this code is and I'll use this branch as the basis for future development.
This code is based on a repo I stumbled upon a couple of months ago neural-net-rs It’s a nicely structured Rust project with no dependencies. I forked it, fixed a bunch stuff and then added code so I could use it to train on the MNIST data set.
MNIST is a database of hand written numbers. It’s sometime referred to as the “hello world” of machine learning. It’s a problem that’s not too simple that you don’t learn anything and not too complex that you spend more time on the theory than the code. It’s also the dataset used in one of the books I’m reading about machine learning, Programming Machine Learning by Paolo Perrotta.
Update 22 Mar 2025:
Here’s an online demo of a digit recognition using a MNIST trained model. It's really bad at 6s
My repo on GitHub doesn’t include the MNIST training data. The original data repository is empty (whoever owns that server probably didn’t want to pay for the traffic) so I don’t know what the license for the data is. There are copies of the data scattered around the internet. It’s not too hard to find.
I have been wanting to learn how to program neural networks for a while but none of the teaching materials I have found taught it at the level I wanted. The subject seems to be taught in one of two ways: heavy mathematical theory or Python. My goal is to understand how to code neural networks from first principles. I want to have a high level understanding of the maths not wade through PHD level formulas. And I didn’t want to simply grab a bunch of Python libraries, wire them together, and call it a day; I wanted to grasp the underlying mechanics. I also wanted to use Rust.
Programming Machine Learning is the closest I have found but the code is in Python. I have been working through the book porting his code to Rust but pretty soon he jumped into numpy
and then I got stuck. It felt a bit like the how to draw an owl meme There are Rust equivalents of numpy
2 but I didn’t want to use them. So I was really pleased to have found the neural-net-rs
repo. It’s at the level I wanted and almost exactly what I had been trying to write myself.
I now have a neural network written in Rust that can recognise hand written digits and more importantly I understand it so I can use it to build on as I tackle harder problems.
Confusion Matrix:
Predicted
Actual 0 1 2 3 4 5 6 7 8 9
+--------------------------------------------------
0 | 972 0 0 2 0 2 1 1 2 0
1 | 0 1126 1 2 0 1 2 1 2 0
2 | 7 1 1003 1 3 1 1 8 7 0
3 | 0 0 5 984 0 7 0 7 5 2
4 | 1 0 4 0 953 1 2 1 2 18
5 | 4 1 0 10 1 863 4 0 7 2
6 | 6 3 0 1 2 4 936 2 4 0
7 | 2 9 13 2 0 0 0 991 0 11
8 | 5 1 2 2 8 5 3 3 940 5
9 | 7 5 1 7 11 1 2 6 2 967
Per-digit Metrics:
Digit | Accuracy | Precision | Recall | F1 Score
-------|----------|-----------|---------|----------
0 | 99.2% | 96.8% | 99.2% | 98.0%
1 | 99.2% | 98.3% | 99.2% | 98.7%
2 | 97.2% | 97.5% | 97.2% | 97.3%
3 | 97.4% | 97.3% | 97.4% | 97.4%
4 | 97.0% | 97.4% | 97.0% | 97.2%
5 | 96.7% | 97.5% | 96.7% | 97.1%
6 | 97.7% | 98.4% | 97.7% | 98.1%
7 | 96.4% | 97.2% | 96.4% | 96.8%
8 | 96.5% | 96.8% | 96.5% | 96.7%
9 | 95.8% | 96.2% | 95.8% | 96.0%
Overall Accuracy: 97.35%
I have no idea if this is good or not 😂
Why am I bothering? It’s really interesting and I have a couple of projects that I want to try tackling with machine learning. More about that later.
Metrics Explained
Confusion Matrix
The matrix is a 10x10 grid where:
- Rows represent the actual digit (0-9)
- Columns represent what the model predicted (0-9)
- Each cell [i][j] contains the count of how many times:
- The actual digit was i
- The model predicted j
Example:
Actual 5: | 12 1 0 39 4 808 8 3 11 6
| ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
Predicted | 0 1 2 3 4 5 6 7 8 9
This row shows that when the actual digit was 5:
- 808 times it was correctly predicted as 5
- 39 times it was incorrectly predicted as 3
- 12 times it was incorrectly predicted as 0
- etc.
The diagonal (where row index equals column index) shows correct predictions. Everything off the diagonal represents mistakes.
Per-digit Metrics
- Accuracy - Percentage of correct predictions for a specific digit
- Precision - Of all cases predicted as digit X, what percentage were actually X
- Recall - Of all actual cases of digit X, what percentage were correctly identified
- F1 Score - Harmonic mean of precision and recall (2 * precision * recall)/(precision + recall)
At some point I might merge or swap them ¯\_(ツ)_/¯↩