Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief numpy redundant functions. 

4""" 

5import numpy 

6from scipy.sparse.coo import coo_matrix 

7 

8 

9def numpy_dot_inplace(inplaces, a, b): 

10 """ 

11 Implements a dot product, deals with inplace information. 

12 See :epkg:`numpy:dot`. 

13 """ 

14 if inplaces.get(0, False) and hasattr(a, 'flags'): 

15 return _numpy_dot_inplace_left(a, b) 

16 if inplaces.get(1, False) and hasattr(b, 'flags'): 

17 return _numpy_dot_inplace_right(a, b) 

18 return numpy.dot(a, b) 

19 

20 

21def _numpy_dot_inplace_left(a, b): 

22 "Subpart of @see fn numpy_dot_inplace." 

23 if a.flags['F_CONTIGUOUS']: 

24 if len(b.shape) == len(a.shape) == 2 and b.shape[1] <= a.shape[1]: 

25 try: 

26 numpy.dot(a, b, out=a[:, :b.shape[1]]) 

27 return a[:, :b.shape[1]] 

28 except ValueError: 

29 return numpy.dot(a, b) 

30 if len(b.shape) == 1: 

31 try: 

32 numpy.dot(a, b.reshape(b.shape[0], 1), out=a[:, :1]) 

33 return a[:, :1].reshape(a.shape[0]) 

34 except ValueError: # pragma no cover 

35 return numpy.dot(a, b) 

36 return numpy.dot(a, b) 

37 

38 

39def _numpy_dot_inplace_right(a, b): 

40 "Subpart of @see fn numpy_dot_inplace." 

41 if b.flags['C_CONTIGUOUS']: 

42 if len(b.shape) == len(a.shape) == 2 and a.shape[0] <= b.shape[0]: 

43 try: 

44 numpy.dot(a, b, out=b[:a.shape[0], :]) 

45 return b[:a.shape[0], :] 

46 except ValueError: # pragma no cover 

47 return numpy.dot(a, b) 

48 if len(a.shape) == 1: 

49 try: 

50 numpy.dot(a, b, out=b[:1, :]) 

51 return b[:1, :] 

52 except ValueError: # pragma no cover 

53 return numpy.dot(a, b) 

54 return numpy.dot(a, b) 

55 

56 

57def numpy_matmul_inplace(inplaces, a, b): 

58 """ 

59 Implements a matmul product, deals with inplace information. 

60 See :epkg:`numpy:matmul`. 

61 Inplace computation does not work well as modifying one of the 

62 container modifies the results. This part still needs to be 

63 improves. 

64 """ 

65 try: 

66 if isinstance(a, coo_matrix) or isinstance(b, coo_matrix): 

67 return numpy.dot(a, b) 

68 if len(a.shape) <= 2 and len(b.shape) <= 2: 

69 return numpy_dot_inplace(inplaces, a, b) 

70 return numpy.matmul(a, b) 

71 except ValueError as e: 

72 raise ValueError( 

73 "Unable to multiply shapes %r, %r." % (a.shape, b.shape)) from e