Skip to content

Commit 8fcc607

Browse files
ltiaofacebook-github-bot
authored andcommitted
Fix batch computation in Pivoted Cholesky
Summary: ## Context TODO: ## Changes Updates a line containing indexing logic that breaks when `len(batch_shape) > 1` Differential Revision: D72906531
1 parent cd36e52 commit 8fcc607

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

botorch/utils/probability/linalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def update_(self, eps: float = 1e-10) -> None:
125125
rank1 = L[..., i + 1 :, i : i + 1].clone()
126126
rank1 = (rank1 * rank1.transpose(-1, -2)).tril()
127127
L[..., i + 1 :, i + 1 :] = L[..., i + 1 :, i + 1 :].clone() - rank1
128-
L[Lii <= i * eps, i:, i] = 0 # numerical stability clause
128+
L[..., i:, i][Lii <= i * eps] = 0 # numerical stability clause
129129
self.step += 1
130130

131131
def pivot_(self, pivot: LongTensor) -> None:

0 commit comments

Comments
 (0)