dot.js 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import { factory } from '../../utils/factory.js';
  2. import { isMatrix } from '../../utils/is.js';
  3. var name = 'dot';
  4. var dependencies = ['typed', 'addScalar', 'multiplyScalar', 'conj', 'size'];
  5. export var createDot = /* #__PURE__ */factory(name, dependencies, _ref => {
  6. var {
  7. typed,
  8. addScalar,
  9. multiplyScalar,
  10. conj,
  11. size
  12. } = _ref;
  13. /**
  14. * Calculate the dot product of two vectors. The dot product of
  15. * `A = [a1, a2, ..., an]` and `B = [b1, b2, ..., bn]` is defined as:
  16. *
  17. * dot(A, B) = conj(a1) * b1 + conj(a2) * b2 + ... + conj(an) * bn
  18. *
  19. * Syntax:
  20. *
  21. * math.dot(x, y)
  22. *
  23. * Examples:
  24. *
  25. * math.dot([2, 4, 1], [2, 2, 3]) // returns number 15
  26. * math.multiply([2, 4, 1], [2, 2, 3]) // returns number 15
  27. *
  28. * See also:
  29. *
  30. * multiply, cross
  31. *
  32. * @param {Array | Matrix} x First vector
  33. * @param {Array | Matrix} y Second vector
  34. * @return {number} Returns the dot product of `x` and `y`
  35. */
  36. return typed(name, {
  37. 'Array | DenseMatrix, Array | DenseMatrix': _denseDot,
  38. 'SparseMatrix, SparseMatrix': _sparseDot
  39. });
  40. function _validateDim(x, y) {
  41. var xSize = _size(x);
  42. var ySize = _size(y);
  43. var xLen, yLen;
  44. if (xSize.length === 1) {
  45. xLen = xSize[0];
  46. } else if (xSize.length === 2 && xSize[1] === 1) {
  47. xLen = xSize[0];
  48. } else {
  49. throw new RangeError('Expected a column vector, instead got a matrix of size (' + xSize.join(', ') + ')');
  50. }
  51. if (ySize.length === 1) {
  52. yLen = ySize[0];
  53. } else if (ySize.length === 2 && ySize[1] === 1) {
  54. yLen = ySize[0];
  55. } else {
  56. throw new RangeError('Expected a column vector, instead got a matrix of size (' + ySize.join(', ') + ')');
  57. }
  58. if (xLen !== yLen) throw new RangeError('Vectors must have equal length (' + xLen + ' != ' + yLen + ')');
  59. if (xLen === 0) throw new RangeError('Cannot calculate the dot product of empty vectors');
  60. return xLen;
  61. }
  62. function _denseDot(a, b) {
  63. var N = _validateDim(a, b);
  64. var adata = isMatrix(a) ? a._data : a;
  65. var adt = isMatrix(a) ? a._datatype : undefined;
  66. var bdata = isMatrix(b) ? b._data : b;
  67. var bdt = isMatrix(b) ? b._datatype : undefined;
  68. // are these 2-dimensional column vectors? (as opposed to 1-dimensional vectors)
  69. var aIsColumn = _size(a).length === 2;
  70. var bIsColumn = _size(b).length === 2;
  71. var add = addScalar;
  72. var mul = multiplyScalar;
  73. // process data types
  74. if (adt && bdt && adt === bdt && typeof adt === 'string') {
  75. var dt = adt;
  76. // find signatures that matches (dt, dt)
  77. add = typed.find(addScalar, [dt, dt]);
  78. mul = typed.find(multiplyScalar, [dt, dt]);
  79. }
  80. // both vectors 1-dimensional
  81. if (!aIsColumn && !bIsColumn) {
  82. var c = mul(conj(adata[0]), bdata[0]);
  83. for (var i = 1; i < N; i++) {
  84. c = add(c, mul(conj(adata[i]), bdata[i]));
  85. }
  86. return c;
  87. }
  88. // a is 1-dim, b is column
  89. if (!aIsColumn && bIsColumn) {
  90. var _c = mul(conj(adata[0]), bdata[0][0]);
  91. for (var _i = 1; _i < N; _i++) {
  92. _c = add(_c, mul(conj(adata[_i]), bdata[_i][0]));
  93. }
  94. return _c;
  95. }
  96. // a is column, b is 1-dim
  97. if (aIsColumn && !bIsColumn) {
  98. var _c2 = mul(conj(adata[0][0]), bdata[0]);
  99. for (var _i2 = 1; _i2 < N; _i2++) {
  100. _c2 = add(_c2, mul(conj(adata[_i2][0]), bdata[_i2]));
  101. }
  102. return _c2;
  103. }
  104. // both vectors are column
  105. if (aIsColumn && bIsColumn) {
  106. var _c3 = mul(conj(adata[0][0]), bdata[0][0]);
  107. for (var _i3 = 1; _i3 < N; _i3++) {
  108. _c3 = add(_c3, mul(conj(adata[_i3][0]), bdata[_i3][0]));
  109. }
  110. return _c3;
  111. }
  112. }
  113. function _sparseDot(x, y) {
  114. _validateDim(x, y);
  115. var xindex = x._index;
  116. var xvalues = x._values;
  117. var yindex = y._index;
  118. var yvalues = y._values;
  119. // TODO optimize add & mul using datatype
  120. var c = 0;
  121. var add = addScalar;
  122. var mul = multiplyScalar;
  123. var i = 0;
  124. var j = 0;
  125. while (i < xindex.length && j < yindex.length) {
  126. var I = xindex[i];
  127. var J = yindex[j];
  128. if (I < J) {
  129. i++;
  130. continue;
  131. }
  132. if (I > J) {
  133. j++;
  134. continue;
  135. }
  136. if (I === J) {
  137. c = add(c, mul(xvalues[i], yvalues[j]));
  138. i++;
  139. j++;
  140. }
  141. }
  142. return c;
  143. }
  144. // TODO remove this once #1771 is fixed
  145. function _size(x) {
  146. return isMatrix(x) ? x.size() : size(x);
  147. }
  148. });