Illustration Image

Training a Handwritten Digits classifier in Pytorch with Cassandra

Obioma on September 15, 2023

Training a Handwritten Digits classifier in Pytorch with Cassandra

Handwritten digit recognition is one of the classic tasks undertaken by students when learning the basics of Neural Networks and Computer Vision. The basic idea is to take a number of labeled images of handwritten digits and use those to train a neural network that is able to classify new unlabeled images. Using this repo, we’ll show how to use data stored in a large-scale database as our training data (the demo uses a managed Cassandra service AstraDB as a great quick start options for this). We also explain how to use that same database as a basic model registry. This addition can enable model serving as well as future retraining.


MNIST is a set of datasets that share a particular format useful for educating students about neural networks while presenting them with diverse problems. The MNIST datasets for this demo are a collection of 28 by 28 pixel grayscale images as data and classifications 0-9 as potential labels. This demo works with the original MNIST handwritten digits dataset as well as the MNIST fashion dataset. 

The use of both of these datasets will help calibrate models, testing whether they are affected by the domain of the classification or not. If a neural net is good at classifying digits, but bad at classifying clothing and accessories, even though in this case the datasets have the same structure, it is evidence that something about the training or structure of the network contains knowledge on digits, or on handwriting, or is more suited to simple rather than complex shapes, etc. 

Pytorch is a python library that contains data, types, and methods for working with neural networks. We also make use of torchvision, a related library specifically meant for computer-vision-related tasks. Pytorch works with data typed as Tensors and can define different types of layers that can be combined to do deep learning and gain advantages that single type NNs cannot. Pytorch provides utilities that help us define, train, test, and predict using our models.

Astra is a managed database based on Apache Cassandra. It retains Cassandra’s horizontal scalability and distributed nature. Every feature that we use in today’s tutorial will be part of Cassandra 5.0 which is releasing soon. To preview these new features, we use Datastax Astra. Today, you can use Astra as the primary data store and as a primitive model registry.


For this tutorial you will need:

  • An Astra Account from DataStax, or to be familiar enough with Cassandra to use an alternative Cassandra database. Sign up for a Free Tier Astra account here
  • An environment in which to run python code. Make sure that you can install new pip modules in this environment. We’d recommend a Gitpod or your local machine.
  • A Jupyter Notebook server or access to the VSCode Jupyter plugin are also necessary.

For an effortless setup, we have provided a Gitpod quickstart option– though you’ll still need to fill in your own credentials before it will run seamlessly. Simply click on the “Open in Gitpod” button found in our GitHub repository to get started. Alternatively go to this link to open the repo in Gitpod. When creating the workspace for this project it is advantageous to select the large class of machine in order to access more RAM and have the data loader run faster. The rest of this article will assume the reader is using Gitpod to follow along unless stated otherwise.

Environment Setup

Assuming you opened the repo in Gitpod, the first thing that will happen is the required python modules will install. If you are not using Gitpod you will need to clone the repo into your environment, change into the working directory and install the prerequisites.

cd pytorch-astra-demo | pip3 install -r requirements.txt

Before you can proceed further you will need to set up your Astra database. After creating a free account you will need to create a database within that account and create a Keyspace within that database. All of this can be done purely using the Astra UI. 

Creating your AstraDB

Setting up your Astra DB

There is a ton of great documentation for how to create an Astra Database which is included in the Knowledge Resources section at the bottom of this blog.

In brief, go to, create/sign in to your account, create a new DB– they are free for most users– then create a keyspace and run the schema creation script (located at setup/create_schema.cql in the CassioML repo) in the CLI for your database.

For this demo, we use a database called cassio_db and a keyspace named mnist_digits. To create that database, select the “Databases” tab (shown on the left menu) then click on the “Create Database button” (shown on the right) and fill out the needed information for your database as shown below. The “Vector Database” comes with vector search enabled, but for this example it does not matter what option you choose.

Once you’ve created your database, you’ll need to generate the Token or Secure Connect bundle to connect to your database with the connect tab. Choose the permissions that make the most sense for your use case. For this demo, there’s nothing wrong with choosing Database Administrator, but you can also go as simple as a Read/Write Service account to get the functionality you need.

Never share your token or bundle with anyone. It is a bundle of several pieces of data about your database, and can be used to access it. 

Reminder: for this demo, the assumed name of the keyspace is mnist_digits.

Establishing the Schema

Once the keyspace has been created, we need to create the Tables that we will be using.  Open the CQL Console for the database by clicking on the “CQL Console” tab in the database view.

We can create the raw_train, raw_test, and models tables, as well as a raw_predict table holding data with no labels attached using the commands below.

CREATE TABLE mnist_digits.raw_train (id int PRIMARY KEY, label int, pixels list<int>);
CREATE TABLE mnist_digits.raw_test (id int PRIMARY KEY, label int, pixels list<int>);
CREATE TABLE mnist_digits.models_train (id uuid PRIMARY KEY, network blob, optimizer blob, upload_date timestamp, epoch int, batch_percent text, loss float);
CREATE TABLE mnist_digits.models_test (id uuid PRIMARY KEY, network blob, optimizer blob, upload_date timestamp, loss float, accuracy float);
CREATE TABLE mnist_digits.raw_predict (id int PRIMARY KEY, label int, pixels list<int>);

It is useful to define storage attached indexes (SAIs) for the models_test and models_train tables so that once the training is completed we can easily identify the best models. This will allow us to search for the models with the minimum loss and maximum accuracy.

CREATE CUSTOM INDEX loss_train_sai_idx on mnist_digits.models_train (loss) using 'StorageAttachedIndex';
CREATE CUSTOM INDEX loss_test_sai_idx on mnist_digits.models_test (loss) using 'StorageAttachedIndex';
CREATE CUSTOM INDEX accuracy_test_sai_idx on mnist_digits.models_test (accuracy) using 'StorageAttachedIndex';

Next, we need to create the resources necessary to connect to the newly created Astra database. Hit the Connect button on the UI and download the Secure Connect Bundle (SCB). Then hit the “Create a Token” button to create a Database Administrator token and download the text that it returns. 

Load the SCB into the environment and put the path to it in the file’s first line, between the single quotes. Put the generated id (Client_ID) for the Database Admin token in the second line. Put the generated secret (Client_Secret) for the token in the third line.

Then run the data loader called using this line. Modify the train_split variable in the file if you want something other than an 80/20 train/test split.


The data loader will populate the train, test, and predict tables that we created earlier. It may take an hour or more to complete because there are close to 800 columns for each data sample (Note, this is why we asked you to select the high memory option for your GitPod). Once it is complete make sure that the data was created by running these commands in the CQL Console of the Astra UI.

SELECT id, label from mnist_digits.raw_train limit 5;
SELECT id, label from mnist_digits.raw_test limit 5;
SELECT id, label from mnist_digits.raw_predict limit 5;

After that you should be able to step through the model_training_full_sequence.ipynb notebook without issue, following the comments to train and store models.

Running the Notebook

An IPython notebook is a collection of cells containing either Markdown or Python code. Each code cell can be run individually, though they share an environment so variables and imports are carried over between cells. 

If you are running the notebook in Gitpod with a VS code editor, on your first run a pop-up will ask you to select a kernel. Select “Install Python Environments”. Then click on “Select Kernel” in the top right corner, or run a cell again. Select “Python Environments”, and then “Python 3.11.1”.

When first opening the notebook, all cells should have some blank space between them and the next cell. If this is not the case, click Clear All Outputs at the top of the screen. To run an individual cell, click on that cell and press Shift+Enter or click the Run button. You can also use the Run All button at the top of the notebook which will run the cells sequentially. 

A successfully run cell should have a green check mark on it afterwards. A still-running cell will have a loading symbol somewhere in it. Each code cell has comments that describe what that particular cell does. 

The notebook will first walk the user through importing the necessary libraries, and then creating a custom Pytorch data loader that connects to Astra. After importing the rest of the required modules for model definition and training we create the data loaders for our training and testing data sets. 

Adding New DataSets

This example uses the MNIST handwritten digits dataset. This dataset consists of a set of 22 by 22 pixel grayscale images depicting digits from 0-9, meant to be classified into those 10 categories. This repo is easily modified to work with other datasets with this format. The most compatible will be other MNIST datasets, which promise to have the same 22 by 22 image size, the same grayscale pixel values, and the same 10 categories. In fact the fashion MNIST dataset here can be substituted almost exactly for the train and test csv files included in the repo. If you switch the filenames in, the rest of the repo can be used as normal.

Once we have done this and set some constants that will be used in model training we test our data loaders by loading in a single batch of examples and then extracting a single data point and examining its shape (the dimensionality of the tensor that represents the hand-drawn number). Here we should see the 28 by 28 pixel nature of the images return as part of the shape of the data. We are changing the dimensionality (it was originally 1 by 784) of the data object so that the model can process it correctly. If you see an error instead this is an indication that the data loader is unable to properly load the data, check your Astra database for missing or empty tables.

Changing the details of how we are changing the model:

Before we train the model, we define a number of constants that change how that training takes place. The first is n_epochs, which define how many training epochs we put the model through. 

During each epoch, we feed in a number of training examples before stopping, at which point we test and save the model. 

  • Batch_size_train tells us how many of our training examples get fed to the model during each epoch. 
  • Batch_size_test defines how many examples are used for testing the model after training. 
  • Learning_rate defines a property of the optimizer, changing the backpropagation step, causing it to make bigger or smaller changes to the model weights. 
  • Momentum determines how much the changes to the model weights carry between the backpropagation step. 

Because backpropagation uses calculus to determine model weight changes, the magnitude of those changes can be affected by the gradient slope of the previous backpropagation step.

After that is complete, we define a class for our neural net that holds its component layers and defines how those layers are connected to each other. We define, train, and test functions that will perform an optimization step on our model, storing the resulting model back into the Astra database and then testing the resulting model on the test data set.

Changing the structure of the model

When we create the Net class we define the structure of our model. In this example repo we set up two convolutional layers, a dropout layer, and two linear layers. The convolutional layers take a number of 2d planes as input, perform a 2d convolution and output a different number. Because our flat grayscale image fits into a single plane, our first Conv2d layer has a single input channel. Conv2d layers are specialized for image processing. 

To use a traditional RGB image as an input we would up the number of input channels to 3, one for each color. The dropout layer randomly zeroes out some channels. The linear layers apply a linear transformation to incoming data. Because our final input has 10 categories, the final linear layer has 10 output layers. They return values between 0 and 1 for each value, roughly corresponding to a probability or confidence score, and we take the highest one and count that as the prediction. They are applied in the order of: first convolutional layer, second convolutional layer, dropout layer, first linear layer, second linear layer. This order can be changed by modifying the order in which they are used in the Net classes forward method.

Once all of these are complete, you should see the accuracy of your model increasing in the space under your final cell like this.

Using the Updated Model on New Data 

In order to use this model on new data the first step is to pull the row concerning the particular model you desire out of Astra. Then you would use pkl.load on the network state object that was saved to turn it back into a dictionary object. Then we create the Net class and network object the same way we did in the notebook. Next, we call network.load_state_dict and pass it the state dictionary that we just loaded as input.

Now we have a network object with the same weights as the one we stored in Astra. We can then load new data from our test loader, whether using the test loader we create in the notebook, the data we placed in the raw+predict table, or new data that we load from somewhere else. Once we have the data and the model we can call network(data) to run the new data through the model and look through the results it gives for the predictions.


Congratulations, you have successfully trained a machine learning model to recognize hand-written numbers! If you were to use the same model to identify the number associated with a new set of hand-written numbers, your model would recognize those numbers better than an untrained model.

Getting help

You can reach out to us on the Planet Cassandra Discord Server to get specific support for this demo. You can also reach out to the Astra team through the chat on Astra’s website. Pytorch is a widely used library in the Machine Learning ecosystem. Now that you’ve gotten started with it, you can use it in a ton of ways; let us know how it helps you and your enterprise. Happy coding!


Become part of our
growing community!
Welcome to Planet Cassandra, a community for Apache Cassandra®! We're a passionate and dedicated group of users, developers, and enthusiasts who are working together to make Cassandra the best it can be. Whether you're just getting started with Cassandra or you're an experienced user, there's a place for you in our community.
A dinosaur
Planet Cassandra is a service for the Apache Cassandra® user community to share with each other. From tutorials and guides, to discussions and updates, we're here to help you get the most out of Cassandra. Connect with us and become part of our growing community today.
© 2009-2023 The Apache Software Foundation under the terms of the Apache License 2.0. Apache, the Apache feather logo, Apache Cassandra, Cassandra, and the Cassandra logo, are either registered trademarks or trademarks of The Apache Software Foundation.

Get Involved with Planet Cassandra!

We believe that the power of the Planet Cassandra community lies in the contributions of its members. Do you have content, articles, videos, or use cases you want to share with the world?