kldivergence.js 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import { factory } from '../../utils/factory.js';
  2. var name = 'kldivergence';
  3. var dependencies = ['typed', 'matrix', 'divide', 'sum', 'multiply', 'map', 'dotDivide', 'log', 'isNumeric'];
  4. export var createKldivergence = /* #__PURE__ */factory(name, dependencies, _ref => {
  5. var {
  6. typed,
  7. matrix,
  8. divide,
  9. sum,
  10. multiply,
  11. map,
  12. dotDivide,
  13. log,
  14. isNumeric
  15. } = _ref;
  16. /**
  17. * Calculate the Kullback-Leibler (KL) divergence between two distributions
  18. *
  19. * Syntax:
  20. *
  21. * math.kldivergence(x, y)
  22. *
  23. * Examples:
  24. *
  25. * math.kldivergence([0.7,0.5,0.4], [0.2,0.9,0.5]) //returns 0.24376698773121153
  26. *
  27. *
  28. * @param {Array | Matrix} q First vector
  29. * @param {Array | Matrix} p Second vector
  30. * @return {number} Returns distance between q and p
  31. */
  32. return typed(name, {
  33. 'Array, Array': function ArrayArray(q, p) {
  34. return _kldiv(matrix(q), matrix(p));
  35. },
  36. 'Matrix, Array': function MatrixArray(q, p) {
  37. return _kldiv(q, matrix(p));
  38. },
  39. 'Array, Matrix': function ArrayMatrix(q, p) {
  40. return _kldiv(matrix(q), p);
  41. },
  42. 'Matrix, Matrix': function MatrixMatrix(q, p) {
  43. return _kldiv(q, p);
  44. }
  45. });
  46. function _kldiv(q, p) {
  47. var plength = p.size().length;
  48. var qlength = q.size().length;
  49. if (plength > 1) {
  50. throw new Error('first object must be one dimensional');
  51. }
  52. if (qlength > 1) {
  53. throw new Error('second object must be one dimensional');
  54. }
  55. if (plength !== qlength) {
  56. throw new Error('Length of two vectors must be equal');
  57. }
  58. // Before calculation, apply normalization
  59. var sumq = sum(q);
  60. if (sumq === 0) {
  61. throw new Error('Sum of elements in first object must be non zero');
  62. }
  63. var sump = sum(p);
  64. if (sump === 0) {
  65. throw new Error('Sum of elements in second object must be non zero');
  66. }
  67. var qnorm = divide(q, sum(q));
  68. var pnorm = divide(p, sum(p));
  69. var result = sum(multiply(qnorm, map(dotDivide(qnorm, pnorm), x => log(x))));
  70. if (isNumeric(result)) {
  71. return result;
  72. } else {
  73. return Number.NaN;
  74. }
  75. }
  76. });