kldivergence.js 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", {
  3. value: true
  4. });
  5. exports.createKldivergence = void 0;
  6. var _factory = require("../../utils/factory.js");
  7. var name = 'kldivergence';
  8. var dependencies = ['typed', 'matrix', 'divide', 'sum', 'multiply', 'map', 'dotDivide', 'log', 'isNumeric'];
  9. var createKldivergence = /* #__PURE__ */(0, _factory.factory)(name, dependencies, function (_ref) {
  10. var typed = _ref.typed,
  11. matrix = _ref.matrix,
  12. divide = _ref.divide,
  13. sum = _ref.sum,
  14. multiply = _ref.multiply,
  15. map = _ref.map,
  16. dotDivide = _ref.dotDivide,
  17. log = _ref.log,
  18. isNumeric = _ref.isNumeric;
  19. /**
  20. * Calculate the Kullback-Leibler (KL) divergence between two distributions
  21. *
  22. * Syntax:
  23. *
  24. * math.kldivergence(x, y)
  25. *
  26. * Examples:
  27. *
  28. * math.kldivergence([0.7,0.5,0.4], [0.2,0.9,0.5]) //returns 0.24376698773121153
  29. *
  30. *
  31. * @param {Array | Matrix} q First vector
  32. * @param {Array | Matrix} p Second vector
  33. * @return {number} Returns distance between q and p
  34. */
  35. return typed(name, {
  36. 'Array, Array': function ArrayArray(q, p) {
  37. return _kldiv(matrix(q), matrix(p));
  38. },
  39. 'Matrix, Array': function MatrixArray(q, p) {
  40. return _kldiv(q, matrix(p));
  41. },
  42. 'Array, Matrix': function ArrayMatrix(q, p) {
  43. return _kldiv(matrix(q), p);
  44. },
  45. 'Matrix, Matrix': function MatrixMatrix(q, p) {
  46. return _kldiv(q, p);
  47. }
  48. });
  49. function _kldiv(q, p) {
  50. var plength = p.size().length;
  51. var qlength = q.size().length;
  52. if (plength > 1) {
  53. throw new Error('first object must be one dimensional');
  54. }
  55. if (qlength > 1) {
  56. throw new Error('second object must be one dimensional');
  57. }
  58. if (plength !== qlength) {
  59. throw new Error('Length of two vectors must be equal');
  60. }
  61. // Before calculation, apply normalization
  62. var sumq = sum(q);
  63. if (sumq === 0) {
  64. throw new Error('Sum of elements in first object must be non zero');
  65. }
  66. var sump = sum(p);
  67. if (sump === 0) {
  68. throw new Error('Sum of elements in second object must be non zero');
  69. }
  70. var qnorm = divide(q, sum(q));
  71. var pnorm = divide(p, sum(p));
  72. var result = sum(multiply(qnorm, map(dotDivide(qnorm, pnorm), function (x) {
  73. return log(x);
  74. })));
  75. if (isNumeric(result)) {
  76. return result;
  77. } else {
  78. return Number.NaN;
  79. }
  80. }
  81. });
  82. exports.createKldivergence = createKldivergence;