.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/ml_basic/plot_grid_search.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_ml_basic_plot_grid_search.py: Grid Search =========== Dans la plupart des cas, l'apprentissage des modèles dépend d'hyperparamètres. Les valeurs par défaut choisies par les auteurs des librairies de machine learning fonctionnent dans la plupart des cas et dans d'autres il peut être utile de les optimiser. La méthode la plus simple consiste à essayer différentes valeurs puis à retenir celle qui minimise l'erreur sur la base de test. .. contents:: :local: .. GENERATED FROM PYTHON SOURCE LINES 19-24 Paramètre par défaut -------------------- On commence par générer un jeu de données artificiel pour une régression. .. GENERATED FROM PYTHON SOURCE LINES 25-33 .. code-block:: default import pandas from sklearn.model_selection import GridSearchCV from sklearn.linear_model import Lasso import matplotlib.pyplot as plt from sklearn.datasets import make_friedman1 X, Y = make_friedman1(n_samples=500, n_features=5) .. GENERATED FROM PYTHON SOURCE LINES 34-35 On représente ces données. .. GENERATED FROM PYTHON SOURCE LINES 35-40 .. code-block:: default fig = plt.figure(figsize=(5, 5)) ax = plt.subplot() ax.plot(X[:, 0], Y, '.') .. image-sg:: /gyexamples/ml_basic/images/sphx_glr_plot_grid_search_001.png :alt: plot grid search :srcset: /gyexamples/ml_basic/images/sphx_glr_plot_grid_search_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none [] .. GENERATED FROM PYTHON SOURCE LINES 41-45 On choisira un modèle de régression linéaire avec une contrainte sur les coefficients `Lasso `_. .. GENERATED FROM PYTHON SOURCE LINES 45-49 .. code-block:: default reglin = Lasso() reglin.fit(X, Y) .. raw:: html
Lasso()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 50-52 L'optimisation du modèle produit une droite dont les coefficients sont : .. GENERATED FROM PYTHON SOURCE LINES 52-54 .. code-block:: default print(reglin.coef_, reglin.intercept_) .. rst-class:: sphx-glr-script-out .. code-block:: none [0. 0. 0. 0. 0.] 14.409882036390346 .. GENERATED FROM PYTHON SOURCE LINES 55-58 On reprend le premier graphe est on y ajoute la droite qui correspond à la régression linéaire uniquement sur la première dimension. .. GENERATED FROM PYTHON SOURCE LINES 58-69 .. code-block:: default reglin = Lasso() reglin.fit(X[:, :1], Y) fig = plt.figure(figsize=(5, 5)) ax = plt.subplot() x = list(sorted(X[:, :1])) y = reglin.predict(x) ax.plot(X[:, 0], Y, '.') ax.plot(x, y) .. image-sg:: /gyexamples/ml_basic/images/sphx_glr_plot_grid_search_002.png :alt: plot grid search :srcset: /gyexamples/ml_basic/images/sphx_glr_plot_grid_search_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none [] .. GENERATED FROM PYTHON SOURCE LINES 70-75 Gird Search ----------- On optimise la valeur du paramètre :math:`\alpha` en choisissant différent valeur entre 0.5 et 2. .. GENERATED FROM PYTHON SOURCE LINES 76-82 .. code-block:: default grid = GridSearchCV( Lasso(), {'alpha': [1e-5, 0.01, 0.1, 0.5, 0.8, 1]}, verbose=3) grid.fit(X[:, :1], Y) .. rst-class:: sphx-glr-script-out .. code-block:: none Fitting 5 folds for each of 6 candidates, totalling 30 fits [CV 1/5] END .......................alpha=1e-05;, score=0.151 total time= 0.0s [CV 2/5] END .......................alpha=1e-05;, score=0.198 total time= 0.0s [CV 3/5] END .......................alpha=1e-05;, score=0.144 total time= 0.0s [CV 4/5] END .......................alpha=1e-05;, score=0.085 total time= 0.0s [CV 5/5] END .......................alpha=1e-05;, score=0.175 total time= 0.0s [CV 1/5] END ........................alpha=0.01;, score=0.151 total time= 0.0s [CV 2/5] END ........................alpha=0.01;, score=0.196 total time= 0.0s [CV 3/5] END ........................alpha=0.01;, score=0.143 total time= 0.0s [CV 4/5] END ........................alpha=0.01;, score=0.085 total time= 0.0s [CV 5/5] END ........................alpha=0.01;, score=0.175 total time= 0.0s [CV 1/5] END .........................alpha=0.1;, score=0.150 total time= 0.0s [CV 2/5] END .........................alpha=0.1;, score=0.179 total time= 0.0s [CV 3/5] END .........................alpha=0.1;, score=0.136 total time= 0.0s [CV 4/5] END .........................alpha=0.1;, score=0.089 total time= 0.0s [CV 5/5] END .........................alpha=0.1;, score=0.174 total time= 0.0s [CV 1/5] END .........................alpha=0.5;, score=0.012 total time= 0.0s [CV 2/5] END .........................alpha=0.5;, score=0.009 total time= 0.0s [CV 3/5] END ........................alpha=0.5;, score=-0.002 total time= 0.0s [CV 4/5] END .........................alpha=0.5;, score=0.028 total time= 0.0s [CV 5/5] END .........................alpha=0.5;, score=0.038 total time= 0.0s [CV 1/5] END ........................alpha=0.8;, score=-0.033 total time= 0.0s [CV 2/5] END ........................alpha=0.8;, score=-0.010 total time= 0.0s [CV 3/5] END ........................alpha=0.8;, score=-0.036 total time= 0.0s [CV 4/5] END ........................alpha=0.8;, score=-0.005 total time= 0.0s [CV 5/5] END ........................alpha=0.8;, score=-0.004 total time= 0.0s [CV 1/5] END ..........................alpha=1;, score=-0.033 total time= 0.0s [CV 2/5] END ..........................alpha=1;, score=-0.010 total time= 0.0s [CV 3/5] END ..........................alpha=1;, score=-0.036 total time= 0.0s [CV 4/5] END ..........................alpha=1;, score=-0.005 total time= 0.0s [CV 5/5] END ..........................alpha=1;, score=-0.004 total time= 0.0s .. raw:: html
GridSearchCV(estimator=Lasso(),
                 param_grid={'alpha': [1e-05, 0.01, 0.1, 0.5, 0.8, 1]}, verbose=3)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 83-84 On affiche les résultats. .. GENERATED FROM PYTHON SOURCE LINES 84-89 .. code-block:: default df = pandas.DataFrame(grid.cv_results_) df["alpha"] = df.params.apply(lambda x: x["alpha"]) print(df) .. rst-class:: sphx-glr-script-out .. code-block:: none mean_fit_time std_fit_time ... rank_test_score alpha 0 0.003209 0.000036 ... 2 0.00001 1 0.003204 0.000043 ... 1 0.01000 2 0.003193 0.000045 ... 3 0.10000 3 0.003191 0.000028 ... 4 0.50000 4 0.003179 0.000027 ... 5 0.80000 5 0.003200 0.000036 ... 5 1.00000 [6 rows x 15 columns] .. GENERATED FROM PYTHON SOURCE LINES 90-91 Sur cet exemple, la contrainte Lasso dégrade beaucoup les performances. .. GENERATED FROM PYTHON SOURCE LINES 91-98 .. code-block:: default fig = plt.figure(figsize=(5, 5)) ax = plt.subplot() df.set_index("alpha")[["mean_test_score"]].plot(ax=ax) ax.set_xlabel("alpha") ax.set_ylabel("mean_test_score") .. image-sg:: /gyexamples/ml_basic/images/sphx_glr_plot_grid_search_003.png :alt: plot grid search :srcset: /gyexamples/ml_basic/images/sphx_glr_plot_grid_search_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(-3.777777777777777, 0.5, 'mean_test_score') .. GENERATED FROM PYTHON SOURCE LINES 99-100 On représente les prédictions pour le meilleur modèle. .. GENERATED FROM PYTHON SOURCE LINES 100-109 .. code-block:: default fig = plt.figure(figsize=(5, 5)) ax = plt.subplot() x = list(sorted(X[:, :1])) y = grid.predict(x) ax.plot(X[:, 0], Y, '.') ax.plot(x, y) .. image-sg:: /gyexamples/ml_basic/images/sphx_glr_plot_grid_search_004.png :alt: plot grid search :srcset: /gyexamples/ml_basic/images/sphx_glr_plot_grid_search_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none [] .. GENERATED FROM PYTHON SOURCE LINES 110-122 Overfitting ----------- Par défaut :epkg:`scikit-learn` optimise les hyperparamètres tout en faisant une cross-validation. Sans celle-ci, c'est comme si le modèle optimisait ses coefficients sur la base d'apprentissage et ses hyperparamètres sur la base de test. De ce fait, toutes les données servent à optimiser un paramètre. La cross-validation limite en vérifiant la stabilité de l'apprentissage sur plusieurs découpages. On peut également découper en train / test / validation mais cela réduit d'autant le nombre de données pour apprendre. .. GENERATED FROM PYTHON SOURCE LINES 123-126 .. code-block:: default # plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.887 seconds) .. _sphx_glr_download_gyexamples_ml_basic_plot_grid_search.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_grid_search.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_grid_search.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_