-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement pad
#748
Implement pad
#748
Conversation
What about a padding.py file? |
Sure, I'll make a new file. It's just not my default. I agree it doesn't belong in basic. |
Not quite 1:1 on numpy features but close. The more exotic padding schemes I would need more time to understand. Still needs jax/numba overloads, but these should be very trivial. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #748 +/- ##
==========================================
+ Coverage 81.38% 81.48% +0.09%
==========================================
Files 172 174 +2
Lines 46868 47166 +298
Branches 11423 11471 +48
==========================================
+ Hits 38145 38434 +289
- Misses 6542 6548 +6
- Partials 2181 2184 +3
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great so far, left some small suggestions.
Draft of the JAX overload. Need your input on the It seems like there might be a difference between how I also think my loopy pads (symmetric, wrap) need to be redone, because they are failing a new test that arbitrarily pads every dimension of an nd input differently. So all that probably needs a re-design from the ground up. |
You may need an operation per dimension |
Regarding JAX do you need to implement a specific dispatch? For instance for the einsum I don't think we'll need because the OFG expression will be as good as what they do internally (since we copied it from them) |
No idea on the JAX dispatch. I just assumed I should. |
…tch numpy outputs
Fill out `_broadcast_inputs` docstring
….tensor`
Don't understand why the doctest for |
It says there is an output that was not expected. If you have a print somewhere, you need to always test it afterward |
You can also run doctest locally btw |
Something like |
We should open a follow up issue for performance. With the reshape and concatenation, we're doing a lot of copies. We should see how much better it would be to have scans with set_subtensors like you tried halfway. |
I kind of just want to skip the segfault test and come back to it later. I am trying to debug, but not really sure what's going on. It runs fine when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is huge!
Description
Implement
pt.pad
, following thenp.pad
API with feature parity.Very preliminary draft, uploading it in this state so I can ask @ricardoV94 to look at the
_linear_ramp_pad
function and tell me if I'm missing something obvious related to shapes. It should follownumpy.lib.arraypad._get_linear_ramps
. Also the reflection pad uses a scan, curious if we can avoid that somehow or if we think it will be no big deal (probably the 2nd).Also I'm not sure where to put this. I put it in
tensor/basic
but it might be better intensor/extra_ops
?Related Issue
pt.pad
#743Checklist
Type of change