I am studying sklearn and I write a class Classifier to do common classification. It need a method to determine using which Estimator:
# Classifier
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
class Classifier(object):
def __init__(self, method='LinearSVC', *args, **kwargs):
Estimator = getattr(**xxx**, method, None)
self.Estimator = Estimator
self._model = Estimator(*args, **kwargs)
def fit(self, data, target):
return self._model.fit(data, target)
def predict(self, data):
return self._model.predict(data)
def score(self, X, y, sample_weight=None):
return self._model.score(X, y, sample_weight=None)
def persist_model(self):
pass
def get_model(self):
return self._model
def classification_report(self, expected, predicted):
return metrics.classification_report(expected, predicted)
def confusion_matrix(self, expected, predicted):
return metrics.confusion_matrix(expected, predicted)
I want to get Estimator by name, but what xxx should be?
Or is there a better way to do this?
Build a dict to store the imported module? but this way seems not so good..
Built in function globals() does the trick: you can check that
globals()['LogisticRegression'] is LogisticRegressionreturnsTrue.ADDENDUM
globals()[method]some_method_dict[method]globals()[method]is just the shortest answer to the question.If this is pythonic or not, I don't know, but the
globals()builtin is there to be used, so why chose more complicated solutions?To be explicit,
can be implemented as
if the
Nonereturn is preferred to aKeyErrorexception ifmethodwas not imported.