trace.js 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import { clone } from '../../utils/object.js';
  2. import { format } from '../../utils/string.js';
  3. import { factory } from '../../utils/factory.js';
  4. var name = 'trace';
  5. var dependencies = ['typed', 'matrix', 'add'];
  6. export var createTrace = /* #__PURE__ */factory(name, dependencies, _ref => {
  7. var {
  8. typed,
  9. matrix,
  10. add
  11. } = _ref;
  12. /**
  13. * Calculate the trace of a matrix: the sum of the elements on the main
  14. * diagonal of a square matrix.
  15. *
  16. * Syntax:
  17. *
  18. * math.trace(x)
  19. *
  20. * Examples:
  21. *
  22. * math.trace([[1, 2], [3, 4]]) // returns 5
  23. *
  24. * const A = [
  25. * [1, 2, 3],
  26. * [-1, 2, 3],
  27. * [2, 0, 3]
  28. * ]
  29. * math.trace(A) // returns 6
  30. *
  31. * See also:
  32. *
  33. * diag
  34. *
  35. * @param {Array | Matrix} x A matrix
  36. *
  37. * @return {number} The trace of `x`
  38. */
  39. return typed('trace', {
  40. Array: function _arrayTrace(x) {
  41. // use dense matrix implementation
  42. return _denseTrace(matrix(x));
  43. },
  44. SparseMatrix: _sparseTrace,
  45. DenseMatrix: _denseTrace,
  46. any: clone
  47. });
  48. function _denseTrace(m) {
  49. // matrix size & data
  50. var size = m._size;
  51. var data = m._data;
  52. // process dimensions
  53. switch (size.length) {
  54. case 1:
  55. // vector
  56. if (size[0] === 1) {
  57. // return data[0]
  58. return clone(data[0]);
  59. }
  60. throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
  61. case 2:
  62. {
  63. // two dimensional
  64. var rows = size[0];
  65. var cols = size[1];
  66. if (rows === cols) {
  67. // calulate sum
  68. var sum = 0;
  69. // loop diagonal
  70. for (var i = 0; i < rows; i++) {
  71. sum = add(sum, data[i][i]);
  72. }
  73. // return trace
  74. return sum;
  75. } else {
  76. throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
  77. }
  78. }
  79. default:
  80. // multi dimensional
  81. throw new RangeError('Matrix must be two dimensional (size: ' + format(size) + ')');
  82. }
  83. }
  84. function _sparseTrace(m) {
  85. // matrix arrays
  86. var values = m._values;
  87. var index = m._index;
  88. var ptr = m._ptr;
  89. var size = m._size;
  90. // check dimensions
  91. var rows = size[0];
  92. var columns = size[1];
  93. // matrix must be square
  94. if (rows === columns) {
  95. // calulate sum
  96. var sum = 0;
  97. // check we have data (avoid looping columns)
  98. if (values.length > 0) {
  99. // loop columns
  100. for (var j = 0; j < columns; j++) {
  101. // k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
  102. var k0 = ptr[j];
  103. var k1 = ptr[j + 1];
  104. // loop k within [k0, k1[
  105. for (var k = k0; k < k1; k++) {
  106. // row index
  107. var i = index[k];
  108. // check row
  109. if (i === j) {
  110. // accumulate value
  111. sum = add(sum, values[k]);
  112. // exit loop
  113. break;
  114. }
  115. if (i > j) {
  116. // exit loop, no value on the diagonal for column j
  117. break;
  118. }
  119. }
  120. }
  121. }
  122. // return trace
  123. return sum;
  124. }
  125. throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
  126. }
  127. });