Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRunUnary, RuntimeTypeError 

9from ..shape_object import ShapeObject 

10from .op_tfidfvectorizer_ import RuntimeTfIdfVectorizer # pylint: disable=E0611,E0401 

11 

12 

13class TfIdfVectorizer(OpRunUnary): 

14 

15 atts = {'max_gram_length': 1, 

16 'max_skip_count': 1, 

17 'min_gram_length': 1, 

18 'mode': b'TF', 

19 'ngram_counts': [], 

20 'ngram_indexes': [], 

21 'pool_int64s': [], 

22 'pool_strings': [], 

23 'weights': []} 

24 

25 def __init__(self, onnx_node, desc=None, **options): 

26 OpRunUnary.__init__(self, onnx_node, desc=desc, 

27 expected_attributes=TfIdfVectorizer.atts, 

28 **options) 

29 self.rt_ = RuntimeTfIdfVectorizer() 

30 if len(self.pool_strings) != 0: 

31 pool_int64s = list(range(len(self.pool_strings))) 

32 pool_strings_ = numpy.array( 

33 [_.decode('utf-8') for _ in self.pool_strings]) 

34 mapping = {} 

35 for i, w in enumerate(pool_strings_): 

36 mapping[w] = i 

37 else: 

38 mapping = None 

39 pool_int64s = self.pool_int64s 

40 pool_strings_ = None 

41 

42 self.mapping_ = mapping 

43 self.pool_strings_ = pool_strings_ 

44 self.rt_.init( 

45 self.max_gram_length, self.max_skip_count, self.min_gram_length, 

46 self.mode, self.ngram_counts, self.ngram_indexes, pool_int64s, 

47 self.weights) 

48 

49 def _run(self, x): # pylint: disable=W0221 

50 if self.mapping_ is None: 

51 res = self.rt_.compute(x) 

52 return (res.reshape((x.shape[0], -1)), ) 

53 else: 

54 xi = numpy.empty(x.shape, dtype=numpy.int64) 

55 for i in range(0, x.shape[0]): 

56 for j in range(0, x.shape[1]): 

57 try: 

58 xi[i, j] = self.mapping_[x[i, j]] 

59 except KeyError: 

60 xi[i, j] = -1 

61 res = self.rt_.compute(xi) 

62 return (res.reshape((x.shape[0], -1)), ) 

63 

64 def _infer_shapes(self, x): # pylint: disable=E0202,W0221 

65 if x.shape is None: 

66 return (x, ) 

67 if len(x) == 1: 

68 return (ShapeObject((x[0], None), dtype=x.dtype, 

69 name=self.__class__.__name__), ) 

70 if len(x) == 2: 

71 return (ShapeObject((x[0], x[1], None), dtype=x.dtype, 

72 name=self.__class__.__name__), ) 

73 raise RuntimeTypeError( 

74 "Only two dimension are allowed, got {}.".format(x)) 

75 

76 def _infer_types(self, x): # pylint: disable=E0202,W0221 

77 return (x, )