Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.

Commit e2e9ca3

Browse files
ericlippertfacebook-github-bot
authored andcommitted
Add cholesky_ex to list of supported functions in compiler (#1570)
Summary: Pull Request resolved: #1570 Xitong's latest version of the GEP model uses a different version of the Cholesky operator that was not yet recognized by the compiler. The `cholesky_ex` pytorch API returns a tuple consisting of the result and a tensor containing the indices of the inputs which were invalid inputs. We have no easy way to replicate this behavior in BMG, so instead we'll have the compiler emulate the pytorch behavior, but assume that there are no errors. It returns a tuple of the graph node and a zero tensor. Reviewed By: gafter Differential Revision: D38017818 fbshipit-source-id: e7a010d1e9f493ddb61c4a6c162841b81507117e
1 parent 8bdfb73 commit e2e9ca3

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

src/beanmachine/ppl/compiler/special_function_caller.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ def __init__(self, bmg: BMGraphBuilder) -> None:
444444
torch.bitwise_right_shift: self._torch_bitwise_right_shift,
445445
torch.Tensor.cholesky: self._torch_cholesky,
446446
torch.linalg.cholesky: self._torch_cholesky,
447+
torch.linalg.cholesky_ex: self._torch_cholesky_ex,
447448
torch.Tensor.div: self._torch_div,
448449
torch.div: self._torch_div,
449450
torch.Tensor.divide: self._torch_div,
@@ -930,6 +931,23 @@ def _torch_cholesky(
930931
# TODO: What to do with upper?
931932
return self._bmg.add_cholesky(input)
932933

934+
def _torch_cholesky_ex(
935+
self,
936+
input: BMGNode,
937+
upper: Optional[BMGNode] = None,
938+
check_errors: Optional[BMGNode] = None,
939+
out: Any = None,
940+
) -> BMGNode:
941+
# TODO: What to do with upper and check_errors?
942+
# cholesky_ex returns a named tuple (L, info) where
943+
# L is the result matrix and info is a tensor containing
944+
# an index saying which input element was not
945+
# positive-definite. We pretend that this operation always
946+
# succeeds and return a graph node and a zero error index.
947+
return torch.return_types.linalg_cholesky_ex( # pyre-ignore
948+
(self._bmg.add_cholesky(input), torch.tensor(0))
949+
)
950+
933951
def _torch_transpose(
934952
self,
935953
input: BMGNode,

src/beanmachine/ppl/compiler/tests/cholesky_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ def cholesky4():
5656
return t.cholesky()
5757

5858

59+
@bm.functional
60+
def cholesky5():
61+
n0 = norm(0) * norm(0)
62+
n1 = norm(1) * norm(1)
63+
t = tensor([[n0, 0.0], [0.0, n1]])
64+
L, _ = torch.linalg.cholesky_ex(t)
65+
return L
66+
67+
5968
# TODO: Test with a non-square matrix, should give an error.
6069

6170

@@ -101,6 +110,8 @@ def test_cholesky(self) -> None:
101110
self.assertEqual(expected.strip(), observed.strip())
102111
observed = BMGInference().to_dot([cholesky3()], {})
103112
self.assertEqual(expected.strip(), observed.strip())
113+
observed = BMGInference().to_dot([cholesky5()], {})
114+
self.assertEqual(expected.strip(), observed.strip())
104115

105116
expected = """
106117
digraph "graph" {

0 commit comments

Comments
 (0)