Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import ShapeObject 

10 

11 

12class ZipMapDictionary(dict): 

13 """ 

14 Custom dictionary class much faster for this runtime, 

15 it implements a subset of the same methods. 

16 """ 

17 __slots__ = ['_rev_keys', '_values', '_mat'] 

18 

19 @staticmethod 

20 def build_rev_keys(keys): 

21 res = {} 

22 for i, k in enumerate(keys): 

23 res[k] = i 

24 return res 

25 

26 def __init__(self, rev_keys, values, mat=None): 

27 """ 

28 @param rev_keys returns by @see me build_rev_keys, 

29 *{keys: column index}* 

30 @param values values 

31 @param mat matrix if values is a row index, 

32 one or two dimensions 

33 """ 

34 if mat is not None: 

35 if not isinstance(mat, numpy.ndarray): 

36 raise TypeError( # pragma: no cover 

37 'matrix is expected, got {}.'.format(type(mat))) 

38 if len(mat.shape) not in (2, 3): 

39 raise ValueError( # pragma: no cover 

40 "matrix must have two or three dimensions but got {}" 

41 ".".format(mat.shape)) 

42 dict.__init__(self) 

43 self._rev_keys = rev_keys 

44 self._values = values 

45 self._mat = mat 

46 

47 def __getstate__(self): 

48 """ 

49 For pickle. 

50 """ 

51 return dict(_rev_keys=self._rev_keys, 

52 _values=self._values, 

53 _mat=self._mat) 

54 

55 def __setstate__(self, state): 

56 """ 

57 For pickle. 

58 """ 

59 if isinstance(state, tuple): 

60 state = state[1] 

61 self._rev_keys = state['_rev_keys'] 

62 self._values = state['_values'] 

63 self._mat = state['_mat'] 

64 

65 def __getitem__(self, key): 

66 """ 

67 Returns the item mapped to keys. 

68 """ 

69 if self._mat is None: 

70 return self._values[self._rev_keys[key]] 

71 return self._mat[self._values, self._rev_keys[key]] 

72 

73 def __setitem__(self, pos, value): 

74 "unused but used by pickle" 

75 pass 

76 

77 def __len__(self): 

78 """ 

79 Returns the number of items. 

80 """ 

81 return len(self._values) if self._mat is None else self._mat.shape[1] 

82 

83 def __iter__(self): 

84 for k in self._rev_keys: 

85 yield k 

86 

87 def __contains__(self, key): 

88 return key in self._rev_keys 

89 

90 def items(self): 

91 if self._mat is None: 

92 for k, v in self._rev_keys.items(): 

93 yield k, self._values[v] 

94 else: 

95 for k, v in self._rev_keys.items(): 

96 yield k, self._mat[self._values, v] 

97 

98 def keys(self): 

99 for k in self._rev_keys.keys(): 

100 yield k 

101 

102 def values(self): 

103 if self._mat is None: 

104 for v in self._values: 

105 yield v 

106 else: 

107 for v in self._mat[self._values]: 

108 yield v 

109 

110 def asdict(self): 

111 res = {} 

112 for k, v in self.items(): 

113 res[k] = v 

114 return res 

115 

116 def __str__(self): 

117 return "ZipMap(%r)" % str(self.asdict()) 

118 

119 

120class ArrayZipMapDictionary(list): 

121 """ 

122 Mocks an array without changing the data it receives. 

123 Notebooks :ref:`onnxnodetimerst` illustrates the weaknesses 

124 and the strengths of this class compare to a list 

125 of dictionaries. 

126 

127 .. index:: ZipMap 

128 """ 

129 

130 def __init__(self, rev_keys, mat): 

131 """ 

132 @param rev_keys dictionary *{keys: column index}* 

133 @param mat matrix if values is a row index, 

134 one or two dimensions 

135 """ 

136 if mat is not None: 

137 if not isinstance(mat, numpy.ndarray): 

138 raise TypeError( # pragma: no cover 

139 'matrix is expected, got {}.'.format(type(mat))) 

140 if len(mat.shape) not in (2, 3): 

141 raise ValueError( # pragma: no cover 

142 "matrix must have two or three dimensions but got {}" 

143 ".".format(mat.shape)) 

144 list.__init__(self) 

145 self._rev_keys = rev_keys 

146 self._mat = mat 

147 

148 @property 

149 def dtype(self): 

150 return self._mat.dtype 

151 

152 def __len__(self): 

153 return self._mat.shape[0] 

154 

155 def __iter__(self): 

156 for i in range(len(self)): 

157 yield self[i] 

158 

159 def __getitem__(self, i): 

160 return ZipMapDictionary(self._rev_keys, i, self._mat) 

161 

162 def __setitem__(self, pos, value): 

163 raise RuntimeError( 

164 "Changing an element is not supported (pos=[{}]).".format(pos)) 

165 

166 @property 

167 def values(self): 

168 """ 

169 Equivalent to ``DataFrame(self).values``. 

170 """ 

171 if len(self._mat.shape) == 3: 

172 return self._mat.reshape((self._mat.shape[1], -1)) 

173 return self._mat 

174 

175 @property 

176 def columns(self): 

177 """ 

178 Equivalent to ``DataFrame(self).columns``. 

179 """ 

180 res = [(v, k) for k, v in self._rev_keys.items()] 

181 if len(res) == 0: 

182 if len(self._mat.shape) == 2: 

183 res = [(i, 'c%d' % i) for i in range(self._mat.shape[1])] 

184 elif len(self._mat.shape) == 3: 

185 # multiclass 

186 res = [(i, 'c%d' % i) 

187 for i in range(self._mat.shape[0] * self._mat.shape[2])] 

188 else: 

189 raise RuntimeError( # pragma: no cover 

190 "Unable to guess the right number of columns for " 

191 "shapes: {}".format(self._mat.shape)) 

192 else: 

193 res.sort() 

194 return [_[1] for _ in res] 

195 

196 @property 

197 def is_zip_map(self): 

198 return True 

199 

200 def __str__(self): 

201 return 'ZipMaps[%s]' % ', '.join(map(str, self)) 

202 

203 

204class ZipMap(OpRun): 

205 """ 

206 The class does not output a dictionary as 

207 specified in :epkg:`ONNX` specifications 

208 but a @see cl ArrayZipMapDictionary which 

209 is wrapper on the input so that it does not 

210 get copied. 

211 """ 

212 

213 atts = {'classlabels_int64s': [], 'classlabels_strings': []} 

214 

215 def __init__(self, onnx_node, desc=None, **options): 

216 OpRun.__init__(self, onnx_node, desc=desc, 

217 expected_attributes=ZipMap.atts, 

218 **options) 

219 if hasattr(self, 'classlabels_int64s') and len(self.classlabels_int64s) > 0: 

220 self.rev_keys_ = ZipMapDictionary.build_rev_keys( 

221 self.classlabels_int64s) 

222 elif hasattr(self, 'classlabels_strings') and len(self.classlabels_strings) > 0: 

223 self.rev_keys_ = ZipMapDictionary.build_rev_keys( 

224 self.classlabels_strings) 

225 else: 

226 self.rev_keys_ = {} 

227 

228 def _run(self, x): # pylint: disable=W0221 

229 res = ArrayZipMapDictionary(self.rev_keys_, x) 

230 return (res, ) 

231 

232 def _infer_shapes(self, x): # pylint: disable=W0221 

233 return (ShapeObject((x[0], ), dtype='map'), ) 

234 

235 def _infer_types(self, x): # pylint: disable=W0221 

236 """ 

237 Returns the same shape by default. 

238 """ 

239 return ('map', )