Skip to content

Added rewrite for matrix inv(inv(x)) -> x #893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 19, 2024

Conversation

tanish1729
Copy link
Contributor

@tanish1729 tanish1729 commented Jul 7, 2024

Description

Adds rewrite for inv(inv(x)) -> x

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Sorry, something went wrong.

@jessegrabowski jessegrabowski requested review from jessegrabowski and ricardoV94 and removed request for jessegrabowski July 7, 2024 02:22
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start! Biggest missing piece is the solve-based inverses. Let's use this opportunity to write a helper that detects those, because all inverse rewrites will need it.

@ricardoV94 ricardoV94 added enhancement New feature or request graph rewriting linalg Linear algebra labels Jul 7, 2024
@ricardoV94 ricardoV94 changed the title Added rewrite for inv(inv(x)) -> x Added rewrite for matrix inv(inv(x)) -> x Jul 7, 2024
Copy link

codecov bot commented Jul 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.50%. Comparing base (ad27dc7) to head (4bf3c2d).
Report is 91 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #893   +/-   ##
=======================================
  Coverage   81.49%   81.50%           
=======================================
  Files         176      176           
  Lines       46925    46938   +13     
  Branches    11428    11435    +7     
=======================================
+ Hits        38242    38255   +13     
- Misses       6498     6500    +2     
+ Partials     2185     2183    -2     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/linalg.py 91.28% <100.00%> (+0.49%) ⬆️

... and 2 files with indirect coverage changes

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@tanish1729
Copy link
Contributor Author

@ricardoV94 lmk if there's anything else to be done here. otherwise, it should be good to merge

@tanish1729 tanish1729 requested a review from ricardoV94 July 19, 2024 09:23
Comment on lines +576 to +600
@node_rewriter([Blockwise])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
valid_inverses = (MatrixInverse, MatrixPinv)
# Check if its a valid inverse operation (either inv/pinv)
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
# If the outer operation is not a valid inverse, we do not apply this rewrite
if not isinstance(node.op.core_op, valid_inverses):
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we predefine the two possible pinv as blockwise (like we do for matrix_inverse) the rewrite can track more specifically and avoids being called for any Blockwise it sees:

Suggested change
@node_rewriter([Blockwise])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
valid_inverses = (MatrixInverse, MatrixPinv)
# Check if its a valid inverse operation (either inv/pinv)
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
# If the outer operation is not a valid inverse, we do not apply this rewrite
if not isinstance(node.op.core_op, valid_inverses):
return None
@node_rewriter([matrix_inverse, matrix_pinv_hermitian, matrix_pinv_non_hermitian])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""

Need to predefine those matrix_pinv*.. The helper "pinv" should return the predefined Ops instead of creating new ones to avoid Op duplication

Copy link
Contributor Author

@tanish1729 tanish1729 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont quite understand the last line u wrote about

Need to predefine those matrix_pinv*.. The helper "pinv" should return the predefined Ops instead of creating new ones to avoid Op duplication

also, as a general rule for rewrites, is it better that they should be tracking more specific Ops instead of just Blockwise

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, as a general rule for rewrites, is it better that they should be tracking more specific Ops instead of just Blockwise

Is that a question? The answer is yes. It avoids useless calls to the rewrite function when the Op is not in the graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but for the check of the ops inside, i will have to use the method that i am already doing

Copy link
Member

@ricardoV94 ricardoV94 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding pinv, the helper function is here:

def pinv(x, hermitian=False):

The idea is instead:

pinv_hermitian = Blockwise(Pinv(hermitian=True))
pinv_non_hermitian = Blockwise(Pinv(hermitian=False))

def pinv(x, hermitian=False):
  ...
  return pinv_hermitian if hermitian else pinv_non_hermitian

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But actually let's not to that now, have to think. This makes initialization a bit slower because we have to create more instances...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh alright. but this is essentially the same thing as first checking the op as blockwise and then the core op as one of the valid inverses. whats the difference in both these ideas

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is the rewrite is not even considered if the node is not a Blockwise(Pinv). This way the rewrite is attempted on every single Blockwise in the graph.

It's just an optimization, not a question of correctness

@ricardoV94 ricardoV94 merged commit f489cf4 into pymc-devs:main Jul 19, 2024
59 checks passed
Ch0ronomato pushed a commit to Ch0ronomato/pytensor that referenced this pull request Aug 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rewrite for consecutive matrix inverses
3 participants