Skip to content

haifeng-jin/readable-ml-framework

Repository files navigation

Readable ML Framework

A machine learning framework with readable source code.

Check out this blog post for a detailed introduction.

Machine learning frameworks can be intimidating. Their codebases are often massive and complex, making source code nearly impossible to read.

Fortunately, there's the Readable ML Framework, which contains only about 800 lines of actual code (not counting comments) written in Python and C++. The code is thoroughly documented, bringing the total to around 2,000 lines.

The features of this framework are just enough to implement a simple neural network for a basic classification problem. By reading through it, you can easily grasp the fundamentals of how an ML framework works.

Here is a basic example of what it can do:

import numpy as np

from framework import ops
from framework.tensor import Tensor

# Create input tensors
x = Tensor.from_numpy(np.array([[2.0, 3.0]], dtype=np.float32))  # shape (1, 2)
y = Tensor.from_numpy(
    np.array([[4.0], [5.0]], dtype=np.float32)
)  # shape (2, 1)

# Perform matrix multiplication
z = ops.matmul(x, y)  # Expected: [[2*4 + 3*5]] = [[23.0]]
s = ops.sum(z)

# Trigger backward propagation
s.backward()

# Print gradients
print("x.grad:", x.grad.numpy())  # Expected: [[4.0, 5.0]]
print("y.grad:", y.grad.numpy())  # Expected: [[2.0], [3.0]]

Also, feel free to check the full classification example.

Disclaimer

This repo is mainly for educational purposes only and nowhere near a feature-complete ML framework. It is for people, who want to learn the internal mechanisms of ML frameworks, like TensorFlow, PyTorch, and JAX.

It implements the eager mode of execution with the tensor data structure and operators in C++ and exposes them in Python APIs. The operators are implemented with multi-threading for speed optimization.

The code is structured in a way that is easiest for people to read. All complex features, like sanity checks for function arguments, GPU support, distributed training, data types of different precisions, asynchronous dispatch, compilers, model serialization, and model export, are not implemented.

How to use

You can read the codebase in the following steps:

Install for development

I used a conda environment for easier setup.

Install the dependencies:

conda install -c conda-forge cxx-compiler clang-format
pip install -r requirements.txt

Install the project for dev mode:

pip install -e .

About

A machine learning framework with readable source code

Resources

License

Stars

Watchers

Forks

Packages

No packages published