Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Jeux de données reliés aux catégories.
5"""
6import os
7from io import StringIO, BytesIO
8import pandas
9from pyquickhelper.filehelper import read_content_ufs, ungzip_files
10from .data_helper import get_data_folder
13def load_adult_dataset(download=True, small=False, url='uci'):
14 """
15 Retourne le jeu de données
16 `Adult Data Set <https://archive.ics.uci.edu/ml/datasets/adult>`_.
17 Les variables sont principalement catégorielles.
18 Notebooks associés à ce jeu de données :
20 .. runpython::
21 :rst:
23 from papierstat.datasets.documentation import list_notebooks_rst_links
24 links = list_notebooks_rst_links('lectures', 'adult')
25 links = [' * %s' % s for s in links]
26 print('\\n'.join(links))
28 @param download télécharge le jeu de données ou considères une copie en local.
29 @param small récupère une version allégée en local
30 @param url source
31 @return :epkg:`pandas:DataFrame` (train, test)
32 """
33 columns = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status',
34 'occupation', 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss',
35 'hours_per_week', 'native_country', '<=50K']
37 if small:
38 fold = get_data_folder()
39 data_train = os.path.join(fold, 'adult.data.gz')
40 data_test = os.path.join(fold, 'adult.test.gz')
41 train = pandas.read_csv(data_train, header=None)
42 test = pandas.read_csv(data_test, header=None)
43 train.columns = columns
44 test.columns = columns
45 elif download:
46 if url == 'uci':
47 url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/"
48 train = pandas.read_csv(url + "adult.data", header=None)
49 test = pandas.read_csv(url + "adult.test", header=None, skiprows=1)
50 else:
51 url = "http://www.xavierdupre.fr/enseignement/complements/"
52 tr = read_content_ufs(url + "adult.data.gz",
53 asbytes=True, encoding=None,
54 min_size=400000)
55 by = BytesIO(tr)
56 tx = ungzip_files(by, unzip=False)
57 st = StringIO(tx.decode('ascii'))
58 train = pandas.read_csv(st, header=None)
59 te = read_content_ufs(url + "adult.test.gz",
60 asbytes=True, encoding=None,
61 min_size=200000)
62 by = BytesIO(te)
63 tx = ungzip_files(by, unzip=False)
64 st = StringIO(tx.decode('ascii'))
65 test = pandas.read_csv(st, header=None, skiprows=1)
66 train.columns = columns
67 test.columns = columns
68 else:
69 raise NotImplementedError( # pragma: no cover
70 "No local copy")
71 label = '<=50K'
72 train[label] = train[label].str.strip(' .')
73 test[label] = test[label].str.strip(' .')
74 cols = train.select_dtypes(object).columns
75 for c in cols:
76 train[c] = train[c].str.strip() # pylint: disable=E1136,E1137
77 for c in cols:
78 test[c] = test[c].str.strip()
79 return train, test