Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ensureShape API #7632

Merged
merged 12 commits into from
Apr 28, 2023
Merged

Add ensureShape API #7632

merged 12 commits into from
Apr 28, 2023

Conversation

fengwuyao
Copy link
Collaborator

@fengwuyao fengwuyao commented Apr 26, 2023

This PR adds the ops for ensureShape() in Core.
We can ensure the input tensor has the same shape as the given shape.

// Construct a 1D tensor with [1, 2, 3, 4]
const x = tf.tensor1d([1, 2, 3, 4]);

// Ensure that x has the same shape with shape: [4]
tf.ensureShape(x, [4]);

Fix #7225

@Linchenn Linchenn self-requested a review April 26, 2023 23:42
@@ -0,0 +1,54 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When creating new files, we need to update the year in the licence.


if (!arraysEqualWithNull($x.shape, shape)) {
throw new Error(`Invalid argument error. Shape of tensor ${
x} is not compatible with expected shape ${shape}`);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x here is a tensor, an object. We typically use a small array or primitive values to fill the template. We could use '...tensor ${x.shape}...' here. And we could add a period at the end.

@@ -0,0 +1,32 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the year

* ```
*
* @param x The input tensor to be ensured.
* @param shape A TensorShape representing the shape of this tensor, a list, a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only support array here, and we could explain using null values


it('different shape', () => {
const x = tf.ones([2, 3]);
expect(() => ensureShape(x, [5, 3])).toThrowError();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a string for toThrowError, might be 'Invalid argument error. Shape of tensor [2, 3] is not compatible with expected shape [5, 3].'

@fengwuyao fengwuyao requested a review from Linchenn April 27, 2023 00:48
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
* const y = tf.tensor1d([1, null, 3, 4]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

‘null’ might be in ensureShape's shape argument. The example could be:

const y = tf.tensor2d([1, 2, 3, 4], [2,2]);
tf.ensureShape(y, [null, 2]).print();

Copy link
Collaborator

@Linchenn Linchenn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Collaborator

@chunnienc chunnienc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the broken converter tests and make sure CI passes, thanks!

import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {ensureShape} from './ensure_shape';

describeWithFlags('ensure_shape', ALL_ENVS, () => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add tests for:

  1. shape with nulls
  2. shape with different lengths

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and I guess describeWithFlags('ensureShape', ALL_ENVS ... is more appropriate for the naming pattern of op tests?

Copy link
Collaborator Author

@fengwuyao fengwuyao Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg, added the tests. Thanks!

@fengwuyao fengwuyao requested a review from chunnienc April 28, 2023 17:27
Copy link
Collaborator

@chunnienc chunnienc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. You can click the resolve button in github to mark/hide the comments you fixed.
Since your PR now has 2 approvals, it's your choice to merge it now or wait for the third one from matt for more comments.

@@ -0,0 +1,59 @@
/**
* @license
* Copyright 2023 Google LLC. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no "All rights reserved"
You can use https://github.com/chunnienc/tfjs-license-fix with glob to add/fix those headers in a batch.

@fengwuyao fengwuyao merged commit b667314 into tensorflow:master Apr 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unknown op 'EnsureShape'
3 participants