.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/ml_basic/plot_regression.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_ml_basic_plot_regression.py: Régression ========== Un problème de régression consiste à construire une fonction qui prédit une quantité réelle *Y* en fonction de variables *X*. C'est une façon d'exprimer un lien entre deux quantités en fonction des observations (voir `régression `_). La `régression linéaire `_ est le modèle le plus simple et consiste à supposer que la relation est linéaire. .. contents:: :local: .. GENERATED FROM PYTHON SOURCE LINES 22-26 Principe -------- On commence par générer un jeu de données artificiel. .. GENERATED FROM PYTHON SOURCE LINES 27-36 .. code-block:: default from sklearn.tree import DecisionTreeRegressor from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split from sklearn.linear_model import LinearRegression import matplotlib.pyplot as plt from sklearn.datasets import make_friedman1 X, Y = make_friedman1(n_samples=500, n_features=5) .. GENERATED FROM PYTHON SOURCE LINES 37-38 On représente ces données. .. GENERATED FROM PYTHON SOURCE LINES 38-43 .. code-block:: default fig = plt.figure(figsize=(5, 5)) ax = plt.subplot() ax.plot(X[:, 0], Y, '.') .. image-sg:: /gyexamples/ml_basic/images/sphx_glr_plot_regression_001.png :alt: plot regression :srcset: /gyexamples/ml_basic/images/sphx_glr_plot_regression_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none [] .. GENERATED FROM PYTHON SOURCE LINES 44-51 D'un point de vue géométrique, un problème de régression consiste à trouver la courbe qui s'approche au plus de tous les points. Le plus simple est de supposer que c'est une droite. Dans ce cas, on choisira un modèle de régression linéaire : `LinearRegression `_. .. GENERATED FROM PYTHON SOURCE LINES 51-55 .. code-block:: default reglin = LinearRegression() reglin.fit(X, Y) .. raw:: html
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 56-58 L'optimisation du modèle produit une droite dont les coefficients sont : .. GENERATED FROM PYTHON SOURCE LINES 58-60 .. code-block:: default print(reglin.coef_, reglin.intercept_) .. rst-class:: sphx-glr-script-out .. code-block:: none [ 5.93230458 6.3621584 -0.87201488 9.60041851 5.02910954] 1.342482246756484 .. GENERATED FROM PYTHON SOURCE LINES 61-64 On reprend le premier graphe est on y ajoute la droite qui correspond à la régression linéaire uniquement sur la première dimension. .. GENERATED FROM PYTHON SOURCE LINES 64-75 .. code-block:: default reglin = LinearRegression() reglin.fit(X[:, :1], Y) fig = plt.figure(figsize=(5, 5)) ax = plt.subplot() x = list(sorted(X[:, :1])) y = reglin.predict(x) ax.plot(X[:, 0], Y, '.') ax.plot(x, y) .. image-sg:: /gyexamples/ml_basic/images/sphx_glr_plot_regression_002.png :alt: plot regression :srcset: /gyexamples/ml_basic/images/sphx_glr_plot_regression_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none [] .. GENERATED FROM PYTHON SOURCE LINES 76-93 Evaluation ---------- Le critère d'erreur le plus utilisé est l'erreur quadratique. Si :math:`y_i` est la valeur à prédire, et :math:`y_i^*` la valeur prédite, l'erreur est : .. math:: err = \sum_i(y_i - y_i^*)^2 La plupart des problèmes sont multidimensionnelles et s'appuie sur de nombreuses variables mais bien souvent la valeur à prédire est réelle. On représente donc le graphique *XY* où l'axe des abscisses représente la valeur à prédire et l'axe des ordonnées la valeur prédite. On découpe d'abord en train/test. .. GENERATED FROM PYTHON SOURCE LINES 94-97 .. code-block:: default X_train, X_test, y_train, y_test = train_test_split(X, Y) .. GENERATED FROM PYTHON SOURCE LINES 98-99 On apprend un modèle. .. GENERATED FROM PYTHON SOURCE LINES 99-103 .. code-block:: default reglin = LinearRegression() reglin.fit(X_train, y_train) .. raw:: html
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 104-105 On prédit. .. GENERATED FROM PYTHON SOURCE LINES 105-107 .. code-block:: default pred = reglin.predict(X_test) .. GENERATED FROM PYTHON SOURCE LINES 108-109 On calcule l'erreur. .. GENERATED FROM PYTHON SOURCE LINES 109-112 .. code-block:: default err = mean_squared_error(y_test, pred) print(err) .. rst-class:: sphx-glr-script-out .. code-block:: none 7.448159243320464 .. GENERATED FROM PYTHON SOURCE LINES 113-114 On dessine. .. GENERATED FROM PYTHON SOURCE LINES 114-122 .. code-block:: default fig = plt.figure(figsize=(5, 5)) ax = plt.subplot() ax.plot(y_test, pred, '.') ax.set_xlabel("expected") ax.set_ylabel("predicted") ax.set_title(f"ERR2={err}") .. image-sg:: /gyexamples/ml_basic/images/sphx_glr_plot_regression_003.png :alt: ERR2=7.448159243320464 :srcset: /gyexamples/ml_basic/images/sphx_glr_plot_regression_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'ERR2=7.448159243320464') .. GENERATED FROM PYTHON SOURCE LINES 123-130 Comparer deux modèles --------------------- On peut bien évidemment les comparer numériquement. Graphiquement, les nuages de points s'entremêlent et deviennent peu visible. Pour y remédier, on trie les erreurs par ordre croissant. .. GENERATED FROM PYTHON SOURCE LINES 131-136 .. code-block:: default regtree = DecisionTreeRegressor() regtree.fit(X_train, y_train) pred_tree = regtree.predict(X_test) .. GENERATED FROM PYTHON SOURCE LINES 137-138 On dessine. .. GENERATED FROM PYTHON SOURCE LINES 138-150 .. code-block:: default y1 = list(sorted(pred - y_test)) y2 = list(sorted(pred_tree - y_test)) fig = plt.figure(figsize=(5, 5)) ax = plt.subplot() ax.plot([0, len(y1)], [0, 0], '-') ax.plot(y1, '.', label="linear") ax.plot(y2, '.', label="tree") ax.legend() plt.show() .. image-sg:: /gyexamples/ml_basic/images/sphx_glr_plot_regression_004.png :alt: plot regression :srcset: /gyexamples/ml_basic/images/sphx_glr_plot_regression_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.655 seconds) .. _sphx_glr_download_gyexamples_ml_basic_plot_regression.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_regression.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_regression.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_