-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ensureShape API #7632
Add ensureShape API #7632
Conversation
tfjs-core/src/ops/ensure_shape.ts
Outdated
@@ -0,0 +1,54 @@ | |||
/** | |||
* @license | |||
* Copyright 2020 Google LLC. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When creating new files, we need to update the year in the licence.
tfjs-core/src/ops/ensure_shape.ts
Outdated
|
||
if (!arraysEqualWithNull($x.shape, shape)) { | ||
throw new Error(`Invalid argument error. Shape of tensor ${ | ||
x} is not compatible with expected shape ${shape}`); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x here is a tensor, an object. We typically use a small array or primitive values to fill the template. We could use '...tensor ${x.shape}...' here. And we could add a period at the end.
@@ -0,0 +1,32 @@ | |||
/** | |||
* @license | |||
* Copyright 2020 Google LLC. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the year
tfjs-core/src/ops/ensure_shape.ts
Outdated
* ``` | ||
* | ||
* @param x The input tensor to be ensured. | ||
* @param shape A TensorShape representing the shape of this tensor, a list, a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only support array here, and we could explain using null values
|
||
it('different shape', () => { | ||
const x = tf.ones([2, 3]); | ||
expect(() => ensureShape(x, [5, 3])).toThrowError(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could add a string for toThrowError, might be 'Invalid argument error. Shape of tensor [2, 3] is not compatible with expected shape [5, 3].'
* | ||
* ```js | ||
* const x = tf.tensor1d([1, 2, 3, 4]); | ||
* const y = tf.tensor1d([1, null, 3, 4]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
‘null’ might be in ensureShape's shape argument. The example could be:
const y = tf.tensor2d([1, 2, 3, 4], [2,2]);
tf.ensureShape(y, [null, 2]).print();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the broken converter tests and make sure CI passes, thanks!
import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; | ||
import {ensureShape} from './ensure_shape'; | ||
|
||
describeWithFlags('ensure_shape', ALL_ENVS, () => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add tests for:
- shape with nulls
- shape with different lengths
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and I guess describeWithFlags('ensureShape', ALL_ENVS ...
is more appropriate for the naming pattern of op tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sg, added the tests. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. You can click the resolve button in github to mark/hide the comments you fixed.
Since your PR now has 2 approvals, it's your choice to merge it now or wait for the third one from matt for more comments.
tfjs-core/src/ops/ensure_shape.ts
Outdated
@@ -0,0 +1,59 @@ | |||
/** | |||
* @license | |||
* Copyright 2023 Google LLC. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no "All rights reserved"
You can use https://github.com/chunnienc/tfjs-license-fix with glob to add/fix those headers in a batch.
This PR adds the ops for ensureShape() in Core.
We can ensure the input tensor has the same shape as the given shape.
Fix #7225