KNN algorithm from OCR project having trouble with certain digits

49 Views Asked by At

Im trying to code a OCR project to identify digits without using any extra libraries. Im using Pygame to draw and render the number to identify. I managed to get all of it working, but for some weird reason, some specific numbers like 9 or 6 never ever get recognized by the algorithm(not even when I set the k really high, none of the possibilities is that number). I have absolutely no idea why, and any help would be really appreciated

import time
import pygame
from sys import exit
import random
import math
number_comparisons = 60000
DATA_DIR = r"C:/Users/----/Downloads/OCR/"
TEST_DIR = r"C:/Users/----/Downloads/OCR/test/"
TEST_DATA_FILENAME = DATA_DIR + "t10k-images.idx3-ubyte"
TEST_LABELS_FILENAME = DATA_DIR + "t10k-labels.idx1-ubyte"
TRAIN_DATA_FILENAME = DATA_DIR + "train-images.idx3-ubyte"
TRAIN_LABELS_FILENAME = DATA_DIR + "train-labels.idx1-ubyte"
start_time = time.time()
DEBUG = True
#starts pygame (images and sounds)
pygame.init()
pygame.mixer.init()
#create screen variable
width = 784
height = 850
screen = pygame.display.set_mode((width,height))

#ponerle titulo
pygame.display.set_caption("Optical Character Recognition")

#reloj/framerate
clock = pygame.time.Clock()


def read_labels(filename,n_max_labels = None):
    labels = [] #variable que guarda todas las imagenes
    with open(filename, "rb") as f: #abrir el fichero filename como f, y leerlo en binario ("rb")
        _  =f.read(4) #numero inutil (representa algo que no necesitamos)
        #los siguientes 12 bytes representan el numero de imagenes, el numero de filas y de columnas
        n_labels = bytes_to_int(f.read(4))
        if n_max_labels:
            n_labels = n_max_labels
        for label_idx in range(n_labels):
            label = f.read(1)
            labels.append(label)
            
    return labels

count = 0

def read_images(filename,n_max_images = None):
    global count
    images = [] #variable que guarda todas las imagenes
    with open(filename, "rb") as f: #abrir el fichero filename como f, y leerlo en binario ("rb")
        _  =f.read(4) #numero inutil (representa algo que no necesitamos)
        #los siguientes 12 bytes representan el numero de imagenes, el numero de filas y de columnas
        n_images = bytes_to_int(f.read(4))
        if n_max_images:
            n_images = n_max_images
        n_rows = bytes_to_int(f.read(4))
        n_columns = bytes_to_int(f.read(4))
        for image_idx in range(n_images):
            image = []#variable que guarda la imagen actual
            for row_idx in range(n_rows):
                row = []#variable que guarda la columna actual
                for column_idx in range(n_columns):
                    count += 1
                    pixel = f.read(1) #leemos el pixel actual de 8 bits y lo apendizamos a la row
                    row.append(pixel)
                image.append(row)#metemos la row en la image
            images.append(image)#metemos la image en el conjunto de images
    return images

def bytes_to_int(byte_data):
    if byte_data == 0:
        return 0
    elif byte_data == 200:
        return 255
    elif byte_data == 200:
        return 200
    else:
        return int.from_bytes(byte_data,"big")

def pasar_lista_unidimensional(X):
    lista = []
    for i in range(len(X)):
        for j in range(len(X[0])):
            lista.append(X[i][j])
    return [lista]
def pasar_lista_unidimensional2(X):
    return [aplanar_lista(sample) for sample in X]
def aplanar_lista(l):
    return [pixel for sublist in l for pixel in sublist]
def dist(x,y):
    temp2 = []
    for x_i,y_i in zip(x,y):
        temp2.append((bytes_to_int(x_i) - bytes_to_int(y_i)) **2)
    return sum(temp2)**0.5
    return sum((bytes_to_int(x_i) - bytes_to_int(y_i)) **2 for x_i,y_i in zip(x,y))**0.5 #distancia euclides
def distancia_entre_samples(X_train,test_sample):
    return [dist(train_sample,test_sample) for train_sample in X_train] #por todas las imagenes, calculamos su distancia arriba

def most_frequent_element(list):
    return max(list, key= list.count)


def knn(X_train,y_train,X_test, k = 3):
    y_pred = [] #la prediccion que tenemos a los x_test
    print("Using the knn algorithm to determine k nearest numbers to drawing...")
    for test_sample_idx,test_sample in enumerate(X_test):
        training_distances = distancia_entre_samples(X_train,test_sample) #queremos conseguir las distancias a todos los puntos
        sorted_distance_indices = [
            pair[0]
            for pair in sorted(enumerate(training_distances), key = lambda x: x[1]) ]#escogemos la menor distancia
        candidates = [bytes_to_int(y_train[idx]) for idx in sorted_distance_indices[:k]] # k mejores candidatos
        print("Top k choices were", candidates)
        top_candidate = most_frequent_element(candidates)
        y_pred.append(top_candidate) #apuntamos a predicción
    return y_pred
def main():
    global X_test
    print("Reading training files...")
    #"X" es igual al dataset y "y" es el label asignado
    X_train = read_images(TRAIN_DATA_FILENAME,number_comparisons)
    y_train = read_labels(TRAIN_LABELS_FILENAME,number_comparisons)
    y_test = read_labels(TEST_LABELS_FILENAME,1)

    
    print("Converting drawing to 2D grid")
    X_train = pasar_lista_unidimensional2(X_train) #queremos pasar la matriz de valores a una matriz unidimensional
    y_pred = knn(X_train,y_train,X_test,15)
    print("The number you have just written is: " ,y_pred)
    print("Number of iterations: ", count)
image_array = []  
font = pygame.font.Font(None, 24) 
if __name__ == "__main__":
    pass
for i in range(0,784,28):
    pygame.draw.line(screen,"white",(0,i),(784,i))
for j in range(0,width,28):
    image_array.append([0]*28)
    pygame.draw.line(screen,"white",(j,0),(j,784))
print(type(image_array),type(image_array[0]),type(image_array[0][0]))
def draw(x,y):
    pygame.draw.rect(screen,"white",(x-x%28,y-y%28,28,28)) 
    image_array[math.trunc(y/28)][math.trunc(x/28)] = 200
    
    pygame.draw.rect(screen,"white",(x-x%28+28,y-y%28,28,28)) 
    if math.trunc(x/28)+1 < 28: image_array[math.trunc(y/28)][math.trunc(x/28)+1] = 200
    
    pygame.draw.rect(screen,"white",(x-x%28-28,y-y%28,28,28)) 
    if math.trunc(x/28)-1 >= 0:  image_array[math.trunc(y/28)][math.trunc(x/28)-1] = 200
    
    pygame.draw.rect(screen,"white",(x-x%28,y-y%28+28,28,28)) 
    if math.trunc(y/28)+1 < 28:  image_array[math.trunc(y/28)+1][math.trunc(x/28)] = 200
    
    pygame.draw.rect(screen,"white",(x-x%28,y-y%28-28,28,28)) 
    if math.trunc(y/28)-1 >= 0: image_array[math.trunc(y/28)-1][math.trunc(x/28)] = 200
    
button = False
text = font.render("Guess Number", True, (0, 0, 0))
while True:
    #para que se pueda cerrar
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            pygame.QUIT()
            exit()
        if event.type == pygame.MOUSEBUTTONDOWN:
            button = True
            
        elif event.type == pygame.MOUSEBUTTONUP:
            button = False
    if button == True:
        pos = pygame.mouse.get_pos()
        if pos[1] < 756:
            draw(pos[0],pos[1])
        else:
            X_test = pasar_lista_unidimensional(image_array)
            print(X_test)
            print("Loading...")
            main()
    pygame.draw.rect(screen,"red",(0,784,784,66))
    pygame.display.update()
    
    #framerate
    clock.tick(60)
0

There are 0 best solutions below