Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Defines a competition.
5"""
6from io import StringIO
7import numpy
8import pandas
9from .metrics import mse, sklearn_metric, roc_auc_score_macro, roc_auc_score_micro
12class Competition:
13 """
14 Defines a competition.
15 """
17 def __init__(self, cpt_id, link, name, description, metric, datafile=None, expected_values=None):
18 """
19 @param cpt_id competition id
20 @param link link to the page, something like ``/competition``
21 @param name name of the competition
22 @param metric metric or list of metrics, list of metrics to compute
23 @param description description
24 @param datafile data file
25 @param expected_values expected values for each metric
26 """
27 self.link = link
28 self.name = name
29 self.cpt_id = cpt_id
30 if isinstance(metric, str):
31 metric = metric.split(',')
32 if not isinstance(metric, list):
33 metric = [metric]
34 self.metrics = metric
35 self.datafile = datafile
36 self.description = description
37 self.expected_values = self._load_values(expected_values)
39 def _load_values(self, values):
40 """
41 Converts values into a list of list of values,
42 one per metrics.
43 """
44 if isinstance(values, str):
45 if '\n' in values:
46 st = StringIO(values)
47 res = pandas.read_csv(st)
48 else:
49 if self.datafile is None:
50 self.datafile = values
51 res = pandas.read_csv(values)
52 elif isinstance(values, list):
53 if len(values) == 0:
54 raise ValueError("values cannot be empty")
55 if isinstance(values[0], dict):
56 res = pandas.DataFrame(values, dtype=float)
57 else:
58 res = pandas.DataFrame(numpy.array(values), dtype=float)
59 if res.shape[0] < res.shape[1]:
60 res = res.T.reset_index(drop=True)
61 res.columns = ["exp%d" % i for i in range(res.shape[1])]
62 elif isinstance(values, pandas.DataFrame):
63 res = values
64 else:
65 raise TypeError(
66 "Unexpected type for expected_values: {0}".format(type(values)))
67 return res
69 def evaluate(self, values):
70 """
71 Evaluates received values.
73 @param values list of values
74 @return dictionary {metric: res}
75 """
76 res = {}
77 values = self._load_values(values)
78 for met in self.metrics:
79 res[met] = self.evaluate_metric(met, self.expected_values, values)
80 return res
82 def evaluate_metric(self, met, exp, val):
83 """
84 Evaluates a metric.
86 @param met metric
87 @param exp expected value
88 @param val values
89 @return result
90 """
91 if met == "mse":
92 return mse(exp, val)
93 elif met == "roc_auc_score_micro":
94 return roc_auc_score_micro(exp, val)
95 elif met == "roc_auc_score_macro":
96 return roc_auc_score_macro(exp, val)
97 else:
98 return sklearn_metric(met, exp, val)
100 @property
101 def metric(self):
102 """
103 Returns the metrics in a single string.
104 """
105 return ",".join(self.metrics)
107 def to_dict(self):
108 """
109 Convert a competition into a dictionary.
110 """
111 s = StringIO()
112 self.expected_values.to_csv(s, index=False)
113 val = s.getvalue()
114 return dict(cpt_id=self.cpt_id, link=self.link, name=self.name,
115 description=self.description, expected_values=val,
116 metric=",".join(self.metrics), datafile=self.datafile)
118 @staticmethod
119 def to_records(list_cpt):
120 """
121 Converts a list of competitions into a list of dictionaries.
122 """
123 res = []
124 for cpt in list_cpt:
125 for met in cpt.metrics:
126 s = StringIO()
127 cpt.expected_values.to_csv(s, index=False)
128 val = s.getvalue()
129 d = dict(link=cpt.link, cpt_name=cpt.name, metric=met,
130 description=cpt.description, expected_values=val)
131 res.append(d)
132 return res