Hyperparamètres, LassoRandomForestRregressor et grid_search (correction)#

Links: notebook, html, python, slides, GitHub

Le notebook explore l’optimisation des hyper paramaètres du modèle LassoRandomForestRegressor, et fait varier le nombre d’arbre et le paramètres alpha.

from jyquickhelper import add_notebook_menu
add_notebook_menu()
%matplotlib inline

Données#

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
data = load_diabetes()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

Premiers modèles#

from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score

rf = RandomForestRegressor()
rf.fit(X_train, y_train)
r2_score(y_test, rf.predict(X_test))
0.3166064611454491

Pour le modèle, il suffit de copier coller le code écrit dans ce fichier lasso_random_forest_regressor.py.

from ensae_teaching_cs.ml.lasso_random_forest_regressor import LassoRandomForestRegressor
lrf = LassoRandomForestRegressor()
lrf.fit(X_train, y_train)
r2_score(y_test, lrf.predict(X_test))
0.20558896981102492

Le modèle a réduit le nombre d’arbres.

len(lrf.estimators_)
97

Evolution de la performance en fonction des paramètres#

grid.cv_results_
{'mean_fit_time': array([0.051863  , 0.11151867, 0.16286798, 0.20638132, 0.24587946,
        0.30230732, 0.04886923, 0.10883999, 0.1585783 , 0.21171408,
        0.25670881, 0.30813308, 0.04687281, 0.10599108, 0.16779151,
        0.21490512, 0.24286323, 0.37416844, 0.04798951, 0.10375576,
        0.13916297, 0.19486108, 0.23168812, 0.35405369, 0.04832931,
        0.10837116, 0.17046494, 0.21563282, 0.250454  , 0.30722728,
        0.0500711 , 0.10197167, 0.14489303, 0.19933763, 0.31132407,
        0.69930143]),
 'std_fit_time': array([0.00362419, 0.01626225, 0.00804797, 0.01572331, 0.00662523,
        0.01574959, 0.00169066, 0.0097691 , 0.0132841 , 0.0106317 ,
        0.01988724, 0.02359756, 0.00126011, 0.00448715, 0.00627981,
        0.02519122, 0.02605425, 0.09337497, 0.01102544, 0.00824485,
        0.00715579, 0.01587819, 0.006515  , 0.04939259, 0.00602516,
        0.00652839, 0.01898743, 0.01727985, 0.01794094, 0.02079929,
        0.00562965, 0.00345422, 0.00807745, 0.00482911, 0.09500837,
        0.11143193]),
 'mean_score_time': array([0.00239778, 0.00359111, 0.00518904, 0.00718164, 0.00817652,
        0.01257362, 0.0021884 , 0.00339103, 0.00539336, 0.00738797,
        0.00917087, 0.00998683, 0.00199485, 0.00379586, 0.00599022,
        0.0103807 , 0.01236439, 0.00837784, 0.00431471, 0.00392194,
        0.00887637, 0.00752082, 0.00937295, 0.01437345, 0.00079789,
        0.00312424, 0.00479422, 0.00718193, 0.00958648, 0.01098609,
        0.00199614, 0.0039938 , 0.0049974 , 0.00697622, 0.01322117,
        0.02559528]),
 'std_score_time': array([8.11351379e-04, 4.87586231e-04, 3.98946617e-04, 3.98891227e-04,
        4.01356881e-04, 4.68445598e-03, 3.86144056e-04, 4.75930831e-04,
        4.96522489e-04, 1.36387385e-03, 1.15770100e-03, 1.41214662e-05,
        1.39020727e-06, 4.03363736e-04, 6.28333254e-04, 9.76193348e-03,
        6.18748536e-03, 4.21257447e-03, 5.70546749e-03, 6.04969222e-03,
        6.08895072e-03, 4.95836569e-03, 7.65298131e-03, 2.73983497e-03,
        9.77213669e-04, 6.24847412e-03, 2.48103089e-03, 3.95754917e-04,
        2.06222335e-03, 1.41556299e-03, 1.58579723e-06, 1.09920549e-03,
        1.70908708e-05, 6.18028043e-04, 2.94616536e-03, 1.07247410e-02]),
 'param_lasso_estimator__alpha': masked_array(data=[0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.5, 0.5,
                    0.5, 0.5, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 1.0, 1.0,
                    1.0, 1.0, 1.0, 1.0, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
                    1.5, 1.5, 1.5, 1.5, 1.5, 1.5],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False],
        fill_value='?',
             dtype=object),
 'param_rf_estimator__n_estimators': masked_array(data=[20, 40, 60, 80, 100, 120, 20, 40, 60, 80, 100, 120, 20,
                    40, 60, 80, 100, 120, 20, 40, 60, 80, 100, 120, 20, 40,
                    60, 80, 100, 120, 20, 40, 60, 80, 100, 120],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False],
        fill_value='?',
             dtype=object),
 'params': [{'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 120}],
 'split0_test_score': array([0.48765423, 0.47007607, 0.41456128, 0.34332073, 0.36888898,
        0.370369  , 0.50945042, 0.52478666, 0.45136636, 0.38160909,
        0.46464615, 0.52426278, 0.33309356, 0.50165501, 0.5023884 ,
        0.47159884, 0.40443694, 0.42850669, 0.35280764, 0.41274937,
        0.41530415, 0.35325067, 0.461381  , 0.40458056, 0.45063697,
        0.47402597, 0.39203225, 0.58405673, 0.43074069, 0.36958539,
        0.35946403, 0.46811327, 0.43129582, 0.47471034, 0.31616108,
        0.43820558]),
 'split1_test_score': array([0.31748269, 0.32402775, 0.31309735, 0.36776797, 0.36291097,
        0.25860886, 0.32332546, 0.28310914, 0.34370404, 0.29429633,
        0.32531769, 0.30070425, 0.32083858, 0.31018103, 0.28147265,
        0.36096592, 0.33612201, 0.34993859, 0.31710402, 0.34449814,
        0.32729745, 0.29203103, 0.3028285 , 0.40849055, 0.35384028,
        0.35159579, 0.30777994, 0.34548216, 0.29892216, 0.32126091,
        0.30904616, 0.30511572, 0.30571425, 0.356684  , 0.32693294,
        0.33647908]),
 'split2_test_score': array([0.36714477, 0.28075098, 0.27797057, 0.28236282, 0.30276893,
        0.21700352, 0.38350757, 0.3370075 , 0.31649401, 0.20121556,
        0.30713851, 0.28664918, 0.33362753, 0.30618393, 0.36897318,
        0.24307011, 0.33060169, 0.32188143, 0.35355399, 0.32021347,
        0.35526908, 0.25476369, 0.26570208, 0.16455204, 0.4154126 ,
        0.30368747, 0.27953113, 0.32737498, 0.23057391, 0.31069444,
        0.36235946, 0.2807269 , 0.33147417, 0.2414187 , 0.2822582 ,
        0.24876048]),
 'split3_test_score': array([0.4043803 , 0.31910819, 0.23721216, 0.30117822, 0.24160984,
        0.29643875, 0.29444929, 0.36670958, 0.29294625, 0.35849669,
        0.28732813, 0.06164115, 0.27354921, 0.30412114, 0.31082146,
        0.23641828, 0.29371034, 0.34239524, 0.39866027, 0.36307616,
        0.2895736 , 0.31561043, 0.41537819, 0.25744729, 0.39204788,
        0.35827202, 0.3558286 , 0.25123577, 0.22871596, 0.36031404,
        0.33534641, 0.31542919, 0.29505816, 0.30829603, 0.27520299,
        0.20069686]),
 'split4_test_score': array([0.37299925, 0.29360033, 0.35534609, 0.34508877, 0.3955746 ,
        0.24485609, 0.32355244, 0.40128887, 0.25337656, 0.26202744,
        0.2442764 , 0.12475539, 0.36143398, 0.25855855, 0.27470568,
        0.37247721, 0.26957179, 0.28886332, 0.34711816, 0.35216452,
        0.30793447, 0.26319255, 0.22076315, 0.197187  , 0.29571515,
        0.30295817, 0.27574516, 0.32196883, 0.32617658, 0.23406369,
        0.30742707, 0.37246999, 0.1981131 , 0.35704234, 0.26689645,
        0.29602189]),
 'mean_test_score': array([0.38993225, 0.33751266, 0.31963749, 0.3279437 , 0.33435067,
        0.27745524, 0.36685703, 0.38258035, 0.33157744, 0.29952902,
        0.32574138, 0.25960255, 0.32450857, 0.33613993, 0.34767227,
        0.33690607, 0.32688855, 0.34631705, 0.35384882, 0.35854033,
        0.33907575, 0.29576967, 0.33321058, 0.28645149, 0.38153057,
        0.35810788, 0.32218342, 0.3660237 , 0.30302586, 0.31918369,
        0.33472862, 0.34837101, 0.3123311 , 0.34763028, 0.29349033,
        0.30403278]),
 'std_test_score': array([0.05623747, 0.06818182, 0.06141412, 0.03133811, 0.05541701,
        0.05303893, 0.07697179, 0.0809888 , 0.06683068, 0.06528952,
        0.07450222, 0.16114489, 0.02874254, 0.08486524, 0.08420844,
        0.08819219, 0.04582314, 0.04622048, 0.0260949 , 0.03054825,
        0.04389069, 0.03592898, 0.09088902, 0.10248598, 0.05322657,
        0.06242197, 0.04515318, 0.11364082, 0.07434396, 0.04807043,
        0.0235824 , 0.06700877, 0.0747087 , 0.07635179, 0.02366514,
        0.08105877]),
 'rank_test_score': array([ 1, 14, 26, 21, 18, 35,  4,  2, 20, 31, 23, 36, 24, 16, 10, 15, 22,
        12,  8,  6, 13, 32, 19, 34,  3,  7, 25,  5, 30, 27, 17,  9, 28, 11,
        33, 29])}
import numpy
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(131, projection='3d')
xs = numpy.array([el['lasso_estimator__alpha'] for el in grid.cv_results_['params']])
ys = numpy.array([el['rf_estimator__n_estimators'] for el in grid.cv_results_['params']])
zs = numpy.array(grid.cv_results_['mean_test_score'])
ax.scatter(xs, ys, zs)
ax.set_title("3D...")

ax = fig.add_subplot(132)
for x in sorted(set(xs)):
    y2 = ys[xs == x]
    z2 = zs[xs == x]
    ax.plot(y2, z2, label="alpha=%1.2f" % x, lw=x*2)
ax.legend();

ax = fig.add_subplot(133)
for y in sorted(set(ys)):
    x2 = xs[ys == y]
    z2 = zs[ys == y]
    ax.plot(x2, z2, label="n_estimators=%d" % y, lw=y/40)
ax.legend();
../_images/ml_lasso_rf_grid_search_correction_22_0.png

Il semble que la valeur de alpha importe peu mais qu’un grand nombre d’arbres a un impact positif. Cela dit, il faut ne pas oublier l’écart-type de ces variations qui n’est pas négligeable.