import numpy as np import matplotlib.pyplot as plt from struct import unpack def loadmnist(imagefile, labelfile): '''Gist Code Code taken from Gist: https://gist.github.com/mGalarnyk/aa79813d7ecb0049c7b926d53f588ae1 ''' # Open the images with gzip in read binary mode images = open(imagefile, 'rb') labels = open(labelfile, 'rb') # Get metadata for images images.read(4) # skip the magic_number number_of_images = images.read(4) number_of_images = unpack('>I', number_of_images)[0] rows = images.read(4) rows = unpack('>I', rows)[0] cols = images.read(4) cols = unpack('>I', cols)[0] # Get metadata for labels labels.read(4) N = labels.read(4) N = unpack('>I', N)[0] # Get data x = np.zeros((N, rows*cols), dtype=np.uint8) # Initialize numpy array y = np.zeros(N, dtype=np.uint8) # Initialize numpy array for i in range(N): for j in range(rows*cols): tmp_pixel = images.read(1) # Just a single byte tmp_pixel = unpack('>B', tmp_pixel)[0] x[i][j] = tmp_pixel tmp_label = labels.read(1) y[i] = unpack('>B', tmp_label)[0] images.close() labels.close() return (x, y) def showimage(image): plt.imshow(np.reshape(image, (28,28)), cmap=plt.cm.gray)