-
Notifications
You must be signed in to change notification settings - Fork 129
Added rewrite for matrix inv(inv(x)) -> x #893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start! Biggest missing piece is the solve-based inverses. Let's use this opportunity to write a helper that detects those, because all inverse rewrites will need it.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #893 +/- ##
=======================================
Coverage 81.49% 81.50%
=======================================
Files 176 176
Lines 46925 46938 +13
Branches 11428 11435 +7
=======================================
+ Hits 38242 38255 +13
- Misses 6498 6500 +2
+ Partials 2185 2183 -2
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
@ricardoV94 lmk if there's anything else to be done here. otherwise, it should be good to merge |
7baab4a
to
1d62dd3
Compare
@node_rewriter([Blockwise]) | ||
def rewrite_inv_inv(fgraph, node): | ||
""" | ||
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. | ||
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten. | ||
Parameters | ||
---------- | ||
fgraph: FunctionGraph | ||
Function graph being optimized | ||
node: Apply | ||
Node of the function graph to be optimized | ||
Returns | ||
------- | ||
list of Variable, optional | ||
List of optimized variables, or None if no optimization was performed | ||
""" | ||
valid_inverses = (MatrixInverse, MatrixPinv) | ||
# Check if its a valid inverse operation (either inv/pinv) | ||
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation | ||
# If the outer operation is not a valid inverse, we do not apply this rewrite | ||
if not isinstance(node.op.core_op, valid_inverses): | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we predefine the two possible pinv as blockwise (like we do for matrix_inverse) the rewrite can track more specifically and avoids being called for any Blockwise it sees:
@node_rewriter([Blockwise]) | |
def rewrite_inv_inv(fgraph, node): | |
""" | |
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. | |
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten. | |
Parameters | |
---------- | |
fgraph: FunctionGraph | |
Function graph being optimized | |
node: Apply | |
Node of the function graph to be optimized | |
Returns | |
------- | |
list of Variable, optional | |
List of optimized variables, or None if no optimization was performed | |
""" | |
valid_inverses = (MatrixInverse, MatrixPinv) | |
# Check if its a valid inverse operation (either inv/pinv) | |
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation | |
# If the outer operation is not a valid inverse, we do not apply this rewrite | |
if not isinstance(node.op.core_op, valid_inverses): | |
return None | |
@node_rewriter([matrix_inverse, matrix_pinv_hermitian, matrix_pinv_non_hermitian]) | |
def rewrite_inv_inv(fgraph, node): | |
""" | |
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. | |
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten. | |
Parameters | |
---------- | |
fgraph: FunctionGraph | |
Function graph being optimized | |
node: Apply | |
Node of the function graph to be optimized | |
Returns | |
------- | |
list of Variable, optional | |
List of optimized variables, or None if no optimization was performed | |
""" |
Need to predefine those matrix_pinv*.. The helper "pinv" should return the predefined Ops instead of creating new ones to avoid Op duplication
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont quite understand the last line u wrote about
Need to predefine those matrix_pinv*.. The helper "pinv" should return the predefined Ops instead of creating new ones to avoid Op duplication
also, as a general rule for rewrites, is it better that they should be tracking more specific Ops instead of just Blockwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, as a general rule for rewrites, is it better that they should be tracking more specific Ops instead of just Blockwise
Is that a question? The answer is yes. It avoids useless calls to the rewrite function when the Op is not in the graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but for the check of the ops inside, i will have to use the method that i am already doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding pinv, the helper function is here:
pytensor/pytensor/tensor/nlinalg.py
Line 68 in ad27dc7
def pinv(x, hermitian=False): |
The idea is instead:
pinv_hermitian = Blockwise(Pinv(hermitian=True))
pinv_non_hermitian = Blockwise(Pinv(hermitian=False))
def pinv(x, hermitian=False):
...
return pinv_hermitian if hermitian else pinv_non_hermitian
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But actually let's not to that now, have to think. This makes initialization a bit slower because we have to create more instances...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh alright. but this is essentially the same thing as first checking the op as blockwise and then the core op as one of the valid inverses. whats the difference in both these ideas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference is the rewrite is not even considered if the node is not a Blockwise(Pinv). This way the rewrite is attempted on every single Blockwise in the graph.
It's just an optimization, not a question of correctness
Description
Adds rewrite for inv(inv(x)) -> x
Related Issue
Checklist
Type of change