I'm trying to use the Genetic Algorithm library, PyGAD to cross validate hyper-parameters in non-DL machine learning applications.
Data generation
## Create synthetic data
f = np.vectorize(lambda x: int(x>0.5))
X = np.random.uniform(low=0, high=1, size=3000)
X = np.apply_along_axis(f, 0, X).reshape([1000,3])
def gen_y(X):
if np.sum(X) > 2:
return 1
elif np.sum(X) < 2:
return 0
else:
return int(np.random.uniform(0,1) > 0.5)
y = np.apply_along_axis(gen_y, 1, X)
import pandas as pd
data = pd.DataFrame(data=X, columns=['x1','x2','x3'])
data['y'] = y
Genetic Algorithm to ID optimal hyper-parameters
gene_space = [
# n_estimators
np.linspace(50,200,25, dtype='int'),
# min_samples_split,
np.linspace(2,10,5, dtype='int'),
# min_samples_leaf,
np.linspace(1,10,5, dtype='int'),
# min_impurity_decrease
np.linspace(0,1,10, dtype='float')
]
def fitness_function_factory(data=data, y_name='y', sample_size=100):
def fitness_function(solution, solution_idx):
model = RandomForestClassifier(
n_estimators=solution[0],
min_samples_split=solution[1],
min_samples_leaf=solution[2],
min_impurity_decrease=solution[3]
)
X = data.drop(columns=[y_name])
y = data[y_name]
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.5)
train_idx = sample_without_replacement(n_population=len(X_train),
n_samples=sample_size)
test_idx = sample_without_replacement(n_population=len(X_test),
n_samples=sample_size)
model.fit(X_train[train_idx], y_train[train_idx])
fitness = model.score(X_test[test_idx], y_test[test_idx])
return fitness
return fitness_function
cross_validate = pygad.GA(gene_space=gene_space,
fitness_func=fitness_function_factory(),
num_generations=100,
num_parents_mating=2,
sol_per_pop=8,
num_genes=len(gene_space),
parent_selection_type='sss',
keep_parents=2,
crossover_type="single_point",
mutation_type="random",
mutation_percent_genes=25)
The last step is identifying the optimal combination via best_solution
method.
cross_validate.best_solution()
>>>
KeyError: "None of [Int64Index([119, 342, 34, 80, 94, 270, 443, 468, 401, 133, 400, 362, 455,\n 480, 449, 271, 303, 399, 462, 237, 152, 264, 281, 301, 435, 386,\n 92, 453, 378, 290, 235, 64, 394, 70, 174, 215, 22, 244, 155,\n 207, 74, 147, 178, 267, 347, 97, 396, 292, 120, 375, 113, 169,\n 460, 43, 168, 298, 37, 300, 91, 331, 388, 321, 481, 96, 308,\n 211, 478, 464, 8, 170, 73, 175, 172, 487, 263, 213, 146, 479,\n 336, 346, 67, 160, 277, 397, 38, 7, 247, 128, 47, 428, 454,\n 313, 257, 338, 199, 381, 60, 245, 324, 223],\n dtype='int64')] are in the [columns]"
So I believe that this error means that optimal configuration cannot be found. But how is this possible and how can this be remedied?