sklearn support vector machine is not learning

1.3k Views Asked by At

I am trying to classify images using sklearn's svm.SVC classifier, but it's not learning, after training I got 0.1 accuracy (there are 10 classes, so 0.1 accuracy is the same as a random guess)

I am using the CIFAR-10 datatset. 10000 images that are represented as 3072 uint8s. The first 1024 are the red pixels, the second 1024 are the green pixels and the thirst 1024 are the blue pixels.

Each image also has a label, which is a number 0-9

Here is my code:

import numpy as np
from sklearn import preprocessing, svm
import pandas as pd
import pickle
from sklearn.externals import joblib

train_data = pickle.load(open('data_batch_1','rb'), encoding='latin1')
test_data = pickle.load(open('test_batch','rb'), encoding='latin1')

X_train = np.array(train_data['data'])
y_train = np.array(train_data['labels'])
X_test = np.array(test_data['data'])
y_test = np.array(test_data['labels'])

clf = svm.SVC(verbose=True)
clf.fit(X_train, y_train)

accuracy = clf.score(X_test, y_test)

joblib.dump(clf, 'Cifar-10-clf.pickle')

print(accuracy)

Does anyone know what my problem could be or can point me to resources to solve this?

1

There are 1 best solutions below

0
On

I'm not sure but I think that you need to tune the parameters of SVC.

I tested some parameters for learning then I got an 0.318 accuracy.

here is code:

# coding: utf-8

import numpy as np
from sklearn import preprocessing, svm
import cPickle

train_data = cPickle.load(open('data/data_batch_1', 'rb'))
test_data = cPickle.load(open('data/test_batch', 'rb'))

X_train = np.array(train_data['data'])
y_train = np.array(train_data['labels'])
X_test = np.array(test_data['data'][:1000])
y_test = np.array(test_data['labels'][:1000])

clf = svm.SVC(kernel='linear', C=10, gamma=0.01)
clf.fit(X_train, y_train)

accuracy = clf.score(X_test, y_test)

print "Accuracy: ", accuracy

And I recommend grid search function for auto tuning the hyper-parameters.

This is public documents about tuning the hyper-parameters in scikit-learn