From d5127c0728578085f1d085608698daa2aeb74909 Mon Sep 17 00:00:00 2001 From: Oliver Schacht Date: Fri, 7 Mar 2025 17:23:38 +0100 Subject: [PATCH] fix: add explicit normalization --- causallearn/utils/FastKCI/FastKCI.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/causallearn/utils/FastKCI/FastKCI.py b/causallearn/utils/FastKCI/FastKCI.py index f5fcc50..5ab54d7 100644 --- a/causallearn/utils/FastKCI/FastKCI.py +++ b/causallearn/utils/FastKCI/FastKCI.py @@ -98,7 +98,11 @@ def partition_data(self): ll = np.tile(np.log(pi_j), (self.n, 1)) for k in range(self.K): ll[:, k] += stats.multivariate_normal.logpdf(self.data_z, mu_k[k, :], cov=sigma_k, allow_singular=True) - Z = np.array([np.random.multinomial(1, np.exp(ll[n, :]-logsumexp(ll[n, :]))).argmax() for n in range(self.n)]) + + ll = np.exp(ll - logsumexp(ll, axis=1, keepdims=True)) + ll = ll / ll.sum(axis=1, keepdims=True) + + Z = np.array([np.random.multinomial(1, ll[n, :]).argmax() for n in range(self.n)]) le = LabelEncoder() Z = le.fit_transform(Z) return Z @@ -414,7 +418,11 @@ def partition_data(self): ll = np.tile(np.log(pi_j), (self.n, 1)) for k in range(self.K): ll[:, k] += stats.multivariate_normal.logpdf(self.data_y, mu_k[k, :], cov=sigma_k, allow_singular=True) - Z = np.array([np.random.multinomial(1, np.exp(ll[n, :]-logsumexp(ll[n, :]))).argmax() for n in range(self.n)]) + + ll = np.exp(ll - logsumexp(ll, axis=1, keepdims=True)) + ll = ll / ll.sum(axis=1, keepdims=True) + + Z = np.array([np.random.multinomial(1, ll[n, :]).argmax() for n in range(self.n)]) prop_Y = np.take_along_axis(ll, Z[:, None], axis=1).sum() le = LabelEncoder() Z = le.fit_transform(Z)