123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- "use strict";
- Object.defineProperty(exports, "__esModule", {
- value: true
- });
- exports.createKldivergence = void 0;
- var _factory = require("../../utils/factory.js");
- var name = 'kldivergence';
- var dependencies = ['typed', 'matrix', 'divide', 'sum', 'multiply', 'map', 'dotDivide', 'log', 'isNumeric'];
- var createKldivergence = /* #__PURE__ */(0, _factory.factory)(name, dependencies, function (_ref) {
- var typed = _ref.typed,
- matrix = _ref.matrix,
- divide = _ref.divide,
- sum = _ref.sum,
- multiply = _ref.multiply,
- map = _ref.map,
- dotDivide = _ref.dotDivide,
- log = _ref.log,
- isNumeric = _ref.isNumeric;
- /**
- * Calculate the Kullback-Leibler (KL) divergence between two distributions
- *
- * Syntax:
- *
- * math.kldivergence(x, y)
- *
- * Examples:
- *
- * math.kldivergence([0.7,0.5,0.4], [0.2,0.9,0.5]) //returns 0.24376698773121153
- *
- *
- * @param {Array | Matrix} q First vector
- * @param {Array | Matrix} p Second vector
- * @return {number} Returns distance between q and p
- */
- return typed(name, {
- 'Array, Array': function ArrayArray(q, p) {
- return _kldiv(matrix(q), matrix(p));
- },
- 'Matrix, Array': function MatrixArray(q, p) {
- return _kldiv(q, matrix(p));
- },
- 'Array, Matrix': function ArrayMatrix(q, p) {
- return _kldiv(matrix(q), p);
- },
- 'Matrix, Matrix': function MatrixMatrix(q, p) {
- return _kldiv(q, p);
- }
- });
- function _kldiv(q, p) {
- var plength = p.size().length;
- var qlength = q.size().length;
- if (plength > 1) {
- throw new Error('first object must be one dimensional');
- }
- if (qlength > 1) {
- throw new Error('second object must be one dimensional');
- }
- if (plength !== qlength) {
- throw new Error('Length of two vectors must be equal');
- }
- // Before calculation, apply normalization
- var sumq = sum(q);
- if (sumq === 0) {
- throw new Error('Sum of elements in first object must be non zero');
- }
- var sump = sum(p);
- if (sump === 0) {
- throw new Error('Sum of elements in second object must be non zero');
- }
- var qnorm = divide(q, sum(q));
- var pnorm = divide(p, sum(p));
- var result = sum(multiply(qnorm, map(dotDivide(qnorm, pnorm), function (x) {
- return log(x);
- })));
- if (isNumeric(result)) {
- return result;
- } else {
- return Number.NaN;
- }
- }
- });
- exports.createKldivergence = createKldivergence;
|