dot.js 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", {
  3. value: true
  4. });
  5. exports.createDot = void 0;
  6. var _factory = require("../../utils/factory.js");
  7. var _is = require("../../utils/is.js");
  8. var name = 'dot';
  9. var dependencies = ['typed', 'addScalar', 'multiplyScalar', 'conj', 'size'];
  10. var createDot = /* #__PURE__ */(0, _factory.factory)(name, dependencies, function (_ref) {
  11. var typed = _ref.typed,
  12. addScalar = _ref.addScalar,
  13. multiplyScalar = _ref.multiplyScalar,
  14. conj = _ref.conj,
  15. size = _ref.size;
  16. /**
  17. * Calculate the dot product of two vectors. The dot product of
  18. * `A = [a1, a2, ..., an]` and `B = [b1, b2, ..., bn]` is defined as:
  19. *
  20. * dot(A, B) = conj(a1) * b1 + conj(a2) * b2 + ... + conj(an) * bn
  21. *
  22. * Syntax:
  23. *
  24. * math.dot(x, y)
  25. *
  26. * Examples:
  27. *
  28. * math.dot([2, 4, 1], [2, 2, 3]) // returns number 15
  29. * math.multiply([2, 4, 1], [2, 2, 3]) // returns number 15
  30. *
  31. * See also:
  32. *
  33. * multiply, cross
  34. *
  35. * @param {Array | Matrix} x First vector
  36. * @param {Array | Matrix} y Second vector
  37. * @return {number} Returns the dot product of `x` and `y`
  38. */
  39. return typed(name, {
  40. 'Array | DenseMatrix, Array | DenseMatrix': _denseDot,
  41. 'SparseMatrix, SparseMatrix': _sparseDot
  42. });
  43. function _validateDim(x, y) {
  44. var xSize = _size(x);
  45. var ySize = _size(y);
  46. var xLen, yLen;
  47. if (xSize.length === 1) {
  48. xLen = xSize[0];
  49. } else if (xSize.length === 2 && xSize[1] === 1) {
  50. xLen = xSize[0];
  51. } else {
  52. throw new RangeError('Expected a column vector, instead got a matrix of size (' + xSize.join(', ') + ')');
  53. }
  54. if (ySize.length === 1) {
  55. yLen = ySize[0];
  56. } else if (ySize.length === 2 && ySize[1] === 1) {
  57. yLen = ySize[0];
  58. } else {
  59. throw new RangeError('Expected a column vector, instead got a matrix of size (' + ySize.join(', ') + ')');
  60. }
  61. if (xLen !== yLen) throw new RangeError('Vectors must have equal length (' + xLen + ' != ' + yLen + ')');
  62. if (xLen === 0) throw new RangeError('Cannot calculate the dot product of empty vectors');
  63. return xLen;
  64. }
  65. function _denseDot(a, b) {
  66. var N = _validateDim(a, b);
  67. var adata = (0, _is.isMatrix)(a) ? a._data : a;
  68. var adt = (0, _is.isMatrix)(a) ? a._datatype : undefined;
  69. var bdata = (0, _is.isMatrix)(b) ? b._data : b;
  70. var bdt = (0, _is.isMatrix)(b) ? b._datatype : undefined;
  71. // are these 2-dimensional column vectors? (as opposed to 1-dimensional vectors)
  72. var aIsColumn = _size(a).length === 2;
  73. var bIsColumn = _size(b).length === 2;
  74. var add = addScalar;
  75. var mul = multiplyScalar;
  76. // process data types
  77. if (adt && bdt && adt === bdt && typeof adt === 'string') {
  78. var dt = adt;
  79. // find signatures that matches (dt, dt)
  80. add = typed.find(addScalar, [dt, dt]);
  81. mul = typed.find(multiplyScalar, [dt, dt]);
  82. }
  83. // both vectors 1-dimensional
  84. if (!aIsColumn && !bIsColumn) {
  85. var c = mul(conj(adata[0]), bdata[0]);
  86. for (var i = 1; i < N; i++) {
  87. c = add(c, mul(conj(adata[i]), bdata[i]));
  88. }
  89. return c;
  90. }
  91. // a is 1-dim, b is column
  92. if (!aIsColumn && bIsColumn) {
  93. var _c = mul(conj(adata[0]), bdata[0][0]);
  94. for (var _i = 1; _i < N; _i++) {
  95. _c = add(_c, mul(conj(adata[_i]), bdata[_i][0]));
  96. }
  97. return _c;
  98. }
  99. // a is column, b is 1-dim
  100. if (aIsColumn && !bIsColumn) {
  101. var _c2 = mul(conj(adata[0][0]), bdata[0]);
  102. for (var _i2 = 1; _i2 < N; _i2++) {
  103. _c2 = add(_c2, mul(conj(adata[_i2][0]), bdata[_i2]));
  104. }
  105. return _c2;
  106. }
  107. // both vectors are column
  108. if (aIsColumn && bIsColumn) {
  109. var _c3 = mul(conj(adata[0][0]), bdata[0][0]);
  110. for (var _i3 = 1; _i3 < N; _i3++) {
  111. _c3 = add(_c3, mul(conj(adata[_i3][0]), bdata[_i3][0]));
  112. }
  113. return _c3;
  114. }
  115. }
  116. function _sparseDot(x, y) {
  117. _validateDim(x, y);
  118. var xindex = x._index;
  119. var xvalues = x._values;
  120. var yindex = y._index;
  121. var yvalues = y._values;
  122. // TODO optimize add & mul using datatype
  123. var c = 0;
  124. var add = addScalar;
  125. var mul = multiplyScalar;
  126. var i = 0;
  127. var j = 0;
  128. while (i < xindex.length && j < yindex.length) {
  129. var I = xindex[i];
  130. var J = yindex[j];
  131. if (I < J) {
  132. i++;
  133. continue;
  134. }
  135. if (I > J) {
  136. j++;
  137. continue;
  138. }
  139. if (I === J) {
  140. c = add(c, mul(xvalues[i], yvalues[j]));
  141. i++;
  142. j++;
  143. }
  144. }
  145. return c;
  146. }
  147. // TODO remove this once #1771 is fixed
  148. function _size(x) {
  149. return (0, _is.isMatrix)(x) ? x.size() : size(x);
  150. }
  151. });
  152. exports.createDot = createDot;