DBN implementation for the MNIST dataset
Let's look at how the DBN class implemented earlier is used for the MNIST dataset.
Loading the dataset
First, we load the dataset from idx3 and idx1 formats into test, train, and validation sets. We need to import TensorFlow common utilities that are defined in the common module explained here:
import tensorflow as tf from common.models.boltzmann import dbn from common.utils import datasets, utilities
trainX, trainY, validX, validY, testX, testY =
datasets.load_mnist_dataset(mode='supervised')You can find details about load_mnist_dataset() in the following code listing. As mode='supervised' is set, the train, test, and validation labels are returned:
def load_mnist_dataset(mode='supervised', one_hot=True):
mnist = input_data.read_data_sets("MNIST_data/", one_hot=one_hot)
# Training set
trX = mnist.train.images
trY = mnist.train.labels
# Validation set
vlX = mnist.validation.images
vlY = mnist.validation.labels
# Test set
...