Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update requirements, restructure files and fix formatting for VAE example #3046

Merged
merged 1 commit into from
Jun 27, 2023

Conversation

canyon289
Copy link
Contributor

@canyon289 canyon289 commented Apr 21, 2023

What does this PR do?

Addressing a portion of the issues in #573
Putting up for review early to ensure I'm trending in the right direction and what else would be needed for a merge

Fixes

  • File structure - copied from WMT Example
  • Google style fixes - 99% done, some small errors may remain
  • Updates Requirements

I can fix up colab, readme, clu, model configs, and adding tests in separate PRs if you don't mind

@canyon289 canyon289 changed the title Update requirements, restructure files and fix formatting Update requirements, restructure files and fix formatting for VAE example Apr 21, 2023
@canyon289
Copy link
Contributor Author

I see the linting issues. Will get them fixed

@8bitmp3
Copy link
Collaborator

8bitmp3 commented Apr 21, 2023

Thanks @canyon289

@andsteing @cgarciae

Copy link
Collaborator

@andsteing andsteing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks!

Some minor comments inline.

examples/vae/input_pipeline.py Outdated Show resolved Hide resolved
# Make sure tf does not allocate gpu memory.
tf.config.experimental.set_visible_devices([], 'GPU')
train.train_and_evaluate(FLAGS)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we usually have two empty lines before this block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant: we usually have single empty lines within functions (lines 55-56, 59-60 should both be a single line), but we have double empty lines between "top-level codeblocks" (line 65 should be two empty lines)

examples/vae/train.py Outdated Show resolved Hide resolved
examples/vae/train.py Outdated Show resolved Hide resolved
examples/vae/train.py Show resolved Hide resolved
examples/vae/input_pipeline.py Outdated Show resolved Hide resolved
examples/vae/train.py Outdated Show resolved Hide resolved
examples/vae/input_pipeline.py Outdated Show resolved Hide resolved
examples/vae/models.py Outdated Show resolved Hide resolved
examples/vae/train.py Outdated Show resolved Hide resolved
@andsteing
Copy link
Collaborator

It's fine to clean-up README in a second step, but could you already update it to run python main.py instead of python train.py ?

@andsteing
Copy link
Collaborator

andsteing commented Apr 24, 2023

Tested that everything installs und runs (with expected loss) in a fresh Colab CPU runtime:
https://colab.research.google.com/drive/1MDrchufz1eUoB03znEnRV-Z8yQtg0Qau#revisionId=0B6Hoz8j9CmpCNlBmRENvTW5lSnpjZ1RHV1JxZGY2MVQyS1dRPQ

@canyon289
Copy link
Contributor Author

Thanks @andsteing for all the comments. I'll address them all

@codecov-commenter
Copy link

Codecov Report

Merging #3046 (38c972b) into main (d8708ed) will not change coverage.
The diff coverage is n/a.

@@           Coverage Diff           @@
##             main    #3046   +/-   ##
=======================================
  Coverage   81.97%   81.97%           
=======================================
  Files          55       55           
  Lines        6031     6031           
=======================================
  Hits         4944     4944           
  Misses       1087     1087           

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

return VAE(latents=FLAGS.latents)


@jax.jit
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question for understanding, should the train step be jitted? I looked at other examples and that didn't seem to be the case.

https://github.com/google/flax/blob/main/examples/wmt/train.py#L166

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's jitted here:

p_train_step = jax.pmap(

(the pmap() transform also compiles the code like jit(), but at the same time parallelizes it onto multiple devices)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the train step be jitted?

to answer you question: yes, you should always compile the largest possble code block; usually that's the train_step().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I missed it that, was looking for the decorator. Thank you @andsteing

@canyon289
Copy link
Contributor Author

Do let me know if there's anything I missed. Happy to keep working on this

@8bitmp3
Copy link
Collaborator

8bitmp3 commented May 22, 2023

cc @andsteing @chiamp

@canyon289
Copy link
Contributor Author

Im still committed to finishing this!

@andsteing
Copy link
Collaborator

@canyon289 if you would like reviewers to have another look at the PR, you can click on the Re-quest review button in the Github UI
image

(otherwise busy reviewers might not read all the individual updates and miss the PR until that button is pressed)

@andsteing
Copy link
Collaborator

Tested that everything installs und runs (with expected loss) in a fresh Colab CPU runtime:
https://colab.research.google.com/drive/1MDrchufz1eUoB03znEnRV-Z8yQtg0Qau#revisionId=0B6Hoz8j9CmpCMlZIdlN0N1MwTitLS21FL05UZUtCNXFIU3dNPQ&scrollTo=1nCF9d0DX0Kk
(runs 30 epochs in 31 minutes, final loss 100.73, 103.4 after 10 eps)

Copy link
Collaborator

@andsteing andsteing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good from my side.
(one open comment about empty lines, otherwise good to submit)

# Make sure tf does not allocate gpu memory.
tf.config.experimental.set_visible_devices([], 'GPU')
train.train_and_evaluate(FLAGS)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant: we usually have single empty lines within functions (lines 55-56, 59-60 should both be a single line), but we have double empty lines between "top-level codeblocks" (line 65 should be two empty lines)

@andsteing
Copy link
Collaborator

@cgarciae @8bitmp3 should we wait for your review?
(since you're both listed as reviewers)

@canyon289
Copy link
Contributor Author

What I meant: we usually have single empty lines within functions (lines 55-56, 59-60 should both be a single line), but we have double empty lines between "top-level codeblocks" (line 65 should be two empty lines)

You're right, updated! Thank you

@canyon289
Copy link
Contributor Author

@canyon289 if you would like reviewers to have another look at the PR, you can click on the Re-quest review button in the Github UI image

(otherwise busy reviewers might not read all the individual updates and miss the PR until that button is pressed)

I dont have the rights in this repo to rerequest review from pending reviewers for some reason, I wouldnt mind another review if someone wants to, but dont want to obligate anybody. Happy to have this merged and keep moving on from there

@canyon289
Copy link
Contributor Author

Squashed to single commit and force pushed

@8bitmp3
Copy link
Collaborator

8bitmp3 commented Jun 26, 2023

@marcvanzee and @levskaya PTAL when you have time please

@andsteing
Copy link
Collaborator

andsteing commented Jun 27, 2023

@canyon289 if you would like reviewers to have another look at the PR, you can click on the Re-quest review button in the Github UI image
(otherwise busy reviewers might not read all the individual updates and miss the PR until that button is pressed)

I dont have the rights in this repo to rerequest review from pending reviewers for some reason, I wouldnt mind another review if someone wants to, but dont want to obligate anybody. Happy to have this merged and keep moving on from there

oh, that's good to know! I must have missed your message previously, since nobody else commented on the PR, let's move forward.

@copybara-service copybara-service bot merged commit fe93a7e into google:main Jun 27, 2023
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants