Hyperparamètres, LassoRandomForestRregressor et grid_search (correction)

Links: notebook, html, PDF, python, slides, GitHub

Le notebook explore l’optimisation des hyper paramaètres du modèle LassoRandomForestRegressor, et fait varier le nombre d’arbre et le paramètres alpha.

from jyquickhelper import add_notebook_menu
add_notebook_menu()
%matplotlib inline

Données

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
data = load_boston()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y)

Premiers modèles

from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score

rf = RandomForestRegressor()
rf.fit(X_train, y_train)
r2_score(y_test, rf.predict(X_test))
0.7968308255996621

Pour le modèle, il suffit de copier coller le code écrit dans ce fichier lasso_random_forest_regressor.py.

from ensae_teaching_cs.ml.lasso_random_forest_regressor import LassoRandomForestRegressor
lrf = LassoRandomForestRegressor()
lrf.fit(X_train, y_train)
r2_score(y_test, lrf.predict(X_test))
C:xavierdupre__home_github_forkscikit-learnsklearnlinear_modelcoordinate_descent.py:475: ConvergenceWarning: Objective did not converge. You might want to increase the number of iterations. Duality gap: 20.049045743243255, tolerance: 3.3370377783641163
  positive)
0.7935381412849827

Le modèle a réduit le nombre d’arbres.

len(lrf.estimators_)
40

Evolution de la performance en fonction des paramètres

grid.cv_results_
{'mean_fit_time': array([0.05663986, 0.10452576, 0.16595559, 0.22539721, 0.27168226,
        0.33429265, 0.05565419, 0.10729852, 0.17114191, 0.23077893,
        0.28304234, 0.33431249, 0.05604181, 0.1063158 , 0.17093563,
        0.22779045, 0.26768265, 0.35245051, 0.06323156, 0.14401121,
        0.19288373, 0.25152674, 0.29241681, 0.32273645, 0.05465369,
        0.10490837, 0.17751675, 0.24574986, 0.27805681, 0.32074413,
        0.06244035, 0.10291305, 0.16654687, 0.21403375, 0.2806407 ,
        0.3367074 ]),
 'std_fit_time': array([0.00683437, 0.00293675, 0.00649722, 0.00437011, 0.00134492,
        0.00543854, 0.00469317, 0.00437525, 0.01169219, 0.00994201,
        0.00682215, 0.01900377, 0.00391298, 0.00280159, 0.00885593,
        0.01138138, 0.00560277, 0.02105583, 0.00461818, 0.04150351,
        0.00681097, 0.02927737, 0.02726511, 0.00809343, 0.00521671,
        0.00265498, 0.02124133, 0.0162238 , 0.00868051, 0.00878851,
        0.00323922, 0.00170647, 0.00951473, 0.00600412, 0.01469903,
        0.01193352]),
 'mean_score_time': array([0.00199561, 0.00278964, 0.00359268, 0.00418863, 0.00398006,
        0.00439715, 0.00179882, 0.00240097, 0.00259347, 0.00359073,
        0.00379052, 0.00458755, 0.00179596, 0.00260038, 0.00338974,
        0.00378971, 0.00419002, 0.00399599, 0.00239301, 0.00439   ,
        0.00339122, 0.00399537, 0.00379014, 0.00378962, 0.0026103 ,
        0.00280147, 0.00319166, 0.00537782, 0.00378947, 0.00398731,
        0.00179567, 0.00240479, 0.00319138, 0.00319166, 0.00399151,
        0.00478716]),
 'std_score_time': array([2.30860108e-06, 7.35515012e-04, 8.01309620e-04, 1.16260605e-03,
        8.85867125e-04, 8.03303461e-04, 3.83656712e-04, 4.95055787e-04,
        4.88870129e-04, 7.98297389e-04, 7.46073178e-04, 4.88831366e-04,
        7.47298023e-04, 4.94227888e-04, 4.86466750e-04, 1.16299839e-03,
        7.47195419e-04, 6.31409439e-04, 7.98106470e-04, 1.85030289e-03,
        7.97952120e-04, 6.31174838e-04, 7.46799029e-04, 3.99208553e-04,
        5.00776113e-04, 4.03363522e-04, 3.99852179e-04, 1.85512089e-03,
        7.46620842e-04, 6.41733605e-04, 3.98707686e-04, 4.94055409e-04,
        3.98802796e-04, 3.99852435e-04, 1.09654222e-03, 7.46786568e-04]),
 'param_lasso_estimator__alpha': masked_array(data=[0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.5, 0.5, 0.5, 0.5,
                    0.5, 0.5, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 1.0, 1.0,
                    1.0, 1.0, 1.0, 1.0, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
                    1.5, 1.5, 1.5, 1.5, 1.5, 1.5],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False],
        fill_value='?',
             dtype=object),
 'param_rf_estimator__n_estimators': masked_array(data=[20, 40, 60, 80, 100, 120, 20, 40, 60, 80, 100, 120, 20,
                    40, 60, 80, 100, 120, 20, 40, 60, 80, 100, 120, 20, 40,
                    60, 80, 100, 120, 20, 40, 60, 80, 100, 120],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False, False, False, False, False,
                    False, False, False, False],
        fill_value='?',
             dtype=object),
 'params': [{'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 0.25, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 0.5, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 0.75, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 1.0, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 1.25, 'rf_estimator__n_estimators': 120},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 20},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 40},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 60},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 80},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 100},
  {'lasso_estimator__alpha': 1.5, 'rf_estimator__n_estimators': 120}],
 'split0_test_score': array([0.8940202 , 0.87745659, 0.89267006, 0.88773409, 0.88764713,
        0.88721727, 0.88030235, 0.88229147, 0.88338305, 0.8841744 ,
        0.87568931, 0.88040145, 0.88224818, 0.86803421, 0.8782666 ,
        0.88402512, 0.87983085, 0.89479441, 0.87792973, 0.8886184 ,
        0.8786744 , 0.8775296 , 0.88777394, 0.8873304 , 0.88272801,
        0.87319492, 0.88929082, 0.87901683, 0.87689482, 0.87903615,
        0.88157204, 0.87444327, 0.88596569, 0.88347627, 0.87777914,
        0.88931924]),
 'split1_test_score': array([0.79704484, 0.84485213, 0.80023779, 0.8581134 , 0.81708241,
        0.88364671, 0.78465591, 0.85299111, 0.8530547 , 0.85618624,
        0.87052386, 0.88326918, 0.85063816, 0.79320864, 0.85897771,
        0.88518101, 0.84289528, 0.83146353, 0.87010914, 0.84243895,
        0.86785386, 0.81390465, 0.86590262, 0.83099966, 0.8427512 ,
        0.85999748, 0.87778341, 0.80600007, 0.82874117, 0.86055027,
        0.78864461, 0.87648724, 0.86358345, 0.84660906, 0.86923607,
        0.8515646 ]),
 'split2_test_score': array([0.91389781, 0.92490052, 0.93859937, 0.90891467, 0.94038985,
        0.93379336, 0.92483562, 0.94347406, 0.92691771, 0.92913799,
        0.93273884, 0.93366021, 0.89996886, 0.94124611, 0.92968597,
        0.92822391, 0.93998711, 0.9297982 , 0.92370424, 0.94017039,
        0.93025174, 0.94268562, 0.92654441, 0.93051024, 0.92906086,
        0.9266844 , 0.93396496, 0.93268742, 0.94347578, 0.93399707,
        0.92104428, 0.93674692, 0.92924889, 0.91370101, 0.92417574,
        0.93610135]),
 'split3_test_score': array([0.87870123, 0.87949269, 0.90185329, 0.89336318, 0.88080773,
        0.88884397, 0.88776187, 0.89833678, 0.89136132, 0.90388563,
        0.89813607, 0.90224458, 0.846408  , 0.90441696, 0.89895751,
        0.88550841, 0.88942613, 0.90262442, 0.84258899, 0.86109812,
        0.88786268, 0.91796995, 0.90608978, 0.91226647, 0.86771125,
        0.8885238 , 0.9012345 , 0.90757908, 0.9049025 , 0.89832926,
        0.90723511, 0.89945538, 0.90380857, 0.90244211, 0.91072623,
        0.90605944]),
 'split4_test_score': array([0.87000685, 0.88453956, 0.86157008, 0.87584732, 0.86631058,
        0.87267254, 0.85968581, 0.85786009, 0.88255153, 0.83106687,
        0.88603786, 0.87960543, 0.86538772, 0.87571858, 0.88723025,
        0.88003698, 0.87751858, 0.88549839, 0.86013719, 0.86229449,
        0.89194402, 0.88812706, 0.87156138, 0.86776254, 0.8888626 ,
        0.85864187, 0.86215456, 0.86389698, 0.86713372, 0.88347567,
        0.86179769, 0.88432831, 0.88138982, 0.86985576, 0.85695583,
        0.87839447]),
 'mean_test_score': array([0.87073419, 0.8822483 , 0.87898612, 0.88479453, 0.87844754,
        0.89323477, 0.86744831, 0.8869907 , 0.88745366, 0.88089023,
        0.89262519, 0.89583617, 0.86893018, 0.8765249 , 0.89062361,
        0.89259509, 0.88593159, 0.88883579, 0.87489386, 0.87892407,
        0.89131734, 0.88804337, 0.89157443, 0.88577386, 0.88222278,
        0.88140849, 0.89288565, 0.87783608, 0.8842296 , 0.89107769,
        0.87205874, 0.89429222, 0.89279928, 0.88321684, 0.8877746 ,
        0.89228782]),
 'std_test_score': array([0.03974785, 0.02550878, 0.046408  , 0.01706329, 0.03959808,
        0.02104863, 0.04644568, 0.03269786, 0.02364982, 0.03452557,
        0.02217278, 0.02064617, 0.01997581, 0.04895129, 0.02349892,
        0.01792068, 0.03128684, 0.03226845, 0.0271143 , 0.03397223,
        0.02115936, 0.04356253, 0.02239465, 0.03472822, 0.02830809,
        0.02507887, 0.02425949, 0.04301028, 0.0383732 , 0.02461451,
        0.04645572, 0.02297537, 0.02227203, 0.02373862, 0.02547628,
        0.02817624]),
 'rank_test_score': array([34, 23, 27, 20, 29,  3, 36, 17, 16, 26,  6,  1, 35, 31, 12,  7, 18,
        13, 32, 28, 10, 14,  9, 19, 24, 25,  4, 30, 21, 11, 33,  2,  5, 22,
        15,  8])}
import numpy
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(131, projection='3d')
xs = numpy.array([el['lasso_estimator__alpha'] for el in grid.cv_results_['params']])
ys = numpy.array([el['rf_estimator__n_estimators'] for el in grid.cv_results_['params']])
zs = numpy.array(grid.cv_results_['mean_test_score'])
ax.scatter(xs, ys, zs)
ax.set_title("3D...")

ax = fig.add_subplot(132)
for x in sorted(set(xs)):
    y2 = ys[xs == x]
    z2 = zs[xs == x]
    ax.plot(y2, z2, label="alpha=%1.2f" % x, lw=x*2)
ax.legend();

ax = fig.add_subplot(133)
for y in sorted(set(ys)):
    x2 = xs[ys == y]
    z2 = zs[ys == y]
    ax.plot(x2, z2, label="n_estimators=%d" % y, lw=y/40)
ax.legend();
../_images/ml_lasso_rf_grid_search_correction_22_0.png

Il semble que la valeur de alpha importe peu mais qu’un grand nombre d’arbres a un impact positif. Cela dit, il faut ne pas oublier l’écart-type de ces variations qui n’est pas négligeable.