trace.js 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", {
  3. value: true
  4. });
  5. exports.createTrace = void 0;
  6. var _object = require("../../utils/object.js");
  7. var _string = require("../../utils/string.js");
  8. var _factory = require("../../utils/factory.js");
  9. var name = 'trace';
  10. var dependencies = ['typed', 'matrix', 'add'];
  11. var createTrace = /* #__PURE__ */(0, _factory.factory)(name, dependencies, function (_ref) {
  12. var typed = _ref.typed,
  13. matrix = _ref.matrix,
  14. add = _ref.add;
  15. /**
  16. * Calculate the trace of a matrix: the sum of the elements on the main
  17. * diagonal of a square matrix.
  18. *
  19. * Syntax:
  20. *
  21. * math.trace(x)
  22. *
  23. * Examples:
  24. *
  25. * math.trace([[1, 2], [3, 4]]) // returns 5
  26. *
  27. * const A = [
  28. * [1, 2, 3],
  29. * [-1, 2, 3],
  30. * [2, 0, 3]
  31. * ]
  32. * math.trace(A) // returns 6
  33. *
  34. * See also:
  35. *
  36. * diag
  37. *
  38. * @param {Array | Matrix} x A matrix
  39. *
  40. * @return {number} The trace of `x`
  41. */
  42. return typed('trace', {
  43. Array: function _arrayTrace(x) {
  44. // use dense matrix implementation
  45. return _denseTrace(matrix(x));
  46. },
  47. SparseMatrix: _sparseTrace,
  48. DenseMatrix: _denseTrace,
  49. any: _object.clone
  50. });
  51. function _denseTrace(m) {
  52. // matrix size & data
  53. var size = m._size;
  54. var data = m._data;
  55. // process dimensions
  56. switch (size.length) {
  57. case 1:
  58. // vector
  59. if (size[0] === 1) {
  60. // return data[0]
  61. return (0, _object.clone)(data[0]);
  62. }
  63. throw new RangeError('Matrix must be square (size: ' + (0, _string.format)(size) + ')');
  64. case 2:
  65. {
  66. // two dimensional
  67. var rows = size[0];
  68. var cols = size[1];
  69. if (rows === cols) {
  70. // calulate sum
  71. var sum = 0;
  72. // loop diagonal
  73. for (var i = 0; i < rows; i++) {
  74. sum = add(sum, data[i][i]);
  75. }
  76. // return trace
  77. return sum;
  78. } else {
  79. throw new RangeError('Matrix must be square (size: ' + (0, _string.format)(size) + ')');
  80. }
  81. }
  82. default:
  83. // multi dimensional
  84. throw new RangeError('Matrix must be two dimensional (size: ' + (0, _string.format)(size) + ')');
  85. }
  86. }
  87. function _sparseTrace(m) {
  88. // matrix arrays
  89. var values = m._values;
  90. var index = m._index;
  91. var ptr = m._ptr;
  92. var size = m._size;
  93. // check dimensions
  94. var rows = size[0];
  95. var columns = size[1];
  96. // matrix must be square
  97. if (rows === columns) {
  98. // calulate sum
  99. var sum = 0;
  100. // check we have data (avoid looping columns)
  101. if (values.length > 0) {
  102. // loop columns
  103. for (var j = 0; j < columns; j++) {
  104. // k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
  105. var k0 = ptr[j];
  106. var k1 = ptr[j + 1];
  107. // loop k within [k0, k1[
  108. for (var k = k0; k < k1; k++) {
  109. // row index
  110. var i = index[k];
  111. // check row
  112. if (i === j) {
  113. // accumulate value
  114. sum = add(sum, values[k]);
  115. // exit loop
  116. break;
  117. }
  118. if (i > j) {
  119. // exit loop, no value on the diagonal for column j
  120. break;
  121. }
  122. }
  123. }
  124. }
  125. // return trace
  126. return sum;
  127. }
  128. throw new RangeError('Matrix must be square (size: ' + (0, _string.format)(size) + ')');
  129. }
  130. });
  131. exports.createTrace = createTrace;