diff --git a/.editorconfig b/.editorconfig index db30a6257e8..55bfe517da2 100644 --- a/.editorconfig +++ b/.editorconfig @@ -9,5 +9,11 @@ indent_size = 4 indent_style = space insert_final_newline = true max_line_length = 120 + # Ktlint-specific config -disabled_rules = filename, max-line-length, argument-list-wrapping, parameter-list-wrapping +ktlint_standard = enabled +ktlint_experimental = disabled +ktlint_standard_filename = disabled +ktlint_standard_max-line-length = disabled +ktlint_standard_argument-list-wrapping = disabled +ktlint_standard_parameter-list-wrapping = disabled diff --git a/.github/scripts/docker-image-hash b/.github/scripts/docker-image-hash index 55b532332df..71fae16318d 100755 --- a/.github/scripts/docker-image-hash +++ b/.github/scripts/docker-image-hash @@ -11,4 +11,4 @@ set -eo pipefail cd "$(dirname "$0")" cd "$(git rev-parse --show-toplevel)" -git ls-files -s --full-name "tools" | git hash-object --stdin +git ls-files -s --full-name "tools/ci-build" | git hash-object --stdin diff --git a/.github/scripts/get-or-create-release-branch.sh b/.github/scripts/get-or-create-release-branch.sh index 6cdc8299111..7dcc87676d0 100755 --- a/.github/scripts/get-or-create-release-branch.sh +++ b/.github/scripts/get-or-create-release-branch.sh @@ -8,8 +8,6 @@ set -eux # Compute the name of the release branch starting from the version that needs to be released ($SEMANTIC_VERSION). # If it's the beginning of a new release series, the branch is created and pushed to the remote (chosen according to # the value $DRY_RUN). -# If it isn't the beginning of a new release series, the script makes sure that the commit that will be tagged is at -# the tip of the (pre-existing) release branch. # # The script populates an output file with key-value pairs that are needed in the release CI workflow to carry out # the next steps in the release flow: the name of the release branch and a boolean flag that is set to 'true' if this @@ -57,16 +55,7 @@ if [[ "${DRY_RUN}" == "true" ]]; then git push --force origin "HEAD:refs/heads/${branch_name}" else commit_sha=$(git rev-parse --short HEAD) - if git ls-remote --exit-code --heads origin "${branch_name}"; then - # The release branch already exists, we need to make sure that our commit is its current tip - branch_head_sha=$(git rev-parse --verify --short "refs/heads/${branch_name}") - if [[ "${branch_head_sha}" != "${commit_sha}" ]]; then - echo "The release branch - ${branch_name} - already exists. ${commit_sha}, the commit you chose when " - echo "launching this release, is not its current HEAD (${branch_head_sha}). This is not allowed: you " - echo "MUST release from the HEAD of the release branch if it already exists." - exit 1 - fi - else + if ! git ls-remote --exit-code --heads origin "${branch_name}"; then # The release branch does not exist. # We need to make sure that the commit SHA that we are releasing is on `main`. git fetch origin main @@ -75,7 +64,7 @@ else git checkout -b "${branch_name}" git push origin "${branch_name}" else - echo "You must choose a commit from main to create a new release series!" + echo "You must choose a commit from main to create a new release branch!" exit 1 fi fi diff --git a/.github/workflows/ci-merge-queue.yml b/.github/workflows/ci-merge-queue.yml new file mode 100644 index 00000000000..bedab9beb5a --- /dev/null +++ b/.github/workflows/ci-merge-queue.yml @@ -0,0 +1,93 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# This workflow runs CI for the GitHub merge queue. + +name: Merge Queue CI +on: + merge_group: + types: [checks_requested] + +# Allow one instance of this workflow per merge +concurrency: + group: ci-merge-queue-yml-${{ github.ref }} + cancel-in-progress: true + +env: + ecr_repository: public.ecr.aws/w0m4q9l7/github-awslabs-smithy-rs-ci + +jobs: + # This job will, if possible, save a docker login password to the job outputs. The token will + # be encrypted with the passphrase stored as a GitHub secret. The login password expires after 12h. + # The login password is encrypted with the repo secret DOCKER_LOGIN_TOKEN_PASSPHRASE + save-docker-login-token: + name: Save a docker login token + outputs: + docker-login-password: ${{ steps.set-token.outputs.docker-login-password }} + permissions: + id-token: write + contents: read + continue-on-error: true + runs-on: ubuntu-latest + steps: + - name: Attempt to load a docker login password + uses: aws-actions/configure-aws-credentials@v1-node16 + with: + role-to-assume: ${{ secrets.SMITHY_RS_PUBLIC_ECR_PUSH_ROLE_ARN }} + role-session-name: GitHubActions + aws-region: us-west-2 + - name: Save the docker login password to the output + id: set-token + run: | + ENCRYPTED_PAYLOAD=$( + gpg --symmetric --batch --passphrase "${{ secrets.DOCKER_LOGIN_TOKEN_PASSPHRASE }}" --output - <(aws ecr-public get-login-password --region us-east-1) | base64 -w0 + ) + echo "docker-login-password=$ENCRYPTED_PAYLOAD" >> $GITHUB_OUTPUT + + # This job detects if the PR made changes to build tools. If it did, then it builds a new + # build Docker image. Otherwise, it downloads a build image from Public ECR. In both cases, + # it uploads the image as a build artifact for other jobs to download and use. + acquire-base-image: + name: Acquire Base Image + needs: save-docker-login-token + runs-on: ubuntu-latest + env: + ENCRYPTED_DOCKER_PASSWORD: ${{ needs.save-docker-login-token.outputs.docker-login-password }} + DOCKER_LOGIN_TOKEN_PASSPHRASE: ${{ secrets.DOCKER_LOGIN_TOKEN_PASSPHRASE }} + permissions: + id-token: write + contents: read + steps: + - uses: actions/checkout@v3 + with: + path: smithy-rs + - name: Acquire base image + id: acquire + env: + DOCKER_BUILDKIT: 1 + run: ./smithy-rs/.github/scripts/acquire-build-image + - name: Acquire credentials + uses: aws-actions/configure-aws-credentials@v1-node16 + with: + role-to-assume: ${{ secrets.SMITHY_RS_PUBLIC_ECR_PUSH_ROLE_ARN }} + role-session-name: GitHubActions + aws-region: us-west-2 + - name: Upload image + run: | + IMAGE_TAG="$(./smithy-rs/.github/scripts/docker-image-hash)" + docker tag "smithy-rs-base-image:${IMAGE_TAG}" "${{ env.ecr_repository }}:${IMAGE_TAG}" + aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws + docker push "${{ env.ecr_repository }}:${IMAGE_TAG}" + + # Run shared CI after the Docker build image has either been rebuilt or found in ECR + ci: + needs: + - save-docker-login-token + - acquire-base-image + if: ${{ github.event.pull_request.head.repo.full_name == 'awslabs/smithy-rs' || toJSON(github.event.merge_group) != '{}' }} + uses: ./.github/workflows/ci.yml + with: + run_sdk_examples: true + secrets: + ENCRYPTED_DOCKER_PASSWORD: ${{ needs.save-docker-login-token.outputs.docker-login-password }} + DOCKER_LOGIN_TOKEN_PASSPHRASE: ${{ secrets.DOCKER_LOGIN_TOKEN_PASSPHRASE }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 774c4f3c7a3..52b54efd45e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ on: required: false env: - rust_version: 1.62.1 + rust_version: 1.63.0 rust_toolchain_components: clippy,rustfmt ENCRYPTED_DOCKER_PASSWORD: ${{ secrets.ENCRYPTED_DOCKER_PASSWORD }} DOCKER_LOGIN_TOKEN_PASSPHRASE: ${{ secrets.DOCKER_LOGIN_TOKEN_PASSPHRASE }} diff --git a/.github/workflows/claim-crate-names.yml b/.github/workflows/claim-crate-names.yml index 650b8cac7cd..63b5eac3079 100644 --- a/.github/workflows/claim-crate-names.yml +++ b/.github/workflows/claim-crate-names.yml @@ -10,7 +10,7 @@ concurrency: cancel-in-progress: true env: - rust_version: 1.62.1 + rust_version: 1.63.0 name: Claim unpublished crate names on crates.io run-name: ${{ github.workflow }} diff --git a/.github/workflows/pull-request-bot.yml b/.github/workflows/pull-request-bot.yml index 7bcd2e833f8..131939fcb0d 100644 --- a/.github/workflows/pull-request-bot.yml +++ b/.github/workflows/pull-request-bot.yml @@ -28,7 +28,7 @@ concurrency: env: java_version: 11 - rust_version: 1.62.1 + rust_version: 1.63.0 rust_toolchain_components: clippy,rustfmt apt_dependencies: libssl-dev gnuplot jq diff --git a/.github/workflows/release-scripts/create-release.js b/.github/workflows/release-scripts/create-release.js index ad4f328e8f4..fb10ea1b2c0 100644 --- a/.github/workflows/release-scripts/create-release.js +++ b/.github/workflows/release-scripts/create-release.js @@ -44,10 +44,13 @@ module.exports = async ({ isDryRun, // Release manifest file path releaseManifestPath, + // The commit-like reference that we want to release (e.g. a commit SHA or a branch name) + releaseCommitish, }) => { assert(github !== undefined, "The `github` argument is required"); assert(isDryRun !== undefined, "The `isDryRun` argument is required"); assert(releaseManifestPath !== undefined, "The `releaseManifestPath` argument is required"); + assert(releaseCommitish !== undefined, "The `releaseCommitish` argument is required"); console.info(`Starting GitHub release creation with isDryRun: ${isDryRun}, and releaseManifestPath: '${releaseManifestPath}'`); @@ -74,6 +77,7 @@ module.exports = async ({ name: releaseManifest.name, body: releaseManifest.body, prerelease: releaseManifest.prerelease, + target_commitish: releaseCommitish, }); console.info(`SUCCESS: Created release with ID: ${response.data.id}, URL: ${response.data.html_url} `); } else { diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 580eb26245a..bcc1cf22ab6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,7 +10,7 @@ concurrency: cancel-in-progress: true env: - rust_version: 1.62.1 + rust_version: 1.63.0 name: Release smithy-rs run-name: ${{ github.workflow }} ${{ inputs.semantic_version }} (${{ inputs.commit_sha }}) - ${{ inputs.dry_run && 'Dry run' || 'Production run' }} @@ -18,8 +18,8 @@ on: workflow_dispatch: inputs: commit_sha: - description: | - The SHA of the git commit that you want to release. + description: | + The SHA of the git commit that you want to release. You must use the non-abbreviated SHA (e.g. b2318b0 won't work!). required: true type: string @@ -75,8 +75,8 @@ jobs: # We need `always` here otherwise this job won't run if the previous job has been skipped # See https://samanpavel.medium.com/github-actions-conditional-job-execution-e6aa363d2867 if: | - always() && - needs.acquire-base-image.result == 'success' && + always() && + needs.acquire-base-image.result == 'success' && (needs.release-ci.result == 'success' || needs.release-ci.result == 'skipped') runs-on: ubuntu-latest outputs: @@ -87,6 +87,7 @@ jobs: with: ref: ${{ inputs.commit_sha }} token: ${{ secrets.RELEASE_AUTOMATION_BOT_PAT }} + fetch-depth: 0 - name: Get or create release branch id: branch-push shell: bash @@ -112,11 +113,13 @@ jobs: runs-on: ubuntu-latest outputs: release_branch: ${{ needs.get-or-create-release-branch.outputs.release_branch }} + commit_sha: ${{ steps.gradle-push.outputs.commit_sha }} steps: - uses: actions/checkout@v3 with: - ref: ${{ needs.get-or-create-release-branch.outputs.release_branch }} + ref: ${{ inputs.commit_sha }} path: smithy-rs + fetch-depth: 0 token: ${{ secrets.RELEASE_AUTOMATION_BOT_PAT }} - name: Upgrade gradle.properties uses: ./smithy-rs/.github/actions/docker-build @@ -131,13 +134,30 @@ jobs: shell: bash env: SEMANTIC_VERSION: ${{ inputs.semantic_version }} + RELEASE_COMMIT_SHA: ${{ inputs.commit_sha }} + RELEASE_BRANCH_NAME: ${{ needs.get-or-create-release-branch.outputs.release_branch }} DRY_RUN: ${{ inputs.dry_run }} run: | set -x + # For debugging purposes git status - # The file was actually changed, we need to commit the changes - git diff-index --quiet HEAD || { git -c 'user.name=AWS SDK Rust Bot' -c 'user.email=aws-sdk-rust-primary@amazon.com' commit gradle.properties --message "Upgrade the smithy-rs runtime crates version to ${SEMANTIC_VERSION}" && git push origin; } + + if ! git diff-index --quiet HEAD; then + # gradle.properties was changed, we need to commit and push the diff + git -c 'user.name=AWS SDK Rust Bot' -c 'user.email=aws-sdk-rust-primary@amazon.com' commit gradle.properties --message "Upgrade the smithy-rs runtime crates version to ${SEMANTIC_VERSION}" + + # This will fail if we tried to release from a non-HEAD commit on the release branch. + # The only scenario where we would try to release a non-HEAD commit from the release branch is + # to retry a release action execution that failed due to a transient issue. + # In that case, we expect the commit to be releasable as-is, i.e. the runtime crate version in gradle.properties + # should already be the expected one. + git push origin "HEAD:refs/heads/${RELEASE_BRANCH_NAME}" + + echo "commit_sha=$(git rev-parse HEAD)" > $GITHUB_OUTPUT + else + echo "commit_sha=${RELEASE_COMMIT_SHA}" > $GITHUB_OUTPUT + fi release: name: Release @@ -158,7 +178,7 @@ jobs: - name: Checkout smithy-rs uses: actions/checkout@v3 with: - ref: ${{ needs.upgrade-gradle-properties.outputs.release_branch }} + ref: ${{ needs.upgrade-gradle-properties.outputs.commit_sha }} path: smithy-rs token: ${{ secrets.RELEASE_AUTOMATION_BOT_PAT }} - name: Generate release artifacts @@ -170,9 +190,20 @@ jobs: - name: Push smithy-rs changes shell: bash working-directory: smithy-rs-release/smithy-rs + id: push-changelog + env: + RELEASE_BRANCH_NAME: ${{ needs.upgrade-gradle-properties.outputs.release_branch }} run: | - echo "Pushing release commits..." - git push origin + if ! git diff-index --quiet HEAD; then + echo "Pushing release commits..." + # This will fail if we tried to release from a non-HEAD commit on the release branch. + # The only scenario where we would try to release a non-HEAD commit from the release branch is + # to retry a release action execution that failed due to a transient issue. + # In that case, we expect the commit to be releasable as-is, i.e. the changelog should have already + # been processed. + git push origin "HEAD:refs/heads/${RELEASE_BRANCH_NAME}" + fi + echo "commit_sha=$(git rev-parse HEAD)" > $GITHUB_OUTPUT - name: Tag release uses: actions/github-script@v6 with: @@ -182,7 +213,8 @@ jobs: await createReleaseScript({ github, isDryRun: ${{ inputs.dry_run }}, - releaseManifestPath: "smithy-rs-release/smithy-rs-release-manifest.json" + releaseManifestPath: "smithy-rs-release/smithy-rs-release-manifest.json", + releaseCommitish: "${{ steps.push-changelog.outputs.commit_sha }}" }); - name: Publish to crates.io shell: bash @@ -232,7 +264,7 @@ jobs: shell: bash run: | set -eux - + # This will fail if other commits have been pushed to `main` after `commit_sha` # In particular, this will ALWAYS fail if you are creating a new release series from # a commit that is not the current tip of `main`. diff --git a/.github/workflows/update-sdk-next.yml b/.github/workflows/update-sdk-next.yml index 8fca2bbc62c..fd6082351e3 100644 --- a/.github/workflows/update-sdk-next.yml +++ b/.github/workflows/update-sdk-next.yml @@ -32,7 +32,7 @@ jobs: - name: Set up Rust uses: dtolnay/rust-toolchain@master with: - toolchain: 1.62.1 + toolchain: 1.63.0 - name: Delete old SDK run: | - name: Generate a fresh SDK diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6da8695fcd..cce6e263271 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,10 +20,10 @@ repos: files: ^.*$ pass_filenames: false - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.5.0 + rev: v2.6.0 hooks: - id: pretty-format-kotlin - args: [--autofix, --ktlint-version, 0.46.1] + args: [--autofix, --ktlint-version, 0.48.2] - id: pretty-format-yaml args: [--autofix, --indent, '2'] - id: pretty-format-rust diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 7a5f6d7515a..5317cd5c7b1 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -11,120 +11,231 @@ # meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"} # author = "rcoh" +[[aws-sdk-rust]] +message = """Request IDs can now be easily retrieved on successful responses. For example, with S3: +```rust +// Import the trait to get the `request_id` method on outputs +use aws_sdk_s3::types::RequestId; +let output = client.list_buckets().send().await?; +println!("Request ID: {:?}", output.request_id()); +``` +""" +references = ["smithy-rs#76", "smithy-rs#2129"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "jdisanti" + +[[aws-sdk-rust]] +message = """Retrieving a request ID from errors now requires importing the `RequestId` trait. For example, with S3: +```rust +use aws_sdk_s3::types::RequestId; +println!("Request ID: {:?}", error.request_id()); +``` +""" +references = ["smithy-rs#76", "smithy-rs#2129"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "jdisanti" + +[[smithy-rs]] +message = "Generic clients no longer expose a `request_id()` function on errors. To get request ID functionality, use the SDK code generator." +references = ["smithy-rs#76", "smithy-rs#2129"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client"} +author = "jdisanti" + +[[aws-sdk-rust]] +message = "The `message()` and `code()` methods on errors have been moved into `ProvideErrorMetadata` trait. This trait will need to be imported to continue calling these." +references = ["smithy-rs#76", "smithy-rs#2129"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "jdisanti" + +[[smithy-rs]] +message = "The `message()` and `code()` methods on errors have been moved into `ProvideErrorMetadata` trait. This trait will need to be imported to continue calling these." +references = ["smithy-rs#76", "smithy-rs#2129"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client"} +author = "jdisanti" + [[aws-sdk-rust]] message = """ -Provide a way to retrieve fallback credentials if a call to `provide_credentials` is interrupted. An interrupt can occur when a timeout future is raced against a future for `provide_credentials`, and the former wins the race. A new method, `fallback_on_interrupt` on the `ProvideCredentials` trait, can be used in that case. The following code snippet from `LazyCredentialsCache::provide_cached_credentials` has been updated like so: -Before: +The `*Error` and `*ErrorKind` types have been combined to make error matching simpler. +
+Example with S3 +**Before:** ```rust -let timeout_future = self.sleeper.sleep(self.load_timeout); -// --snip-- -let future = Timeout::new(provider.provide_credentials(), timeout_future); -let result = cache - .get_or_load(|| { - async move { - let credentials = future.await.map_err(|_err| { - CredentialsError::provider_timed_out(load_timeout) - })??; - // --snip-- +let result = client + .get_object() + .bucket(BUCKET_NAME) + .key("some-key") + .send() + .await; +match result { + Ok(_output) => { /* Do something with the output */ } + Err(err) => match err.into_service_error() { + GetObjectError { kind, .. } => match kind { + GetObjectErrorKind::InvalidObjectState(value) => println!("invalid object state: {:?}", value), + GetObjectErrorKind::NoSuchKey(_) => println!("object didn't exist"), } - }).await; -// --snip-- + err @ GetObjectError { .. } if err.code() == Some("SomeUnmodeledError") => {} + err @ _ => return Err(err.into()), + }, +} ``` -After: +**After:** ```rust -let timeout_future = self.sleeper.sleep(self.load_timeout); -// --snip-- -let future = Timeout::new(provider.provide_credentials(), timeout_future); -let result = cache - .get_or_load(|| { - async move { - let credentials = match future.await { - Ok(creds) => creds?, - Err(_err) => match provider.fallback_on_interrupt() { // can provide fallback credentials - Some(creds) => creds, - None => return Err(CredentialsError::provider_timed_out(load_timeout)), - } - }; - // --snip-- +// Needed to access the `.code()` function on the error type: +use aws_sdk_s3::types::ProvideErrorMetadata; +let result = client + .get_object() + .bucket(BUCKET_NAME) + .key("some-key") + .send() + .await; +match result { + Ok(_output) => { /* Do something with the output */ } + Err(err) => match err.into_service_error() { + GetObjectError::InvalidObjectState(value) => { + println!("invalid object state: {:?}", value); + } + GetObjectError::NoSuchKey(_) => { + println!("object didn't exist"); } - }).await; -// --snip-- + err if err.code() == Some("SomeUnmodeledError") => {} + err @ _ => return Err(err.into()), + }, +} ``` +
""" -references = ["smithy-rs#2246"] -meta = { "breaking" = false, "tada" = false, "bug" = false } -author = "ysaito1001" +references = ["smithy-rs#76", "smithy-rs#2129", "smithy-rs#2075"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "jdisanti" [[smithy-rs]] -message = "The [`@uniqueItems`](https://smithy.io/2.0/spec/constraint-traits.html#uniqueitems-trait) trait on `list` shapes is now supported in server SDKs." -references = ["smithy-rs#2232", "smithy-rs#1670"] -meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "server"} -author = "david-perez" - -[[aws-sdk-rust]] message = """ -Add static stability support to IMDS credentials provider. It does not alter common use cases for the provider, but allows the provider to serve expired credentials in case IMDS is unreachable. This allows requests to be dispatched to a target service with expired credentials. This, in turn, allows the target service to make the ultimate decision as to whether requests sent are valid or not. +The `*Error` and `*ErrorKind` types have been combined to make error matching simpler. +
+Example with S3 +**Before:** +```rust +let result = client + .get_object() + .bucket(BUCKET_NAME) + .key("some-key") + .send() + .await; +match result { + Ok(_output) => { /* Do something with the output */ } + Err(err) => match err.into_service_error() { + GetObjectError { kind, .. } => match kind { + GetObjectErrorKind::InvalidObjectState(value) => println!("invalid object state: {:?}", value), + GetObjectErrorKind::NoSuchKey(_) => println!("object didn't exist"), + } + err @ GetObjectError { .. } if err.code() == Some("SomeUnmodeledError") => {} + err @ _ => return Err(err.into()), + }, +} +``` +**After:** +```rust +// Needed to access the `.code()` function on the error type: +use aws_sdk_s3::types::ProvideErrorMetadata; +let result = client + .get_object() + .bucket(BUCKET_NAME) + .key("some-key") + .send() + .await; +match result { + Ok(_output) => { /* Do something with the output */ } + Err(err) => match err.into_service_error() { + GetObjectError::InvalidObjectState(value) => { + println!("invalid object state: {:?}", value); + } + GetObjectError::NoSuchKey(_) => { + println!("object didn't exist"); + } + err if err.code() == Some("SomeUnmodeledError") => {} + err @ _ => return Err(err.into()), + }, +} +``` +
""" -references = ["smithy-rs#2258"] -meta = { "breaking" = false, "tada" = true, "bug" = false } -author = "ysaito1001" +references = ["smithy-rs#76", "smithy-rs#2129", "smithy-rs#2075"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client"} +author = "jdisanti" [[smithy-rs]] -message = "Fix broken doc link for `tokio_stream::Stream` that is a re-export of `futures_core::Stream`." -references = ["smithy-rs#2271"] -meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client"} -author = "ysaito1001" +message = "`aws_smithy_types::Error` has been renamed to `aws_smithy_types::error::ErrorMetadata`." +references = ["smithy-rs#76", "smithy-rs#2129"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client"} +author = "jdisanti" [[aws-sdk-rust]] -message = "Fix broken doc link for `tokio_stream::Stream` that is a re-export of `futures_core::Stream`." -references = ["smithy-rs#2271"] -meta = { "breaking" = false, "tada" = false, "bug" = true } -author = "ysaito1001" +message = "`aws_smithy_types::Error` has been renamed to `aws_smithy_types::error::ErrorMetadata`." +references = ["smithy-rs#76", "smithy-rs#2129"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "jdisanti" + +[[aws-sdk-rust]] +message = "Fluent builder methods on the client are now marked as deprecated when the related operation is deprecated." +references = ["aws-sdk-rust#740"] +meta = { "breaking" = false, "tada" = true, "bug" = true } +author = "Velfi" + +[[smithy-rs]] +message = "Fluent builder methods on the client are now marked as deprecated when the related operation is deprecated." +references = ["aws-sdk-rust#740"] +meta = { "breaking" = false, "tada" = true, "bug" = true, "target" = "client"} +author = "Velfi" [[smithy-rs]] message = """ -Fix `name` and `absolute` methods on `OperationExtension`. -The older, [now removed](https://github.com/awslabs/smithy-rs/pull/2161), service builder would insert `OperationExtension` into the `http::Response` containing the [absolute shape ID](https://smithy.io/2.0/spec/model.html#grammar-token-smithy-AbsoluteRootShapeId) with the `#` symbol replaced with a `.`. When [reintroduced](https://github.com/awslabs/smithy-rs/pull/2157) into the new service builder machinery the behavior was changed - we now do _not_ perform the replace. This change fixes the documentation and `name`/`absolute` methods of the `OperationExtension` API to match this new behavior. -In the old service builder, `OperationExtension` was initialized, by the framework, and then used as follows: -```rust -let ext = OperationExtension::new("com.amazonaws.CompleteSnapshot"); -// This is expected -let name = ext.name(); // "CompleteSnapshot" -let namespace = ext.namespace(); // = "com.amazonaws"; +Add support for the `awsQueryCompatible` trait. This allows services to continue supporting a custom error code (via the `awsQueryError` trait) when the services migrate their protocol from `awsQuery` to `awsJson1_0` annotated with `awsQueryCompatible`. +
+Click to expand for more details... + +After the migration, services will include an additional header `x-amzn-query-error` in their responses whose value is in the form of `;`. An example response looks something like ``` -When reintroduced, `OperationExtension` was initialized by the `Plugin` and then used as follows: -```rust -let ext = OperationExtension::new("com.amazonaws#CompleteSnapshot"); -// This is the bug -let name = ext.name(); // "amazonaws#CompleteSnapshot" -let namespace = ext.namespace(); // = "com"; +HTTP/1.1 400 +x-amzn-query-error: AWS.SimpleQueueService.NonExistentQueue;Sender +Date: Wed, 08 Sep 2021 23:46:52 GMT +Content-Type: application/x-amz-json-1.0 +Content-Length: 163 + +{ + "__type": "com.amazonaws.sqs#QueueDoesNotExist", + "message": "some user-visible message" +} ``` -The intended behavior is now restored: +`` is `AWS.SimpleQueueService.NonExistentQueue` and `` is `Sender`. + +If an operation results in an error that causes a service to send back the response above, you can access `` and `` as follows: ```rust -let ext = OperationExtension::new("com.amazonaws#CompleteSnapshot"); -// This is expected -let name = ext.name(); // "CompleteSnapshot" -let namespace = ext.namespace(); // = "com.amazonaws"; +match client.some_operation().send().await { + Ok(_) => { /* success */ } + Err(sdk_err) => { + let err = sdk_err.into_service_error(); + assert_eq!( + error.meta().code(), + Some("AWS.SimpleQueueService.NonExistentQueue"), + ); + assert_eq!(error.meta().extra("type"), Some("Sender")); + } +} +
``` -The rationale behind this change is that the previous design was tailored towards a specific internal use case and shouldn't be enforced on all customers. """ -references = ["smithy-rs#2276"] -meta = { "breaking" = true, "tada" = false, "bug" = true, "target" = "server"} -author = "hlbarber" +references = ["smithy-rs#2398"] +meta = { "breaking" = false, "tada" = true, "bug" = false } +author = "ysaito1001" [[aws-sdk-rust]] -message = """ -Fix request canonicalization for HTTP requests with repeated headers (for example S3's `GetObjectAttributes`). Previously requests with repeated headers would fail with a 403 signature mismatch due to this bug. -""" -references = ["smithy-rs#2261", "aws-sdk-rust#720"] -meta = { "breaking" = false, "tada" = false, "bug" = true } -author = "nipunn1313" +message = "`SdkError` variants can now be constructed for easier unit testing." +references = ["smithy-rs#2428", "smithy-rs#2208"] +meta = { "breaking" = false, "tada" = true, "bug" = false } +author = "jdisanti" [[smithy-rs]] -message = """Add serde crate to `aws-smithy-types`. - -It's behind the feature gate `aws_sdk_unstable` which can only be enabled via a `--cfg` flag. -""" -references = ["smithy-rs#1944"] -meta = { "breaking" = false, "tada" = false, "bug" = false } -author = "thomas-k-cameron" +message = "`SdkError` variants can now be constructed for easier unit testing." +references = ["smithy-rs#2428", "smithy-rs#2208"] +meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client" } +author = "jdisanti" diff --git a/aws/SDK_CHANGELOG.next.json b/aws/SDK_CHANGELOG.next.json index 5a4025621f7..5673dbc79f6 100644 --- a/aws/SDK_CHANGELOG.next.json +++ b/aws/SDK_CHANGELOG.next.json @@ -498,4 +498,4 @@ } ], "aws-sdk-model": [] -} \ No newline at end of file +} diff --git a/aws/rust-runtime/aws-config/external-types.toml b/aws/rust-runtime/aws-config/external-types.toml index 9a7dbc128e3..6935fe9bc0e 100644 --- a/aws/rust-runtime/aws-config/external-types.toml +++ b/aws/rust-runtime/aws-config/external-types.toml @@ -29,9 +29,6 @@ allowed_external_types = [ "http::uri::Uri", "tower_service::Service", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Decide if `InvalidUri` should be exposed - "http::uri::InvalidUri", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Decide if the following should be exposed "hyper::client::connect::Connection", "tokio::io::async_read::AsyncRead", diff --git a/aws/rust-runtime/aws-config/src/imds/client.rs b/aws/rust-runtime/aws-config/src/imds/client.rs index ba5cea046cb..e791ec929a7 100644 --- a/aws/rust-runtime/aws-config/src/imds/client.rs +++ b/aws/rust-runtime/aws-config/src/imds/client.rs @@ -917,7 +917,7 @@ pub(crate) mod test { imds_request("http://169.254.169.254/latest/metadata", TOKEN_A), http::Response::builder() .status(200) - .body(SdkBody::from(vec![0xA0 as u8, 0xA1 as u8])) + .body(SdkBody::from(vec![0xA0, 0xA1])) .unwrap(), ), ]); diff --git a/aws/rust-runtime/aws-config/src/lib.rs b/aws/rust-runtime/aws-config/src/lib.rs index 8ca695a7700..fc82e6d1fe7 100644 --- a/aws/rust-runtime/aws-config/src/lib.rs +++ b/aws/rust-runtime/aws-config/src/lib.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_debug_implementations, missing_docs, diff --git a/aws/rust-runtime/aws-config/src/profile/credentials.rs b/aws/rust-runtime/aws-config/src/profile/credentials.rs index ba7fb9241ed..9ce085503b4 100644 --- a/aws/rust-runtime/aws-config/src/profile/credentials.rs +++ b/aws/rust-runtime/aws-config/src/profile/credentials.rs @@ -61,9 +61,8 @@ impl ProvideCredentials for ProfileFileCredentialsProvider { /// let provider = ProfileFileCredentialsProvider::builder().build(); /// ``` /// -/// _Note: Profile providers to not implement any caching. They will reload and reparse the profile -/// from the file system when called. See [CredentialsCache](aws_credential_types::cache::CredentialsCache) for -/// more information about caching._ +/// _Note: Profile providers, when called, will load and parse the profile from the file system +/// only once. Parsed file contents will be cached indefinitely._ /// /// This provider supports several different credentials formats: /// ### Credentials defined explicitly within the file diff --git a/aws/rust-runtime/aws-config/src/profile/credentials/exec.rs b/aws/rust-runtime/aws-config/src/profile/credentials/exec.rs index 51449f22efa..f930b1a8d8f 100644 --- a/aws/rust-runtime/aws-config/src/profile/credentials/exec.rs +++ b/aws/rust-runtime/aws-config/src/profile/credentials/exec.rs @@ -3,14 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -use std::sync::Arc; - -use aws_sdk_sts::operation::AssumeRole; -use aws_sdk_sts::{Config, Credentials}; -use aws_types::region::Region; - use super::repr::{self, BaseProvider}; - use crate::credential_process::CredentialProcessProvider; use crate::profile::credentials::ProfileFileError; use crate::provider_config::ProviderConfig; @@ -18,10 +11,13 @@ use crate::sso::{SsoConfig, SsoCredentialsProvider}; use crate::sts; use crate::web_identity_token::{StaticConfiguration, WebIdentityTokenCredentialsProvider}; use aws_credential_types::provider::{self, error::CredentialsError, ProvideCredentials}; +use aws_sdk_sts::input::AssumeRoleInput; use aws_sdk_sts::middleware::DefaultMiddleware; +use aws_sdk_sts::{Config, Credentials}; use aws_smithy_client::erase::DynConnector; - +use aws_types::region::Region; use std::fmt::Debug; +use std::sync::Arc; #[derive(Debug)] pub(super) struct AssumeRoleProvider { @@ -51,7 +47,7 @@ impl AssumeRoleProvider { .as_ref() .cloned() .unwrap_or_else(|| sts::util::default_session_name("assume-role-from-profile")); - let operation = AssumeRole::builder() + let operation = AssumeRoleInput::builder() .role_arn(&self.role_arn) .set_external_id(self.external_id.clone()) .role_session_name(session_name) diff --git a/aws/rust-runtime/aws-config/src/sso.rs b/aws/rust-runtime/aws-config/src/sso.rs index 0f58264645b..3881a217ab9 100644 --- a/aws/rust-runtime/aws-config/src/sso.rs +++ b/aws/rust-runtime/aws-config/src/sso.rs @@ -211,7 +211,7 @@ async fn load_sso_credentials( let config = aws_sdk_sso::Config::builder() .region(sso_config.region.clone()) .build(); - let operation = aws_sdk_sso::operation::GetRoleCredentials::builder() + let operation = aws_sdk_sso::input::GetRoleCredentialsInput::builder() .role_name(&sso_config.role_name) .access_token(&*token.access_token) .account_id(&sso_config.account_id) diff --git a/aws/rust-runtime/aws-config/src/sts/assume_role.rs b/aws/rust-runtime/aws-config/src/sts/assume_role.rs index 422b6441515..ad24b11bb56 100644 --- a/aws/rust-runtime/aws-config/src/sts/assume_role.rs +++ b/aws/rust-runtime/aws-config/src/sts/assume_role.rs @@ -5,19 +5,18 @@ //! Assume credentials for a role through the AWS Security Token Service (STS). +use crate::provider_config::ProviderConfig; use aws_credential_types::cache::CredentialsCache; use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials}; -use aws_sdk_sts::error::AssumeRoleErrorKind; +use aws_sdk_sts::error::AssumeRoleError; +use aws_sdk_sts::input::AssumeRoleInput; use aws_sdk_sts::middleware::DefaultMiddleware; use aws_sdk_sts::model::PolicyDescriptorType; -use aws_sdk_sts::operation::AssumeRole; use aws_smithy_client::erase::DynConnector; use aws_smithy_http::result::SdkError; +use aws_smithy_types::error::display::DisplayErrorContext; use aws_types::region::Region; use std::time::Duration; - -use crate::provider_config::ProviderConfig; -use aws_smithy_types::error::display::DisplayErrorContext; use tracing::Instrument; /// Credentials provider that uses credentials provided by another provider to assume a role @@ -225,7 +224,7 @@ impl AssumeRoleProviderBuilder { .session_name .unwrap_or_else(|| super::util::default_session_name("assume-role-provider")); - let operation = AssumeRole::builder() + let operation = AssumeRoleInput::builder() .set_role_arn(Some(self.role_arn)) .set_external_id(self.external_id) .set_role_session_name(Some(session_name)) @@ -266,9 +265,9 @@ impl Inner { } Err(SdkError::ServiceError(ref context)) if matches!( - context.err().kind, - AssumeRoleErrorKind::RegionDisabledException(_) - | AssumeRoleErrorKind::MalformedPolicyDocumentException(_) + context.err(), + AssumeRoleError::RegionDisabledException(_) + | AssumeRoleError::MalformedPolicyDocumentException(_) ) => { Err(CredentialsError::invalid_configuration( diff --git a/aws/rust-runtime/aws-config/src/web_identity_token.rs b/aws/rust-runtime/aws-config/src/web_identity_token.rs index e3c45c7bb8a..82c184897fb 100644 --- a/aws/rust-runtime/aws-config/src/web_identity_token.rs +++ b/aws/rust-runtime/aws-config/src/web_identity_token.rs @@ -236,7 +236,7 @@ async fn load_credentials( .region(region.clone()) .build(); - let operation = aws_sdk_sts::operation::AssumeRoleWithWebIdentity::builder() + let operation = aws_sdk_sts::input::AssumeRoleWithWebIdentityInput::builder() .role_arn(role_arn) .role_session_name(session_name) .web_identity_token(token) diff --git a/aws/rust-runtime/aws-credential-types/Cargo.toml b/aws/rust-runtime/aws-credential-types/Cargo.toml index bf7e3325681..dd241e764ec 100644 --- a/aws/rust-runtime/aws-credential-types/Cargo.toml +++ b/aws/rust-runtime/aws-credential-types/Cargo.toml @@ -14,6 +14,7 @@ test-util = [] [dependencies] aws-smithy-async = { path = "../../../rust-runtime/aws-smithy-async" } aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types" } +fastrand = "1.4.0" tokio = { version = "1.8.4", features = ["sync"] } tracing = "0.1" zeroize = "1" diff --git a/aws/rust-runtime/aws-credential-types/src/cache/lazy_caching.rs b/aws/rust-runtime/aws-credential-types/src/cache/lazy_caching.rs index 3a2459c5974..1081b8f3364 100644 --- a/aws/rust-runtime/aws-credential-types/src/cache/lazy_caching.rs +++ b/aws/rust-runtime/aws-credential-types/src/cache/lazy_caching.rs @@ -20,6 +20,7 @@ use crate::time_source::TimeSource; const DEFAULT_LOAD_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_CREDENTIAL_EXPIRATION: Duration = Duration::from_secs(15 * 60); const DEFAULT_BUFFER_TIME: Duration = Duration::from_secs(10); +const DEFAULT_BUFFER_TIME_JITTER_FRACTION: fn() -> f64 = fastrand::f64; #[derive(Debug)] pub(crate) struct LazyCredentialsCache { @@ -28,6 +29,8 @@ pub(crate) struct LazyCredentialsCache { cache: ExpiringCache, provider: SharedCredentialsProvider, load_timeout: Duration, + buffer_time: Duration, + buffer_time_jitter_fraction: fn() -> f64, default_credential_expiration: Duration, } @@ -37,8 +40,9 @@ impl LazyCredentialsCache { sleeper: Arc, provider: SharedCredentialsProvider, load_timeout: Duration, - default_credential_expiration: Duration, buffer_time: Duration, + buffer_time_jitter_fraction: fn() -> f64, + default_credential_expiration: Duration, ) -> Self { Self { time, @@ -46,6 +50,8 @@ impl LazyCredentialsCache { cache: ExpiringCache::new(buffer_time), provider, load_timeout, + buffer_time, + buffer_time_jitter_fraction, default_credential_expiration, } } @@ -95,17 +101,28 @@ impl ProvideCachedCredentials for LazyCredentialsCache { let expiry = credentials .expiry() .unwrap_or(now + default_credential_expiration); - Ok((credentials, expiry)) + + let jitter = self + .buffer_time + .mul_f64((self.buffer_time_jitter_fraction)()); + + // Logging for cache miss should be emitted here as opposed to after the call to + // `cache.get_or_load` above. In the case of multiple threads concurrently executing + // `cache.get_or_load`, logging inside `cache.get_or_load` ensures that it is emitted + // only once for the first thread that succeeds in populating a cache value. + info!( + "credentials cache miss occurred; added new AWS credentials (took {:?})", + start_time.elapsed() + ); + + Ok((credentials, expiry + jitter)) } // Only instrument the the actual load future so that no span // is opened if the cache decides not to execute it. .instrument(span) }) .await; - info!( - "credentials cache miss occurred; retrieved new AWS credentials (took {:?})", - start_time.elapsed() - ); + debug!("loaded credentials"); result } }) @@ -125,8 +142,8 @@ mod builder { use super::TimeSource; use super::{ - LazyCredentialsCache, DEFAULT_BUFFER_TIME, DEFAULT_CREDENTIAL_EXPIRATION, - DEFAULT_LOAD_TIMEOUT, + LazyCredentialsCache, DEFAULT_BUFFER_TIME, DEFAULT_BUFFER_TIME_JITTER_FRACTION, + DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_LOAD_TIMEOUT, }; /// Builder for constructing a `LazyCredentialsCache`. @@ -147,6 +164,7 @@ mod builder { time_source: Option, load_timeout: Option, buffer_time: Option, + buffer_time_jitter_fraction: Option f64>, default_credential_expiration: Option, } @@ -228,6 +246,38 @@ mod builder { self } + /// A random percentage by which buffer time is jittered for randomization. + /// + /// For example, if credentials are expiring in 15 minutes, the buffer time is 10 seconds, + /// and buffer time jitter fraction is 0.2, then buffer time is adjusted to 8 seconds. + /// Therefore, any requests made after 14 minutes and 52 seconds will load new credentials. + /// + /// Defaults to a randomly generated value between 0.0 and 1.0. This setter is for testing only. + #[cfg(feature = "test-util")] + pub fn buffer_time_jitter_fraction( + mut self, + buffer_time_jitter_fraction: fn() -> f64, + ) -> Self { + self.set_buffer_time_jitter_fraction(Some(buffer_time_jitter_fraction)); + self + } + + /// A random percentage by which buffer time is jittered for randomization. + /// + /// For example, if credentials are expiring in 15 minutes, the buffer time is 10 seconds, + /// and buffer time jitter fraction is 0.2, then buffer time is adjusted to 8 seconds. + /// Therefore, any requests made after 14 minutes and 52 seconds will load new credentials. + /// + /// Defaults to a randomly generated value between 0.0 and 1.0. This setter is for testing only. + #[cfg(feature = "test-util")] + pub fn set_buffer_time_jitter_fraction( + &mut self, + buffer_time_jitter_fraction: Option f64>, + ) -> &mut Self { + self.buffer_time_jitter_fraction = buffer_time_jitter_fraction; + self + } + /// Default expiration time to set on credentials if they don't have an expiration time. /// /// This is only used if the given [`ProvideCredentials`](crate::provider::ProvideCredentials) returns @@ -283,8 +333,10 @@ mod builder { }), provider, self.load_timeout.unwrap_or(DEFAULT_LOAD_TIMEOUT), - default_credential_expiration, self.buffer_time.unwrap_or(DEFAULT_BUFFER_TIME), + self.buffer_time_jitter_fraction + .unwrap_or(DEFAULT_BUFFER_TIME_JITTER_FRACTION), + default_credential_expiration, ) } } @@ -310,8 +362,11 @@ mod tests { DEFAULT_LOAD_TIMEOUT, }; + const BUFFER_TIME_NO_JITTER: fn() -> f64 = || 0_f64; + fn test_provider( time: TimeSource, + buffer_time_jitter_fraction: fn() -> f64, load_list: Vec, ) -> LazyCredentialsCache { let load_list = Arc::new(Mutex::new(load_list)); @@ -327,8 +382,9 @@ mod tests { } })), DEFAULT_LOAD_TIMEOUT, - DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_BUFFER_TIME, + buffer_time_jitter_fraction, + DEFAULT_CREDENTIAL_EXPIRATION, ) } @@ -361,8 +417,9 @@ mod tests { Arc::new(TokioSleep::new()), provider, DEFAULT_LOAD_TIMEOUT, - DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_BUFFER_TIME, + BUFFER_TIME_NO_JITTER, + DEFAULT_CREDENTIAL_EXPIRATION, ); assert_eq!( epoch_secs(1000), @@ -381,6 +438,7 @@ mod tests { let mut time = TestingTimeSource::new(epoch_secs(100)); let credentials_cache = test_provider( TimeSource::testing(&time), + BUFFER_TIME_NO_JITTER, vec![ Ok(credentials(1000)), Ok(credentials(2000)), @@ -404,6 +462,7 @@ mod tests { let mut time = TestingTimeSource::new(epoch_secs(100)); let credentials_cache = test_provider( TimeSource::testing(&time), + BUFFER_TIME_NO_JITTER, vec![ Ok(credentials(1000)), Err(CredentialsError::not_loaded("failed")), @@ -430,6 +489,7 @@ mod tests { let time = TestingTimeSource::new(epoch_secs(0)); let credentials_cache = Arc::new(test_provider( TimeSource::testing(&time), + BUFFER_TIME_NO_JITTER, vec![ Ok(credentials(500)), Ok(credentials(1500)), @@ -480,8 +540,9 @@ mod tests { Ok(credentials(1000)) })), Duration::from_millis(5), - DEFAULT_CREDENTIAL_EXPIRATION, DEFAULT_BUFFER_TIME, + BUFFER_TIME_NO_JITTER, + DEFAULT_CREDENTIAL_EXPIRATION, ); assert!(matches!( @@ -489,4 +550,30 @@ mod tests { Err(CredentialsError::ProviderTimedOut { .. }) )); } + + #[tokio::test] + async fn buffer_time_jitter() { + let mut time = TestingTimeSource::new(epoch_secs(100)); + let buffer_time_jitter_fraction = || 0.5_f64; + let credentials_cache = test_provider( + TimeSource::testing(&time), + buffer_time_jitter_fraction, + vec![Ok(credentials(1000)), Ok(credentials(2000))], + ); + + expect_creds(1000, &credentials_cache).await; + let buffer_time_with_jitter = + (DEFAULT_BUFFER_TIME.as_secs_f64() * buffer_time_jitter_fraction()) as u64; + assert_eq!(buffer_time_with_jitter, 5); + // Advance time to the point where the first credentials are about to expire (but haven't). + let almost_expired_secs = 1000 - buffer_time_with_jitter - 1; + time.set_time(epoch_secs(almost_expired_secs)); + // We should still use the first credentials. + expect_creds(1000, &credentials_cache).await; + // Now let the first credentials expire. + let expired_secs = almost_expired_secs + 1; + time.set_time(epoch_secs(expired_secs)); + // Now that the first credentials have been expired, the second credentials will be retrieved. + expect_creds(2000, &credentials_cache).await; + } } diff --git a/aws/rust-runtime/aws-credential-types/src/lib.rs b/aws/rust-runtime/aws-credential-types/src/lib.rs index 1f790c3fca5..b2f8330b589 100644 --- a/aws/rust-runtime/aws-credential-types/src/lib.rs +++ b/aws/rust-runtime/aws-credential-types/src/lib.rs @@ -8,6 +8,7 @@ //! * An opaque struct representing credentials //! * Concrete implementations of credentials caching +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_debug_implementations, missing_docs, diff --git a/aws/rust-runtime/aws-endpoint/src/lib.rs b/aws/rust-runtime/aws-endpoint/src/lib.rs index c8951431845..deb4b6b670b 100644 --- a/aws/rust-runtime/aws-endpoint/src/lib.rs +++ b/aws/rust-runtime/aws-endpoint/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] + use std::collections::HashMap; use std::error::Error; use std::fmt; @@ -270,7 +272,7 @@ mod test { let mut req = operation::Request::new(req); { let mut props = req.properties_mut(); - props.insert(region.clone()); + props.insert(region); props.insert(SigningService::from_static("qldb")); props.insert(endpoint); }; diff --git a/aws/rust-runtime/aws-http/src/auth.rs b/aws/rust-runtime/aws-http/src/auth.rs index c91b4c5bb51..98e0e219bb0 100644 --- a/aws/rust-runtime/aws-http/src/auth.rs +++ b/aws/rust-runtime/aws-http/src/auth.rs @@ -188,10 +188,7 @@ mod tests { .create_cache(SharedCredentialsProvider::new(provide_credentials_fn( || async { Ok(Credentials::for_tests()) }, ))); - set_credentials_cache( - &mut req.properties_mut(), - SharedCredentialsCache::from(credentials_cache), - ); + set_credentials_cache(&mut req.properties_mut(), credentials_cache); let req = CredentialsStage::new() .apply(req) .await diff --git a/aws/rust-runtime/aws-http/src/lib.rs b/aws/rust-runtime/aws-http/src/lib.rs index b000c3d6aea..d5307bcba3d 100644 --- a/aws/rust-runtime/aws-http/src/lib.rs +++ b/aws/rust-runtime/aws-http/src/lib.rs @@ -3,8 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -//! Provides user agent and credentials middleware for the AWS SDK. +//! AWS-specific middleware implementations and HTTP-related features. +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_docs, rustdoc::missing_crate_level_docs, @@ -27,3 +28,6 @@ pub mod user_agent; /// AWS-specific content-encoding tools pub mod content_encoding; + +/// AWS-specific request ID support +pub mod request_id; diff --git a/aws/rust-runtime/aws-http/src/request_id.rs b/aws/rust-runtime/aws-http/src/request_id.rs new file mode 100644 index 00000000000..c3f19272280 --- /dev/null +++ b/aws/rust-runtime/aws-http/src/request_id.rs @@ -0,0 +1,182 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_http::http::HttpHeaders; +use aws_smithy_http::operation; +use aws_smithy_http::result::SdkError; +use aws_smithy_types::error::metadata::{ + Builder as ErrorMetadataBuilder, ErrorMetadata, ProvideErrorMetadata, +}; +use aws_smithy_types::error::Unhandled; +use http::{HeaderMap, HeaderValue}; + +/// Constant for the [`ErrorMetadata`] extra field that contains the request ID +const AWS_REQUEST_ID: &str = "aws_request_id"; + +/// Implementers add a function to return an AWS request ID +pub trait RequestId { + /// Returns the request ID, or `None` if the service could not be reached. + fn request_id(&self) -> Option<&str>; +} + +impl RequestId for SdkError +where + R: HttpHeaders, +{ + fn request_id(&self) -> Option<&str> { + match self { + Self::ResponseError(err) => extract_request_id(err.raw().http_headers()), + Self::ServiceError(err) => extract_request_id(err.raw().http_headers()), + _ => None, + } + } +} + +impl RequestId for ErrorMetadata { + fn request_id(&self) -> Option<&str> { + self.extra(AWS_REQUEST_ID) + } +} + +impl RequestId for Unhandled { + fn request_id(&self) -> Option<&str> { + self.meta().request_id() + } +} + +impl RequestId for operation::Response { + fn request_id(&self) -> Option<&str> { + extract_request_id(self.http().headers()) + } +} + +impl RequestId for http::Response { + fn request_id(&self) -> Option<&str> { + extract_request_id(self.headers()) + } +} + +impl RequestId for Result +where + O: RequestId, + E: RequestId, +{ + fn request_id(&self) -> Option<&str> { + match self { + Ok(ok) => ok.request_id(), + Err(err) => err.request_id(), + } + } +} + +/// Applies a request ID to a generic error builder +#[doc(hidden)] +pub fn apply_request_id( + builder: ErrorMetadataBuilder, + headers: &HeaderMap, +) -> ErrorMetadataBuilder { + if let Some(request_id) = extract_request_id(headers) { + builder.custom(AWS_REQUEST_ID, request_id) + } else { + builder + } +} + +/// Extracts a request ID from HTTP response headers +fn extract_request_id(headers: &HeaderMap) -> Option<&str> { + headers + .get("x-amzn-requestid") + .or_else(|| headers.get("x-amz-request-id")) + .and_then(|value| value.to_str().ok()) +} + +#[cfg(test)] +mod tests { + use super::*; + use aws_smithy_http::body::SdkBody; + use http::Response; + + #[test] + fn test_request_id_sdk_error() { + let without_request_id = + || operation::Response::new(Response::builder().body(SdkBody::empty()).unwrap()); + let with_request_id = || { + operation::Response::new( + Response::builder() + .header( + "x-amzn-requestid", + HeaderValue::from_static("some-request-id"), + ) + .body(SdkBody::empty()) + .unwrap(), + ) + }; + assert_eq!( + None, + SdkError::<(), _>::response_error("test", without_request_id()).request_id() + ); + assert_eq!( + Some("some-request-id"), + SdkError::<(), _>::response_error("test", with_request_id()).request_id() + ); + assert_eq!( + None, + SdkError::service_error((), without_request_id()).request_id() + ); + assert_eq!( + Some("some-request-id"), + SdkError::service_error((), with_request_id()).request_id() + ); + } + + #[test] + fn test_extract_request_id() { + let mut headers = HeaderMap::new(); + assert_eq!(None, extract_request_id(&headers)); + + headers.append( + "x-amzn-requestid", + HeaderValue::from_static("some-request-id"), + ); + assert_eq!(Some("some-request-id"), extract_request_id(&headers)); + + headers.append( + "x-amz-request-id", + HeaderValue::from_static("other-request-id"), + ); + assert_eq!(Some("some-request-id"), extract_request_id(&headers)); + + headers.remove("x-amzn-requestid"); + assert_eq!(Some("other-request-id"), extract_request_id(&headers)); + } + + #[test] + fn test_apply_request_id() { + let mut headers = HeaderMap::new(); + assert_eq!( + ErrorMetadata::builder().build(), + apply_request_id(ErrorMetadata::builder(), &headers).build(), + ); + + headers.append( + "x-amzn-requestid", + HeaderValue::from_static("some-request-id"), + ); + assert_eq!( + ErrorMetadata::builder() + .custom(AWS_REQUEST_ID, "some-request-id") + .build(), + apply_request_id(ErrorMetadata::builder(), &headers).build(), + ); + } + + #[test] + fn test_error_metadata_request_id_impl() { + let err = ErrorMetadata::builder() + .custom(AWS_REQUEST_ID, "some-request-id") + .build(); + assert_eq!(Some("some-request-id"), err.request_id()); + } +} diff --git a/aws/rust-runtime/aws-inlineable/src/http_body_checksum.rs b/aws/rust-runtime/aws-inlineable/src/http_body_checksum.rs index 59dea2ca41d..d99e1b11231 100644 --- a/aws/rust-runtime/aws-inlineable/src/http_body_checksum.rs +++ b/aws/rust-runtime/aws-inlineable/src/http_body_checksum.rs @@ -269,7 +269,7 @@ mod tests { for i in 0..10000 { let line = format!("This is a large file created for testing purposes {}", i); - file.as_file_mut().write(line.as_bytes()).unwrap(); + file.as_file_mut().write_all(line.as_bytes()).unwrap(); crc32c_checksum.update(line.as_bytes()); } diff --git a/aws/rust-runtime/aws-inlineable/src/lib.rs b/aws/rust-runtime/aws-inlineable/src/lib.rs index b4e00994d41..d5ed3b8be39 100644 --- a/aws/rust-runtime/aws-inlineable/src/lib.rs +++ b/aws/rust-runtime/aws-inlineable/src/lib.rs @@ -10,6 +10,7 @@ //! This is _NOT_ intended to be an actual crate. It is a cargo project to solely to aid //! with local development of the SDK. +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_docs, rustdoc::missing_crate_level_docs, @@ -23,9 +24,11 @@ pub mod no_credentials; /// Support types required for adding presigning to an operation in a generated service. pub mod presigning; +// TODO(CrateReorganization): Delete the `old_presigning` module +pub mod old_presigning; -/// Special logic for handling S3's error responses. -pub mod s3_errors; +/// Special logic for extracting request IDs from S3's responses. +pub mod s3_request_id; /// Glacier-specific checksumming behavior pub mod glacier_checksums; diff --git a/aws/rust-runtime/aws-inlineable/src/old_presigning.rs b/aws/rust-runtime/aws-inlineable/src/old_presigning.rs new file mode 100644 index 00000000000..cf95c3901d1 --- /dev/null +++ b/aws/rust-runtime/aws-inlineable/src/old_presigning.rs @@ -0,0 +1,282 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Presigned request types and configuration. + +/// Presigning config and builder +pub mod config { + use std::fmt; + use std::time::{Duration, SystemTime}; + + const ONE_WEEK: Duration = Duration::from_secs(604800); + + /// Presigning config values required for creating a presigned request. + #[non_exhaustive] + #[derive(Debug, Clone)] + pub struct PresigningConfig { + start_time: SystemTime, + expires_in: Duration, + } + + impl PresigningConfig { + /// Creates a `PresigningConfig` with the given `expires_in` duration. + /// + /// The `expires_in` duration is the total amount of time the presigned request should + /// be valid for. Other config values are defaulted. + /// + /// Credential expiration time takes priority over the `expires_in` value. + /// If the credentials used to sign the request expire before the presigned request is + /// set to expire, then the presigned request will become invalid. + pub fn expires_in(expires_in: Duration) -> Result { + Self::builder().expires_in(expires_in).build() + } + + /// Creates a new builder for creating a `PresigningConfig`. + pub fn builder() -> Builder { + Builder::default() + } + + /// Returns the amount of time the presigned request should be valid for. + pub fn expires(&self) -> Duration { + self.expires_in + } + + /// Returns the start time. The presigned request will be valid between this and the end + /// time produced by adding the `expires()` value to it. + pub fn start_time(&self) -> SystemTime { + self.start_time + } + } + + #[derive(Debug)] + enum ErrorKind { + /// Presigned requests cannot be valid for longer than one week. + ExpiresInDurationTooLong, + + /// The `PresigningConfig` builder requires a value for `expires_in`. + ExpiresInRequired, + } + + /// `PresigningConfig` build errors. + #[derive(Debug)] + pub struct Error { + kind: ErrorKind, + } + + impl std::error::Error for Error {} + + impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + ErrorKind::ExpiresInDurationTooLong => { + write!(f, "`expires_in` must be no longer than one week") + } + ErrorKind::ExpiresInRequired => write!(f, "`expires_in` is required"), + } + } + } + + impl From for Error { + fn from(kind: ErrorKind) -> Self { + Self { kind } + } + } + + /// Builder used to create `PresigningConfig`. + #[non_exhaustive] + #[derive(Default, Debug)] + pub struct Builder { + start_time: Option, + expires_in: Option, + } + + impl Builder { + /// Sets the start time for the presigned request. + /// + /// The request will start to be valid at this time, and will cease to be valid after + /// the end time, which can be determined by adding the `expires_in` duration to this + /// start time. If not specified, this will default to the current time. + /// + /// Optional. + pub fn start_time(mut self, start_time: SystemTime) -> Self { + self.set_start_time(Some(start_time)); + self + } + + /// Sets the start time for the presigned request. + /// + /// The request will start to be valid at this time, and will cease to be valid after + /// the end time, which can be determined by adding the `expires_in` duration to this + /// start time. If not specified, this will default to the current time. + /// + /// Optional. + pub fn set_start_time(&mut self, start_time: Option) { + self.start_time = start_time; + } + + /// Sets how long the request should be valid after the `start_time` (which defaults + /// to the current time). + /// + /// Credential expiration time takes priority over the `expires_in` value. + /// If the credentials used to sign the request expire before the presigned request is + /// set to expire, then the presigned request will become invalid. + /// + /// Required. + pub fn expires_in(mut self, expires_in: Duration) -> Self { + self.set_expires_in(Some(expires_in)); + self + } + + /// Sets how long the request should be valid after the `start_time` (which defaults + /// to the current time). + /// + /// Credential expiration time takes priority over the `expires_in` value. + /// If the credentials used to sign the request expire before the presigned request is + /// set to expire, then the presigned request will become invalid. + /// + /// Required. + pub fn set_expires_in(&mut self, expires_in: Option) { + self.expires_in = expires_in; + } + + /// Builds the `PresigningConfig`. This will error if `expires_in` is not + /// given, or if it's longer than one week. + pub fn build(self) -> Result { + let expires_in = self.expires_in.ok_or(ErrorKind::ExpiresInRequired)?; + if expires_in > ONE_WEEK { + return Err(ErrorKind::ExpiresInDurationTooLong.into()); + } + Ok(PresigningConfig { + start_time: self.start_time.unwrap_or_else(SystemTime::now), + expires_in, + }) + } + } +} + +/// Presigned request +pub mod request { + use std::fmt::{Debug, Formatter}; + + /// Represents a presigned request. This only includes the HTTP request method, URI, and headers. + /// + /// **This struct has conversion convenience functions:** + /// + /// - [`PresignedRequest::to_http_request`][Self::to_http_request] returns an [`http::Request`](https://docs.rs/http/0.2.6/http/request/struct.Request.html) + /// - [`PresignedRequest::into`](#impl-From) returns an [`http::request::Builder`](https://docs.rs/http/0.2.6/http/request/struct.Builder.html) + #[non_exhaustive] + pub struct PresignedRequest(http::Request<()>); + + impl PresignedRequest { + pub(crate) fn new(inner: http::Request<()>) -> Self { + Self(inner) + } + + /// Returns the HTTP request method. + pub fn method(&self) -> &http::Method { + self.0.method() + } + + /// Returns the HTTP request URI. + pub fn uri(&self) -> &http::Uri { + self.0.uri() + } + + /// Returns any HTTP headers that need to go along with the request, except for `Host`, + /// which should be sent based on the endpoint in the URI by the HTTP client rather than + /// added directly. + pub fn headers(&self) -> &http::HeaderMap { + self.0.headers() + } + + /// Given a body, convert this `PresignedRequest` into an `http::Request` + pub fn to_http_request(self, body: B) -> Result, http::Error> { + let builder: http::request::Builder = self.into(); + + builder.body(body) + } + } + + impl Debug for PresignedRequest { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PresignedRequest") + .field("method", self.method()) + .field("uri", self.uri()) + .field("headers", self.headers()) + .finish() + } + } + + impl From for http::request::Builder { + fn from(req: PresignedRequest) -> Self { + let mut builder = http::request::Builder::new() + .uri(req.uri()) + .method(req.method()); + + if let Some(headers) = builder.headers_mut() { + *headers = req.headers().clone(); + } + + builder + } + } +} + +/// Tower middleware service for creating presigned requests +#[allow(dead_code)] +pub(crate) mod service { + use super::request::PresignedRequest; + use aws_smithy_http::operation; + use http::header::USER_AGENT; + use std::future::{ready, Ready}; + use std::marker::PhantomData; + use std::task::{Context, Poll}; + + /// Tower [`Service`](tower::Service) for generated a [`PresignedRequest`] from the AWS middleware. + #[derive(Default, Debug)] + #[non_exhaustive] + pub(crate) struct PresignedRequestService { + _phantom: PhantomData, + } + + // Required because of the derive Clone on MapRequestService. + // Manually implemented to avoid requiring errors to implement Clone. + impl Clone for PresignedRequestService { + fn clone(&self) -> Self { + Self { + _phantom: Default::default(), + } + } + } + + impl PresignedRequestService { + /// Creates a new `PresignedRequestService` + pub(crate) fn new() -> Self { + Self { + _phantom: Default::default(), + } + } + } + + impl tower::Service for PresignedRequestService { + type Response = PresignedRequest; + type Error = E; + type Future = Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: operation::Request) -> Self::Future { + let (mut req, _) = req.into_parts(); + + // Remove user agent headers since the request will not be executed by the AWS Rust SDK. + req.headers_mut().remove(USER_AGENT); + req.headers_mut().remove("X-Amz-User-Agent"); + + ready(Ok(PresignedRequest::new(req.map(|_| ())))) + } + } +} diff --git a/aws/rust-runtime/aws-inlineable/src/presigning.rs b/aws/rust-runtime/aws-inlineable/src/presigning.rs index 5a97a19902e..da0997d5918 100644 --- a/aws/rust-runtime/aws-inlineable/src/presigning.rs +++ b/aws/rust-runtime/aws-inlineable/src/presigning.rs @@ -5,229 +5,221 @@ //! Presigned request types and configuration. -/// Presigning config and builder -pub mod config { - use std::fmt; - use std::time::{Duration, SystemTime}; +use std::fmt; +use std::time::{Duration, SystemTime}; - const ONE_WEEK: Duration = Duration::from_secs(604800); +const ONE_WEEK: Duration = Duration::from_secs(604800); - /// Presigning config values required for creating a presigned request. - #[non_exhaustive] - #[derive(Debug, Clone)] - pub struct PresigningConfig { - start_time: SystemTime, - expires_in: Duration, - } +/// Presigning config values required for creating a presigned request. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct PresigningConfig { + start_time: SystemTime, + expires_in: Duration, +} - impl PresigningConfig { - /// Creates a `PresigningConfig` with the given `expires_in` duration. - /// - /// The `expires_in` duration is the total amount of time the presigned request should - /// be valid for. Other config values are defaulted. - /// - /// Credential expiration time takes priority over the `expires_in` value. - /// If the credentials used to sign the request expire before the presigned request is - /// set to expire, then the presigned request will become invalid. - pub fn expires_in(expires_in: Duration) -> Result { - Self::builder().expires_in(expires_in).build() - } +impl PresigningConfig { + /// Creates a `PresigningConfig` with the given `expires_in` duration. + /// + /// The `expires_in` duration is the total amount of time the presigned request should + /// be valid for. Other config values are defaulted. + /// + /// Credential expiration time takes priority over the `expires_in` value. + /// If the credentials used to sign the request expire before the presigned request is + /// set to expire, then the presigned request will become invalid. + pub fn expires_in(expires_in: Duration) -> Result { + Self::builder().expires_in(expires_in).build() + } - /// Creates a new builder for creating a `PresigningConfig`. - pub fn builder() -> Builder { - Builder::default() - } + /// Creates a new builder for creating a `PresigningConfig`. + pub fn builder() -> PresigningConfigBuilder { + PresigningConfigBuilder::default() + } - /// Returns the amount of time the presigned request should be valid for. - pub fn expires(&self) -> Duration { - self.expires_in - } + /// Returns the amount of time the presigned request should be valid for. + pub fn expires(&self) -> Duration { + self.expires_in + } - /// Returns the start time. The presigned request will be valid between this and the end - /// time produced by adding the `expires()` value to it. - pub fn start_time(&self) -> SystemTime { - self.start_time - } + /// Returns the start time. The presigned request will be valid between this and the end + /// time produced by adding the `expires()` value to it. + pub fn start_time(&self) -> SystemTime { + self.start_time } +} - #[derive(Debug)] - enum ErrorKind { - /// Presigned requests cannot be valid for longer than one week. - ExpiresInDurationTooLong, +#[derive(Debug)] +enum ErrorKind { + /// Presigned requests cannot be valid for longer than one week. + ExpiresInDurationTooLong, - /// The `PresigningConfig` builder requires a value for `expires_in`. - ExpiresInRequired, - } + /// The `PresigningConfig` builder requires a value for `expires_in`. + ExpiresInRequired, +} - /// `PresigningConfig` build errors. - #[derive(Debug)] - pub struct Error { - kind: ErrorKind, - } +/// `PresigningConfig` build errors. +#[derive(Debug)] +pub struct PresigningConfigError { + kind: ErrorKind, +} - impl std::error::Error for Error {} +impl std::error::Error for PresigningConfigError {} - impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.kind { - ErrorKind::ExpiresInDurationTooLong => { - write!(f, "`expires_in` must be no longer than one week") - } - ErrorKind::ExpiresInRequired => write!(f, "`expires_in` is required"), +impl fmt::Display for PresigningConfigError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + ErrorKind::ExpiresInDurationTooLong => { + write!(f, "`expires_in` must be no longer than one week") } + ErrorKind::ExpiresInRequired => write!(f, "`expires_in` is required"), } } +} - impl From for Error { - fn from(kind: ErrorKind) -> Self { - Self { kind } - } - } - - /// Builder used to create `PresigningConfig`. - #[non_exhaustive] - #[derive(Default, Debug)] - pub struct Builder { - start_time: Option, - expires_in: Option, +impl From for PresigningConfigError { + fn from(kind: ErrorKind) -> Self { + Self { kind } } +} - impl Builder { - /// Sets the start time for the presigned request. - /// - /// The request will start to be valid at this time, and will cease to be valid after - /// the end time, which can be determined by adding the `expires_in` duration to this - /// start time. If not specified, this will default to the current time. - /// - /// Optional. - pub fn start_time(mut self, start_time: SystemTime) -> Self { - self.set_start_time(Some(start_time)); - self - } - - /// Sets the start time for the presigned request. - /// - /// The request will start to be valid at this time, and will cease to be valid after - /// the end time, which can be determined by adding the `expires_in` duration to this - /// start time. If not specified, this will default to the current time. - /// - /// Optional. - pub fn set_start_time(&mut self, start_time: Option) { - self.start_time = start_time; - } - - /// Sets how long the request should be valid after the `start_time` (which defaults - /// to the current time). - /// - /// Credential expiration time takes priority over the `expires_in` value. - /// If the credentials used to sign the request expire before the presigned request is - /// set to expire, then the presigned request will become invalid. - /// - /// Required. - pub fn expires_in(mut self, expires_in: Duration) -> Self { - self.set_expires_in(Some(expires_in)); - self - } - - /// Sets how long the request should be valid after the `start_time` (which defaults - /// to the current time). - /// - /// Credential expiration time takes priority over the `expires_in` value. - /// If the credentials used to sign the request expire before the presigned request is - /// set to expire, then the presigned request will become invalid. - /// - /// Required. - pub fn set_expires_in(&mut self, expires_in: Option) { - self.expires_in = expires_in; - } +/// Builder used to create `PresigningConfig`. +#[non_exhaustive] +#[derive(Default, Debug)] +pub struct PresigningConfigBuilder { + start_time: Option, + expires_in: Option, +} - /// Builds the `PresigningConfig`. This will error if `expires_in` is not - /// given, or if it's longer than one week. - pub fn build(self) -> Result { - let expires_in = self.expires_in.ok_or(ErrorKind::ExpiresInRequired)?; - if expires_in > ONE_WEEK { - return Err(ErrorKind::ExpiresInDurationTooLong.into()); - } - Ok(PresigningConfig { - start_time: self.start_time.unwrap_or_else(SystemTime::now), - expires_in, - }) - } +impl PresigningConfigBuilder { + /// Sets the start time for the presigned request. + /// + /// The request will start to be valid at this time, and will cease to be valid after + /// the end time, which can be determined by adding the `expires_in` duration to this + /// start time. If not specified, this will default to the current time. + /// + /// Optional. + pub fn start_time(mut self, start_time: SystemTime) -> Self { + self.set_start_time(Some(start_time)); + self } -} -/// Presigned request -pub mod request { - use std::fmt::{Debug, Formatter}; + /// Sets the start time for the presigned request. + /// + /// The request will start to be valid at this time, and will cease to be valid after + /// the end time, which can be determined by adding the `expires_in` duration to this + /// start time. If not specified, this will default to the current time. + /// + /// Optional. + pub fn set_start_time(&mut self, start_time: Option) { + self.start_time = start_time; + } - /// Represents a presigned request. This only includes the HTTP request method, URI, and headers. + /// Sets how long the request should be valid after the `start_time` (which defaults + /// to the current time). /// - /// **This struct has conversion convenience functions:** + /// Credential expiration time takes priority over the `expires_in` value. + /// If the credentials used to sign the request expire before the presigned request is + /// set to expire, then the presigned request will become invalid. /// - /// - [`PresignedRequest::to_http_request`][Self::to_http_request] returns an [`http::Request`](https://docs.rs/http/0.2.6/http/request/struct.Request.html) - /// - [`PresignedRequest::into`](#impl-From) returns an [`http::request::Builder`](https://docs.rs/http/0.2.6/http/request/struct.Builder.html) - #[non_exhaustive] - pub struct PresignedRequest(http::Request<()>); + /// Required. + pub fn expires_in(mut self, expires_in: Duration) -> Self { + self.set_expires_in(Some(expires_in)); + self + } - impl PresignedRequest { - pub(crate) fn new(inner: http::Request<()>) -> Self { - Self(inner) - } + /// Sets how long the request should be valid after the `start_time` (which defaults + /// to the current time). + /// + /// Credential expiration time takes priority over the `expires_in` value. + /// If the credentials used to sign the request expire before the presigned request is + /// set to expire, then the presigned request will become invalid. + /// + /// Required. + pub fn set_expires_in(&mut self, expires_in: Option) { + self.expires_in = expires_in; + } - /// Returns the HTTP request method. - pub fn method(&self) -> &http::Method { - self.0.method() + /// Builds the `PresigningConfig`. This will error if `expires_in` is not + /// given, or if it's longer than one week. + pub fn build(self) -> Result { + let expires_in = self.expires_in.ok_or(ErrorKind::ExpiresInRequired)?; + if expires_in > ONE_WEEK { + return Err(ErrorKind::ExpiresInDurationTooLong.into()); } + Ok(PresigningConfig { + start_time: self.start_time.unwrap_or_else(SystemTime::now), + expires_in, + }) + } +} - /// Returns the HTTP request URI. - pub fn uri(&self) -> &http::Uri { - self.0.uri() - } +/// Represents a presigned request. This only includes the HTTP request method, URI, and headers. +/// +/// **This struct has conversion convenience functions:** +/// +/// - [`PresignedRequest::to_http_request`][Self::to_http_request] returns an [`http::Request`](https://docs.rs/http/0.2.6/http/request/struct.Request.html) +/// - [`PresignedRequest::into`](#impl-From) returns an [`http::request::Builder`](https://docs.rs/http/0.2.6/http/request/struct.Builder.html) +#[non_exhaustive] +pub struct PresignedRequest(http::Request<()>); + +impl PresignedRequest { + pub(crate) fn new(inner: http::Request<()>) -> Self { + Self(inner) + } - /// Returns any HTTP headers that need to go along with the request, except for `Host`, - /// which should be sent based on the endpoint in the URI by the HTTP client rather than - /// added directly. - pub fn headers(&self) -> &http::HeaderMap { - self.0.headers() - } + /// Returns the HTTP request method. + pub fn method(&self) -> &http::Method { + self.0.method() + } - /// Given a body, convert this `PresignedRequest` into an `http::Request` - pub fn to_http_request(self, body: B) -> Result, http::Error> { - let builder: http::request::Builder = self.into(); + /// Returns the HTTP request URI. + pub fn uri(&self) -> &http::Uri { + self.0.uri() + } - builder.body(body) - } + /// Returns any HTTP headers that need to go along with the request, except for `Host`, + /// which should be sent based on the endpoint in the URI by the HTTP client rather than + /// added directly. + pub fn headers(&self) -> &http::HeaderMap { + self.0.headers() } - impl Debug for PresignedRequest { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PresignedRequest") - .field("method", self.method()) - .field("uri", self.uri()) - .field("headers", self.headers()) - .finish() - } + /// Given a body, convert this `PresignedRequest` into an `http::Request` + pub fn to_http_request(self, body: B) -> Result, http::Error> { + let builder: http::request::Builder = self.into(); + + builder.body(body) } +} - impl From for http::request::Builder { - fn from(req: PresignedRequest) -> Self { - let mut builder = http::request::Builder::new() - .uri(req.uri()) - .method(req.method()); +impl fmt::Debug for PresignedRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PresignedRequest") + .field("method", self.method()) + .field("uri", self.uri()) + .field("headers", self.headers()) + .finish() + } +} - if let Some(headers) = builder.headers_mut() { - *headers = req.headers().clone(); - } +impl From for http::request::Builder { + fn from(req: PresignedRequest) -> Self { + let mut builder = http::request::Builder::new() + .uri(req.uri()) + .method(req.method()); - builder + if let Some(headers) = builder.headers_mut() { + *headers = req.headers().clone(); } + + builder } } /// Tower middleware service for creating presigned requests #[allow(dead_code)] pub(crate) mod service { - use crate::presigning::request::PresignedRequest; + use super::PresignedRequest; use aws_smithy_http::operation; use http::header::USER_AGENT; use std::future::{ready, Ready}; diff --git a/aws/rust-runtime/aws-inlineable/src/s3_errors.rs b/aws/rust-runtime/aws-inlineable/src/s3_errors.rs deleted file mode 100644 index ca15ddc42bb..00000000000 --- a/aws/rust-runtime/aws-inlineable/src/s3_errors.rs +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use http::{HeaderMap, HeaderValue}; - -const EXTENDED_REQUEST_ID: &str = "s3_extended_request_id"; - -/// S3-specific service error additions. -pub trait ErrorExt { - /// Returns the S3 Extended Request ID necessary when contacting AWS Support. - /// Read more at . - fn extended_request_id(&self) -> Option<&str>; -} - -impl ErrorExt for aws_smithy_types::Error { - fn extended_request_id(&self) -> Option<&str> { - self.extra(EXTENDED_REQUEST_ID) - } -} - -/// Parses the S3 Extended Request ID out of S3 error response headers. -pub fn parse_extended_error( - error: aws_smithy_types::Error, - headers: &HeaderMap, -) -> aws_smithy_types::Error { - let mut builder = error.into_builder(); - let host_id = headers - .get("x-amz-id-2") - .and_then(|header_value| header_value.to_str().ok()); - if let Some(host_id) = host_id { - builder.custom(EXTENDED_REQUEST_ID, host_id); - } - builder.build() -} - -#[cfg(test)] -mod test { - use crate::s3_errors::{parse_extended_error, ErrorExt}; - - #[test] - fn add_error_fields() { - let resp = http::Response::builder() - .header( - "x-amz-id-2", - "eftixk72aD6Ap51TnqcoF8eFidJG9Z/2mkiDFu8yU9AS1ed4OpIszj7UDNEHGran", - ) - .status(400) - .body("") - .unwrap(); - let error = aws_smithy_types::Error::builder() - .message("123") - .request_id("456") - .build(); - - let error = parse_extended_error(error, resp.headers()); - assert_eq!( - error - .extended_request_id() - .expect("extended request id should be set"), - "eftixk72aD6Ap51TnqcoF8eFidJG9Z/2mkiDFu8yU9AS1ed4OpIszj7UDNEHGran" - ); - } - - #[test] - fn handle_missing_header() { - let resp = http::Response::builder().status(400).body("").unwrap(); - let error = aws_smithy_types::Error::builder() - .message("123") - .request_id("456") - .build(); - - let error = parse_extended_error(error, resp.headers()); - assert_eq!(error.extended_request_id(), None); - } -} diff --git a/aws/rust-runtime/aws-inlineable/src/s3_request_id.rs b/aws/rust-runtime/aws-inlineable/src/s3_request_id.rs new file mode 100644 index 00000000000..909dcbcd7aa --- /dev/null +++ b/aws/rust-runtime/aws-inlineable/src/s3_request_id.rs @@ -0,0 +1,178 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_client::SdkError; +use aws_smithy_http::http::HttpHeaders; +use aws_smithy_http::operation; +use aws_smithy_types::error::metadata::{ + Builder as ErrorMetadataBuilder, ErrorMetadata, ProvideErrorMetadata, +}; +use aws_smithy_types::error::Unhandled; +use http::{HeaderMap, HeaderValue}; + +const EXTENDED_REQUEST_ID: &str = "s3_extended_request_id"; + +/// Trait to retrieve the S3-specific extended request ID +/// +/// Read more at . +pub trait RequestIdExt { + /// Returns the S3 Extended Request ID necessary when contacting AWS Support. + fn extended_request_id(&self) -> Option<&str>; +} + +impl RequestIdExt for SdkError +where + R: HttpHeaders, +{ + fn extended_request_id(&self) -> Option<&str> { + match self { + Self::ResponseError(err) => extract_extended_request_id(err.raw().http_headers()), + Self::ServiceError(err) => extract_extended_request_id(err.raw().http_headers()), + _ => None, + } + } +} + +impl RequestIdExt for ErrorMetadata { + fn extended_request_id(&self) -> Option<&str> { + self.extra(EXTENDED_REQUEST_ID) + } +} + +impl RequestIdExt for Unhandled { + fn extended_request_id(&self) -> Option<&str> { + self.meta().extended_request_id() + } +} + +impl RequestIdExt for operation::Response { + fn extended_request_id(&self) -> Option<&str> { + extract_extended_request_id(self.http().headers()) + } +} + +impl RequestIdExt for http::Response { + fn extended_request_id(&self) -> Option<&str> { + extract_extended_request_id(self.headers()) + } +} + +impl RequestIdExt for Result +where + O: RequestIdExt, + E: RequestIdExt, +{ + fn extended_request_id(&self) -> Option<&str> { + match self { + Ok(ok) => ok.extended_request_id(), + Err(err) => err.extended_request_id(), + } + } +} + +/// Applies the extended request ID to a generic error builder +#[doc(hidden)] +pub fn apply_extended_request_id( + builder: ErrorMetadataBuilder, + headers: &HeaderMap, +) -> ErrorMetadataBuilder { + if let Some(extended_request_id) = extract_extended_request_id(headers) { + builder.custom(EXTENDED_REQUEST_ID, extended_request_id) + } else { + builder + } +} + +/// Extracts the S3 Extended Request ID from HTTP response headers +fn extract_extended_request_id(headers: &HeaderMap) -> Option<&str> { + headers + .get("x-amz-id-2") + .and_then(|value| value.to_str().ok()) +} + +#[cfg(test)] +mod test { + use super::*; + use aws_smithy_client::SdkError; + use aws_smithy_http::body::SdkBody; + use http::Response; + + #[test] + fn handle_missing_header() { + let resp = http::Response::builder().status(400).body("").unwrap(); + let mut builder = aws_smithy_types::Error::builder().message("123"); + builder = apply_extended_request_id(builder, resp.headers()); + assert_eq!(builder.build().extended_request_id(), None); + } + + #[test] + fn test_extended_request_id_sdk_error() { + let without_extended_request_id = + || operation::Response::new(Response::builder().body(SdkBody::empty()).unwrap()); + let with_extended_request_id = || { + operation::Response::new( + Response::builder() + .header("x-amz-id-2", HeaderValue::from_static("some-request-id")) + .body(SdkBody::empty()) + .unwrap(), + ) + }; + assert_eq!( + None, + SdkError::<(), _>::response_error("test", without_extended_request_id()) + .extended_request_id() + ); + assert_eq!( + Some("some-request-id"), + SdkError::<(), _>::response_error("test", with_extended_request_id()) + .extended_request_id() + ); + assert_eq!( + None, + SdkError::service_error((), without_extended_request_id()).extended_request_id() + ); + assert_eq!( + Some("some-request-id"), + SdkError::service_error((), with_extended_request_id()).extended_request_id() + ); + } + + #[test] + fn test_extract_extended_request_id() { + let mut headers = HeaderMap::new(); + assert_eq!(None, extract_extended_request_id(&headers)); + + headers.append("x-amz-id-2", HeaderValue::from_static("some-request-id")); + assert_eq!( + Some("some-request-id"), + extract_extended_request_id(&headers) + ); + } + + #[test] + fn test_apply_extended_request_id() { + let mut headers = HeaderMap::new(); + assert_eq!( + ErrorMetadata::builder().build(), + apply_extended_request_id(ErrorMetadata::builder(), &headers).build(), + ); + + headers.append("x-amz-id-2", HeaderValue::from_static("some-request-id")); + assert_eq!( + ErrorMetadata::builder() + .custom(EXTENDED_REQUEST_ID, "some-request-id") + .build(), + apply_extended_request_id(ErrorMetadata::builder(), &headers).build(), + ); + } + + #[test] + fn test_error_metadata_extended_request_id_impl() { + let err = ErrorMetadata::builder() + .custom(EXTENDED_REQUEST_ID, "some-request-id") + .build(); + assert_eq!(Some("some-request-id"), err.extended_request_id()); + } +} diff --git a/aws/rust-runtime/aws-sig-auth/src/lib.rs b/aws/rust-runtime/aws-sig-auth/src/lib.rs index 61ae88e81c4..643b3ff216b 100644 --- a/aws/rust-runtime/aws-sig-auth/src/lib.rs +++ b/aws/rust-runtime/aws-sig-auth/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] + //! AWS Signature Authentication Package //! //! This crate may be used to generate presigned URLs for unmodeled behavior such as `rds-iam-token` diff --git a/aws/rust-runtime/aws-sigv4/Cargo.toml b/aws/rust-runtime/aws-sigv4/Cargo.toml index ad1570bb863..90cf5f42f30 100644 --- a/aws/rust-runtime/aws-sigv4/Cargo.toml +++ b/aws/rust-runtime/aws-sigv4/Cargo.toml @@ -32,7 +32,7 @@ sha2 = "0.10" criterion = "0.4" bytes = "1" httparse = "1.5" -pretty_assertions = "1.0" +pretty_assertions = "1.3" proptest = "1" time = { version = "0.3.4", features = ["parsing"] } diff --git a/aws/rust-runtime/aws-sigv4/src/http_request/canonical_request.rs b/aws/rust-runtime/aws-sigv4/src/http_request/canonical_request.rs index 172f2cbce77..08ac6c8f544 100644 --- a/aws/rust-runtime/aws-sigv4/src/http_request/canonical_request.rs +++ b/aws/rust-runtime/aws-sigv4/src/http_request/canonical_request.rs @@ -5,7 +5,6 @@ use crate::date_time::{format_date, format_date_time}; use crate::http_request::error::CanonicalRequestError; -use crate::http_request::query_writer::QueryWriter; use crate::http_request::settings::UriPathNormalizationMode; use crate::http_request::sign::SignableRequest; use crate::http_request::uri_path_normalization::normalize_uri_path; @@ -13,6 +12,7 @@ use crate::http_request::url_escape::percent_encode_path; use crate::http_request::PercentEncodingMode; use crate::http_request::{PayloadChecksumKind, SignableBody, SignatureLocation, SigningParams}; use crate::sign::sha256_hex_string; +use aws_smithy_http::query_writer::QueryWriter; use http::header::{AsHeaderName, HeaderName, HOST}; use http::{HeaderMap, HeaderValue, Method, Uri}; use std::borrow::Cow; @@ -519,13 +519,13 @@ mod tests { use crate::http_request::canonical_request::{ normalize_header_value, trim_all, CanonicalRequest, SigningScope, StringToSign, }; - use crate::http_request::query_writer::QueryWriter; use crate::http_request::test::{test_canonical_request, test_request, test_sts}; use crate::http_request::{ PayloadChecksumKind, SignableBody, SignableRequest, SigningSettings, }; use crate::http_request::{SignatureLocation, SigningParams}; use crate::sign::sha256_hex_string; + use aws_smithy_http::query_writer::QueryWriter; use http::Uri; use http::{header::HeaderName, HeaderValue}; use pretty_assertions::assert_eq; diff --git a/aws/rust-runtime/aws-sigv4/src/http_request/mod.rs b/aws/rust-runtime/aws-sigv4/src/http_request/mod.rs index c93a052887c..543d58fb2dd 100644 --- a/aws/rust-runtime/aws-sigv4/src/http_request/mod.rs +++ b/aws/rust-runtime/aws-sigv4/src/http_request/mod.rs @@ -43,7 +43,6 @@ mod canonical_request; mod error; -mod query_writer; mod settings; mod sign; mod uri_path_normalization; diff --git a/aws/rust-runtime/aws-sigv4/src/http_request/sign.rs b/aws/rust-runtime/aws-sigv4/src/http_request/sign.rs index 69f5c00819a..dbe76242947 100644 --- a/aws/rust-runtime/aws-sigv4/src/http_request/sign.rs +++ b/aws/rust-runtime/aws-sigv4/src/http_request/sign.rs @@ -8,10 +8,10 @@ use super::{PayloadChecksumKind, SignatureLocation}; use crate::http_request::canonical_request::header; use crate::http_request::canonical_request::param; use crate::http_request::canonical_request::{CanonicalRequest, StringToSign, HMAC_256}; -use crate::http_request::query_writer::QueryWriter; use crate::http_request::SigningParams; use crate::sign::{calculate_signature, generate_signing_key, sha256_hex_string}; use crate::SigningOutput; +use aws_smithy_http::query_writer::QueryWriter; use http::header::HeaderValue; use http::{HeaderMap, Method, Uri}; use std::borrow::Cow; @@ -380,9 +380,11 @@ mod tests { #[test] fn test_sign_vanilla_with_query_params() { - let mut settings = SigningSettings::default(); - settings.signature_location = SignatureLocation::QueryParams; - settings.expires_in = Some(Duration::from_secs(35)); + let settings = SigningSettings { + signature_location: SignatureLocation::QueryParams, + expires_in: Some(Duration::from_secs(35)), + ..Default::default() + }; let params = SigningParams { access_key: "AKIDEXAMPLE", secret_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", diff --git a/aws/rust-runtime/aws-sigv4/src/http_request/url_escape.rs b/aws/rust-runtime/aws-sigv4/src/http_request/url_escape.rs index 651c62c44ac..d7656c355b2 100644 --- a/aws/rust-runtime/aws-sigv4/src/http_request/url_escape.rs +++ b/aws/rust-runtime/aws-sigv4/src/http_request/url_escape.rs @@ -3,11 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_smithy_http::{label, query}; - -pub(super) fn percent_encode_query(value: &str) -> String { - query::fmt_string(value) -} +use aws_smithy_http::label; pub(super) fn percent_encode_path(value: &str) -> String { label::fmt_string(value, label::EncodingStrategy::Greedy) diff --git a/aws/rust-runtime/aws-sigv4/src/lib.rs b/aws/rust-runtime/aws-sigv4/src/lib.rs index be2552f5781..14d7a7b5fd2 100644 --- a/aws/rust-runtime/aws-sigv4/src/lib.rs +++ b/aws/rust-runtime/aws-sigv4/src/lib.rs @@ -6,6 +6,7 @@ //! Provides functions for calculating Sigv4 signing keys, signatures, and //! optional utilities for signing HTTP requests and Event Stream messages. +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_docs, rustdoc::missing_crate_level_docs, diff --git a/aws/rust-runtime/aws-types/src/lib.rs b/aws/rust-runtime/aws-types/src/lib.rs index 795ff08849b..600adc83cc2 100644 --- a/aws/rust-runtime/aws-types/src/lib.rs +++ b/aws/rust-runtime/aws-types/src/lib.rs @@ -5,6 +5,7 @@ //! Cross-service types for the AWS SDK. +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_docs, rustdoc::missing_crate_level_docs, diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt index 75d6fd7cbdb..3e55fd9ce2e 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt @@ -9,12 +9,15 @@ import software.amazon.smithy.rust.codegen.client.smithy.customizations.DocsRsMe import software.amazon.smithy.rust.codegen.client.smithy.customizations.DocsRsMetadataSettings import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator +import software.amazon.smithy.rustsdk.customize.DisabledAuthDecorator import software.amazon.smithy.rustsdk.customize.apigateway.ApiGatewayDecorator -import software.amazon.smithy.rustsdk.customize.auth.DisabledAuthDecorator +import software.amazon.smithy.rustsdk.customize.applyDecorators import software.amazon.smithy.rustsdk.customize.ec2.Ec2Decorator import software.amazon.smithy.rustsdk.customize.glacier.GlacierDecorator +import software.amazon.smithy.rustsdk.customize.onlyApplyTo import software.amazon.smithy.rustsdk.customize.route53.Route53Decorator import software.amazon.smithy.rustsdk.customize.s3.S3Decorator +import software.amazon.smithy.rustsdk.customize.s3.S3ExtendedRequestIdDecorator import software.amazon.smithy.rustsdk.customize.s3control.S3ControlDecorator import software.amazon.smithy.rustsdk.customize.sts.STSDecorator import software.amazon.smithy.rustsdk.endpoints.AwsEndpointDecorator @@ -23,41 +26,49 @@ import software.amazon.smithy.rustsdk.endpoints.OperationInputTestDecorator val DECORATORS: List = listOf( // General AWS Decorators - CredentialsCacheDecorator(), - CredentialsProviderDecorator(), - RegionDecorator(), - AwsEndpointDecorator(), - UserAgentDecorator(), - SigV4SigningDecorator(), - HttpRequestChecksumDecorator(), - HttpResponseChecksumDecorator(), - RetryClassifierDecorator(), - IntegrationTestDecorator(), - AwsFluentClientDecorator(), - CrateLicenseDecorator(), - SdkConfigDecorator(), - ServiceConfigDecorator(), - AwsPresigningDecorator(), - AwsReadmeDecorator(), - HttpConnectorDecorator(), - AwsEndpointsStdLib(), - *PromotedBuiltInsDecorators, - GenericSmithySdkConfigSettings(), - OperationInputTestDecorator(), + listOf( + CredentialsCacheDecorator(), + CredentialsProviderDecorator(), + RegionDecorator(), + AwsEndpointDecorator(), + UserAgentDecorator(), + SigV4SigningDecorator(), + HttpRequestChecksumDecorator(), + HttpResponseChecksumDecorator(), + RetryClassifierDecorator(), + IntegrationTestDecorator(), + AwsFluentClientDecorator(), + CrateLicenseDecorator(), + SdkConfigDecorator(), + ServiceConfigDecorator(), + AwsPresigningDecorator(), + AwsReadmeDecorator(), + HttpConnectorDecorator(), + AwsEndpointsStdLib(), + *PromotedBuiltInsDecorators, + GenericSmithySdkConfigSettings(), + OperationInputTestDecorator(), + AwsRequestIdDecorator(), + DisabledAuthDecorator(), + ), // Service specific decorators - ApiGatewayDecorator(), - DisabledAuthDecorator(), - Ec2Decorator(), - GlacierDecorator(), - Route53Decorator(), - S3Decorator(), - S3ControlDecorator(), - STSDecorator(), + ApiGatewayDecorator().onlyApplyTo("com.amazonaws.apigateway#BackplaneControlService"), + Ec2Decorator().onlyApplyTo("com.amazonaws.ec2#AmazonEC2"), + GlacierDecorator().onlyApplyTo("com.amazonaws.glacier#Glacier"), + Route53Decorator().onlyApplyTo("com.amazonaws.route53#AWSDnsV20130401"), + "com.amazonaws.s3#AmazonS3".applyDecorators( + S3Decorator(), + S3ExtendedRequestIdDecorator(), + ), + S3ControlDecorator().onlyApplyTo("com.amazonaws.s3control#AWSS3ControlServiceV20180820"), + STSDecorator().onlyApplyTo("com.amazonaws.sts#AWSSecurityTokenServiceV20110615"), // Only build docs-rs for linux to reduce load on docs.rs - DocsRsMetadataDecorator(DocsRsMetadataSettings(targets = listOf("x86_64-unknown-linux-gnu"), allFeatures = true)), -) + listOf( + DocsRsMetadataDecorator(DocsRsMetadataSettings(targets = listOf("x86_64-unknown-linux-gnu"), allFeatures = true)), + ), +).flatten() class AwsCodegenDecorator : CombinedClientCodegenDecorator(DECORATORS) { override val name: String = "AwsSdkCodegenDecorator" diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt index c06b58994bd..e240d642700 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt @@ -9,14 +9,14 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.TitleTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.smithy.generators.client.CustomizableOperationGenerator +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedCustomizeModule import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerics import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.rustlang.GenericTypeArg import software.amazon.smithy.rust.codegen.core.rustlang.RustGenerics @@ -76,7 +76,7 @@ private class AwsClientGenerics(private val types: Types) : FluentClientGenerics override fun sendBounds( operation: Symbol, operationOutput: Symbol, - operationError: RuntimeType, + operationError: Symbol, retryClassifier: RuntimeType, ): Writable = writable { } @@ -96,17 +96,18 @@ class AwsFluentClientDecorator : ClientCodegenDecorator { val generics = AwsClientGenerics(types) FluentClientGenerator( codegenContext, - generics, + reexportSmithyClientBuilder = false, + generics = generics, customizations = listOf( - AwsPresignedFluentBuilderMethod(runtimeConfig), + AwsPresignedFluentBuilderMethod(codegenContext, runtimeConfig), AwsFluentClientDocs(codegenContext), ), retryClassifier = AwsRuntimeType.awsHttp(runtimeConfig).resolve("retry::AwsResponseRetryClassifier"), ).render(rustCrate) - rustCrate.withModule(CustomizableOperationGenerator.CustomizeModule) { + rustCrate.withModule(codegenContext.featureGatedCustomizeModule()) { renderCustomizableOperationSendMethod(runtimeConfig, generics, this) } - rustCrate.withModule(FluentClientGenerator.clientModule) { + rustCrate.withModule(ClientRustModule.client) { AwsFluentClientExtensions(types).render(this) } val awsSmithyClient = "aws-smithy-client" @@ -228,7 +229,7 @@ private class AwsFluentClientDocs(private val codegenContext: CodegenContext) : private val serviceShape = codegenContext.serviceShape private val crateName = codegenContext.moduleUseName() private val codegenScope = - arrayOf("aws_config" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).copy(scope = DependencyScope.Dev).toType()) + arrayOf("aws_config" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency().toType()) // If no `aws-config` version is provided, assume that docs referencing `aws-config` cannot be given. // Also, STS and SSO must NOT reference `aws-config` since that would create a circular dependency. diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt index c733138076d..1ecca4651e4 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt @@ -30,12 +30,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.util.cloneOperation import software.amazon.smithy.rust.codegen.core.util.expectTrait @@ -129,23 +127,23 @@ class AwsPresigningDecorator internal constructor( } class AwsInputPresignedMethod( - private val codegenContext: CodegenContext, + private val codegenContext: ClientCodegenContext, private val operationShape: OperationShape, ) : OperationCustomization() { private val runtimeConfig = codegenContext.runtimeConfig private val symbolProvider = codegenContext.symbolProvider - private val codegenScope = arrayOf( - "Error" to AwsRuntimeType.Presigning.resolve("config::Error"), - "PresignedRequest" to AwsRuntimeType.Presigning.resolve("request::PresignedRequest"), - "PresignedRequestService" to AwsRuntimeType.Presigning.resolve("service::PresignedRequestService"), - "PresigningConfig" to AwsRuntimeType.Presigning.resolve("config::PresigningConfig"), - "SdkError" to RuntimeType.sdkError(runtimeConfig), - "aws_sigv4" to AwsRuntimeType.awsSigv4(runtimeConfig), - "sig_auth" to AwsRuntimeType.awsSigAuth(runtimeConfig), - "tower" to RuntimeType.Tower, - "Middleware" to runtimeConfig.defaultMiddleware(), - ) + private val codegenScope = ( + presigningTypes(codegenContext) + listOf( + "PresignedRequestService" to AwsRuntimeType.presigning(codegenContext) + .resolve("service::PresignedRequestService"), + "SdkError" to RuntimeType.sdkError(runtimeConfig), + "aws_sigv4" to AwsRuntimeType.awsSigv4(runtimeConfig), + "sig_auth" to AwsRuntimeType.awsSigAuth(runtimeConfig), + "tower" to RuntimeType.Tower, + "Middleware" to runtimeConfig.defaultMiddleware(), + ) + ).toTypedArray() override fun section(section: OperationSection): Writable = writable { @@ -155,7 +153,7 @@ class AwsInputPresignedMethod( } private fun RustWriter.writeInputPresignedMethod(section: OperationSection.InputImpl) { - val operationError = operationShape.errorSymbol(symbolProvider) + val operationError = symbolProvider.symbolForOperationError(operationShape) val presignableOp = PRESIGNABLE_OPERATIONS.getValue(operationShape.id) val makeOperationOp = if (presignableOp.hasModelTransforms()) { @@ -242,14 +240,15 @@ class AwsInputPresignedMethod( } class AwsPresignedFluentBuilderMethod( + codegenContext: ClientCodegenContext, runtimeConfig: RuntimeConfig, ) : FluentClientCustomization() { - private val codegenScope = arrayOf( - "Error" to AwsRuntimeType.Presigning.resolve("config::Error"), - "PresignedRequest" to AwsRuntimeType.Presigning.resolve("request::PresignedRequest"), - "PresigningConfig" to AwsRuntimeType.Presigning.resolve("config::PresigningConfig"), - "SdkError" to RuntimeType.sdkError(runtimeConfig), - ) + private val codegenScope = ( + presigningTypes(codegenContext) + arrayOf( + "Error" to AwsRuntimeType.presigning(codegenContext).resolve("config::Error"), + "SdkError" to RuntimeType.sdkError(runtimeConfig), + ) + ).toTypedArray() override fun section(section: FluentClientSection): Writable = writable { @@ -365,3 +364,15 @@ private fun RustWriter.documentPresignedMethod(hasConfigArg: Boolean) { """, ) } + +private fun presigningTypes(codegenContext: ClientCodegenContext): List> = + when (codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> listOf( + "PresignedRequest" to AwsRuntimeType.presigning(codegenContext).resolve("PresignedRequest"), + "PresigningConfig" to AwsRuntimeType.presigning(codegenContext).resolve("PresigningConfig"), + ) + else -> listOf( + "PresignedRequest" to AwsRuntimeType.presigning(codegenContext).resolve("request::PresignedRequest"), + "PresigningConfig" to AwsRuntimeType.presigning(codegenContext).resolve("config::PresigningConfig"), + ) + } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRequestIdDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRequestIdDecorator.kt new file mode 100644 index 00000000000..0b496ce2c9b --- /dev/null +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRequestIdDecorator.kt @@ -0,0 +1,29 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rustsdk + +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType + +/** + * Customizes response parsing logic to add AWS request IDs to error metadata and outputs + */ +class AwsRequestIdDecorator : BaseRequestIdDecorator() { + override val name: String = "AwsRequestIdDecorator" + override val order: Byte = 0 + + override val fieldName: String = "request_id" + override val accessorFunctionName: String = "request_id" + + private fun requestIdModule(codegenContext: ClientCodegenContext): RuntimeType = + AwsRuntimeType.awsHttp(codegenContext.runtimeConfig).resolve("request_id") + + override fun accessorTrait(codegenContext: ClientCodegenContext): RuntimeType = + requestIdModule(codegenContext).resolve("RequestId") + + override fun applyToError(codegenContext: ClientCodegenContext): RuntimeType = + requestIdModule(codegenContext).resolve("apply_request_id") +} diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRuntimeType.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRuntimeType.kt index f4b05b7bafd..e07607a7224 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRuntimeType.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsRuntimeType.kt @@ -6,8 +6,8 @@ package software.amazon.smithy.rustsdk import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency -import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeCrateLocation @@ -42,10 +42,17 @@ fun RuntimeConfig.awsRoot(): RuntimeCrateLocation { } object AwsRuntimeType { - val S3Errors by lazy { RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("s3_errors")) } - val Presigning by lazy { - RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("presigning", visibility = Visibility.PUBLIC)) - } + fun presigning(codegenContext: ClientCodegenContext): RuntimeType = + when (codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("presigning", visibility = Visibility.PUBLIC)) + else -> RuntimeType.forInlineDependency( + InlineAwsDependency.forRustFileAs( + file = "old_presigning", + moduleName = "presigning", + visibility = Visibility.PUBLIC, + ), + ) + } fun RuntimeConfig.defaultMiddleware() = RuntimeType.forInlineDependency( InlineAwsDependency.forRustFile( @@ -63,7 +70,7 @@ object AwsRuntimeType { fun awsCredentialTypes(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsCredentialTypes(runtimeConfig).toType() fun awsCredentialTypesTestUtil(runtimeConfig: RuntimeConfig) = - AwsCargoDependency.awsCredentialTypes(runtimeConfig).copy(scope = DependencyScope.Dev).withFeature("test-util").toType() + AwsCargoDependency.awsCredentialTypes(runtimeConfig).toDevDependency().withFeature("test-util").toType() fun awsEndpoint(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsEndpoint(runtimeConfig).toType() fun awsHttp(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsHttp(runtimeConfig).toType() diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt new file mode 100644 index 00000000000..b70bf419bb0 --- /dev/null +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/BaseRequestIdDecorator.kt @@ -0,0 +1,226 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rustsdk + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule +import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorSection +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderSection +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureSection +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplSection +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait + +/** + * Base customization for adding a request ID (or extended request ID) to outputs and errors. + */ +abstract class BaseRequestIdDecorator : ClientCodegenDecorator { + abstract val accessorFunctionName: String + abstract val fieldName: String + abstract fun accessorTrait(codegenContext: ClientCodegenContext): RuntimeType + abstract fun applyToError(codegenContext: ClientCodegenContext): RuntimeType + + override fun operationCustomizations( + codegenContext: ClientCodegenContext, + operation: OperationShape, + baseCustomizations: List, + ): List = baseCustomizations + listOf(RequestIdOperationCustomization(codegenContext)) + + override fun errorCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = + baseCustomizations + listOf(RequestIdErrorCustomization(codegenContext)) + + override fun errorImplCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations + listOf(RequestIdErrorImplCustomization(codegenContext)) + + override fun structureCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations + listOf(RequestIdStructureCustomization(codegenContext)) + + override fun builderCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations + listOf(RequestIdBuilderCustomization()) + + override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + rustCrate.withModule( + when (codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> ClientRustModule.Operation + else -> ClientRustModule.types + }, + ) { + // Re-export RequestId in generated crate + rust("pub use #T;", accessorTrait(codegenContext)) + } + } + + private inner class RequestIdOperationCustomization(private val codegenContext: ClientCodegenContext) : + OperationCustomization() { + override fun section(section: OperationSection): Writable = writable { + when (section) { + is OperationSection.PopulateErrorMetadataExtras -> { + rustTemplate( + "${section.builderName} = #{apply_to_error}(${section.builderName}, ${section.responseName}.headers());", + "apply_to_error" to applyToError(codegenContext), + ) + } + is OperationSection.MutateOutput -> { + rust( + "output._set_$fieldName(#T::$accessorFunctionName(response).map(str::to_string));", + accessorTrait(codegenContext), + ) + } + is OperationSection.BeforeParseResponse -> { + rustTemplate( + "#{tracing}::debug!($fieldName = ?#{trait}::$accessorFunctionName(${section.responseName}));", + "tracing" to RuntimeType.Tracing, + "trait" to accessorTrait(codegenContext), + ) + } + else -> {} + } + } + } + + private inner class RequestIdErrorCustomization(private val codegenContext: ClientCodegenContext) : + ErrorCustomization() { + override fun section(section: ErrorSection): Writable = writable { + when (section) { + is ErrorSection.OperationErrorAdditionalTraitImpls -> { + rustTemplate( + """ + impl #{AccessorTrait} for #{error} { + fn $accessorFunctionName(&self) -> Option<&str> { + self.meta().$accessorFunctionName() + } + } + """, + "AccessorTrait" to accessorTrait(codegenContext), + "error" to section.errorSymbol, + ) + } + + is ErrorSection.ServiceErrorAdditionalTraitImpls -> { + rustBlock("impl #T for Error", accessorTrait(codegenContext)) { + rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") { + rustBlock("match self") { + section.allErrors.forEach { error -> + val sym = codegenContext.symbolProvider.toSymbol(error) + rust("Self::${sym.name}(e) => e.$accessorFunctionName(),") + } + rust("Self::Unhandled(e) => e.$accessorFunctionName(),") + } + } + } + } + } + } + } + + private inner class RequestIdErrorImplCustomization(private val codegenContext: ClientCodegenContext) : + ErrorImplCustomization() { + override fun section(section: ErrorImplSection): Writable = writable { + when (section) { + is ErrorImplSection.ErrorAdditionalTraitImpls -> { + rustBlock("impl #1T for #2T", accessorTrait(codegenContext), section.errorType) { + rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") { + rust("use #T;", RuntimeType.provideErrorMetadataTrait(codegenContext.runtimeConfig)) + rust("self.meta().$accessorFunctionName()") + } + } + } + + else -> {} + } + } + } + + private inner class RequestIdStructureCustomization(private val codegenContext: ClientCodegenContext) : + StructureCustomization() { + override fun section(section: StructureSection): Writable = writable { + if (section.shape.hasTrait()) { + when (section) { + is StructureSection.AdditionalFields -> { + rust("_$fieldName: Option,") + } + + is StructureSection.AdditionalTraitImpls -> { + rustTemplate( + """ + impl #{AccessorTrait} for ${section.structName} { + fn $accessorFunctionName(&self) -> Option<&str> { + self._$fieldName.as_deref() + } + } + """, + "AccessorTrait" to accessorTrait(codegenContext), + ) + } + + is StructureSection.AdditionalDebugFields -> { + rust("""${section.formatterName}.field("_$fieldName", &self._$fieldName);""") + } + } + } + } + } + + private inner class RequestIdBuilderCustomization : BuilderCustomization() { + override fun section(section: BuilderSection): Writable = writable { + if (section.shape.hasTrait()) { + when (section) { + is BuilderSection.AdditionalFields -> { + rust("_$fieldName: Option,") + } + + is BuilderSection.AdditionalMethods -> { + rust( + """ + pub(crate) fn _$fieldName(mut self, $fieldName: impl Into) -> Self { + self._$fieldName = Some($fieldName.into()); + self + } + + pub(crate) fn _set_$fieldName(&mut self, $fieldName: Option) -> &mut Self { + self._$fieldName = $fieldName; + self + } + """, + ) + } + + is BuilderSection.AdditionalDebugFields -> { + rust("""${section.formatterName}.field("_$fieldName", &self._$fieldName);""") + } + + is BuilderSection.AdditionalFieldsInBuild -> { + rust("_$fieldName: self._$fieldName,") + } + } + } + } + } +} diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt index c9499c39203..30e15a0ca93 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/CredentialProviders.kt @@ -8,9 +8,9 @@ package software.amazon.smithy.rustsdk import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.TestUtilFeature +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedConfigModule import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig -import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -19,8 +19,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.adhocCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection class CredentialsProviderDecorator : ClientCodegenDecorator { override val name: String = "CredentialsProvider" @@ -33,13 +31,6 @@ class CredentialsProviderDecorator : ClientCodegenDecorator { return baseCustomizations + CredentialProviderConfig(codegenContext.runtimeConfig) } - override fun libRsCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List { - return baseCustomizations + PubUseCredentials(codegenContext.runtimeConfig) - } - override fun extraSections(codegenContext: ClientCodegenContext): List = listOf( adhocCustomization { section -> @@ -49,6 +40,13 @@ class CredentialsProviderDecorator : ClientCodegenDecorator { override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { rustCrate.mergeFeature(TestUtilFeature.copy(deps = listOf("aws-credential-types/test-util"))) + + rustCrate.withModule(codegenContext.featureGatedConfigModule()) { + rust( + "pub use #T::Credentials;", + AwsRuntimeType.awsCredentialTypes(codegenContext.runtimeConfig), + ) + } } } @@ -95,20 +93,5 @@ class CredentialProviderConfig(runtimeConfig: RuntimeConfig) : ConfigCustomizati } } -class PubUseCredentials(private val runtimeConfig: RuntimeConfig) : LibRsCustomization() { - override fun section(section: LibRsSection): Writable { - return when (section) { - is LibRsSection.Body -> writable { - rust( - "pub use #T::Credentials;", - AwsRuntimeType.awsCredentialTypes(runtimeConfig), - ) - } - - else -> emptySection - } - } -} - fun defaultProvider() = RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("no_credentials")).resolve("NoCredentials") diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt index 0a5a4cfd027..c1799c8bf54 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/HttpRequestChecksumDecorator.kt @@ -26,7 +26,7 @@ import software.amazon.smithy.rust.codegen.core.util.orNull fun RuntimeConfig.awsInlineableBodyWithChecksum() = RuntimeType.forInlineDependency( InlineAwsDependency.forRustFile( - "http_body_checksum", visibility = Visibility.PUBLIC, + "http_body_checksum", visibility = Visibility.PUBCRATE, CargoDependency.Http, CargoDependency.HttpBody, CargoDependency.smithyHttp(this), diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InlineAwsDependency.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InlineAwsDependency.kt index fa2554655a0..b127795fc33 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InlineAwsDependency.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/InlineAwsDependency.kt @@ -12,5 +12,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Visibility object InlineAwsDependency { fun forRustFile(file: String, visibility: Visibility = Visibility.PRIVATE, vararg additionalDependency: RustDependency): InlineDependency = - InlineDependency.Companion.forRustFile(RustModule.new(file, visibility), "/aws-inlineable/src/$file.rs", *additionalDependency) + forRustFileAs(file, file, visibility, *additionalDependency) + + fun forRustFileAs(file: String, moduleName: String, visibility: Visibility = Visibility.PRIVATE, vararg additionalDependency: RustDependency): InlineDependency = + InlineDependency.Companion.forRustFile(RustModule.new(moduleName, visibility), "/aws-inlineable/src/$file.rs", *additionalDependency) } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt index 0eafec8fcdb..9cbddde2502 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/IntegrationTestDependencies.kt @@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection +import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly import java.nio.file.Files import java.nio.file.Paths import kotlin.io.path.absolute @@ -72,7 +73,7 @@ class IntegrationTestDependencies( private val hasBenches: Boolean, ) : LibRsCustomization() { override fun section(section: LibRsSection) = when (section) { - is LibRsSection.Body -> writable { + is LibRsSection.Body -> testDependenciesOnly { if (hasTests) { val smithyClient = CargoDependency.smithyClient(runtimeConfig) .copy(features = setOf("test-util"), scope = DependencyScope.Dev) @@ -81,7 +82,7 @@ class IntegrationTestDependencies( addDependency(SerdeJson) addDependency(Tokio) addDependency(FuturesUtil) - addDependency(Tracing) + addDependency(Tracing.toDevDependency()) addDependency(TracingSubscriber) } if (hasBenches) { @@ -91,6 +92,7 @@ class IntegrationTestDependencies( serviceSpecific.section(section)(this) } } + else -> emptySection } @@ -114,8 +116,8 @@ class S3TestDependencies : LibRsCustomization() { override fun section(section: LibRsSection): Writable = writable { addDependency(AsyncStd) - addDependency(BytesUtils) - addDependency(FastRand) + addDependency(BytesUtils.toDevDependency()) + addDependency(FastRand.toDevDependency()) addDependency(HdrHistogram) addDependency(Smol) addDependency(TempFile) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt index 70775e2cda4..af625e7979d 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/RegionDecorator.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedConfigModule import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -20,12 +21,11 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.adhocCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.extendIf import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf @@ -101,13 +101,6 @@ class RegionDecorator : ClientCodegenDecorator { return baseCustomizations.extendIf(usesRegion(codegenContext)) { RegionConfigPlugin() } } - override fun libRsCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List { - return baseCustomizations.extendIf(usesRegion(codegenContext)) { PubUseRegion(codegenContext.runtimeConfig) } - } - override fun extraSections(codegenContext: ClientCodegenContext): List { return usesRegion(codegenContext).thenSingletonListOf { adhocCustomization { section -> @@ -121,6 +114,14 @@ class RegionDecorator : ClientCodegenDecorator { } } + override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + if (usesRegion(codegenContext)) { + rustCrate.withModule(codegenContext.featureGatedConfigModule()) { + rust("pub use #T::Region;", region(codegenContext.runtimeConfig)) + } + } + } + override fun endpointCustomizations(codegenContext: ClientCodegenContext): List { if (!usesRegion(codegenContext)) { return listOf() @@ -129,7 +130,9 @@ class RegionDecorator : ClientCodegenDecorator { object : EndpointCustomization { override fun loadBuiltInFromServiceConfig(parameter: Parameter, configRef: String): Writable? { return when (parameter.builtIn) { - Builtins.REGION.builtIn -> writable { rust("$configRef.region.as_ref().map(|r|r.as_ref().to_owned())") } + Builtins.REGION.builtIn -> writable { + rust("$configRef.region.as_ref().map(|r|r.as_ref().to_owned())") + } else -> null } } @@ -221,19 +224,4 @@ class RegionConfigPlugin : OperationCustomization() { } } -class PubUseRegion(private val runtimeConfig: RuntimeConfig) : LibRsCustomization() { - override fun section(section: LibRsSection): Writable { - return when (section) { - is LibRsSection.Body -> writable { - rust( - "pub use #T::Region;", - region(runtimeConfig), - ) - } - - else -> emptySection - } - } -} - fun region(runtimeConfig: RuntimeConfig) = AwsRuntimeType.awsTypes(runtimeConfig).resolve("region") diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt index f0e2a7e7d64..1d94cd6082b 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt @@ -6,10 +6,10 @@ package software.amazon.smithy.rustsdk import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate @@ -105,7 +105,7 @@ class SdkConfigDecorator : ClientCodegenDecorator { val codegenScope = arrayOf( "SdkConfig" to AwsRuntimeType.awsTypes(codegenContext.runtimeConfig).resolve("sdk_config::SdkConfig"), ) - rustCrate.withModule(RustModule.Config) { + rustCrate.withModule(ClientRustModule.Config) { rustTemplate( """ impl From<&#{SdkConfig}> for Builder { diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt index 8a2a27f66fb..0fa2c5f66b9 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/UserAgentDecorator.kt @@ -9,6 +9,8 @@ import software.amazon.smithy.aws.traits.ServiceTrait import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedConfigModule +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedMetaModule import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -16,12 +18,12 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.adhocCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait @@ -39,21 +41,12 @@ class UserAgentDecorator : ClientCodegenDecorator { return baseCustomizations + AppNameCustomization(codegenContext.runtimeConfig) } - override fun libRsCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List { - // We are generating an AWS SDK, the service needs to have the AWS service trait - val serviceTrait = codegenContext.serviceShape.expectTrait() - return baseCustomizations + ApiVersionAndPubUse(codegenContext.runtimeConfig, serviceTrait) - } - override fun operationCustomizations( codegenContext: ClientCodegenContext, operation: OperationShape, baseCustomizations: List, ): List { - return baseCustomizations + UserAgentFeature(codegenContext.runtimeConfig) + return baseCustomizations + UserAgentFeature(codegenContext) } override fun extraSections(codegenContext: ClientCodegenContext): List { @@ -65,44 +58,54 @@ class UserAgentDecorator : ClientCodegenDecorator { } /** - * Adds a static `API_METADATA` variable to the crate root containing the serviceId & the version of the crate for this individual service + * Adds a static `API_METADATA` variable to the crate `config` containing the serviceId & the version of the crate for this individual service */ - private class ApiVersionAndPubUse(private val runtimeConfig: RuntimeConfig, serviceTrait: ServiceTrait) : - LibRsCustomization() { - private val serviceId = serviceTrait.sdkId.lowercase().replace(" ", "") - override fun section(section: LibRsSection): Writable = when (section) { - is LibRsSection.Body -> writable { - // PKG_VERSION comes from CrateVersionGenerator - rust( - "static API_METADATA: #1T::ApiMetadata = #1T::ApiMetadata::new(${serviceId.dq()}, PKG_VERSION);", - AwsRuntimeType.awsHttp(runtimeConfig).resolve("user_agent"), - ) + override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + val runtimeConfig = codegenContext.runtimeConfig - // Re-export the app name so that it can be specified in config programmatically without an explicit dependency - rustTemplate( - "pub use #{AppName};", - "AppName" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("app_name::AppName"), - ) - } + // We are generating an AWS SDK, the service needs to have the AWS service trait + val serviceTrait = codegenContext.serviceShape.expectTrait() + val serviceId = serviceTrait.sdkId.lowercase().replace(" ", "") + + rustCrate.withModule(codegenContext.featureGatedMetaModule()) { + rustTemplate( + """ + pub(crate) static API_METADATA: #{user_agent}::ApiMetadata = + #{user_agent}::ApiMetadata::new(${serviceId.dq()}, #{PKG_VERSION}); + """, + "user_agent" to AwsRuntimeType.awsHttp(runtimeConfig).resolve("user_agent"), + "PKG_VERSION" to CrateVersionCustomization.pkgVersion(codegenContext.featureGatedMetaModule()), + ) + } - else -> emptySection + rustCrate.withModule(codegenContext.featureGatedConfigModule()) { + // Re-export the app name so that it can be specified in config programmatically without an explicit dependency + rustTemplate( + "pub use #{AppName};", + "AppName" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("app_name::AppName"), + ) } } - private class UserAgentFeature(private val runtimeConfig: RuntimeConfig) : OperationCustomization() { + private class UserAgentFeature( + private val codegenContext: ClientCodegenContext, + ) : OperationCustomization() { + private val runtimeConfig = codegenContext.runtimeConfig + override fun section(section: OperationSection): Writable = when (section) { is OperationSection.MutateRequest -> writable { rustTemplate( """ let mut user_agent = #{ua_module}::AwsUserAgent::new_from_environment( #{Env}::real(), - crate::API_METADATA.clone(), + #{meta}::API_METADATA.clone(), ); if let Some(app_name) = _config.app_name() { user_agent = user_agent.with_app_name(app_name.clone()); } ${section.request}.properties_mut().insert(user_agent); """, + "meta" to codegenContext.featureGatedMetaModule(), "ua_module" to AwsRuntimeType.awsHttp(runtimeConfig).resolve("user_agent"), "Env" to AwsRuntimeType.awsTypes(runtimeConfig).resolve("os_shim_internal::Env"), ) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/auth/DisabledAuthDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/DisabledAuthDecorator.kt similarity index 91% rename from aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/auth/DisabledAuthDecorator.kt rename to aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/DisabledAuthDecorator.kt index 2c65f95bd38..dfbd2ca5979 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/auth/DisabledAuthDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/DisabledAuthDecorator.kt @@ -3,17 +3,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rustsdk.customize.auth +package software.amazon.smithy.rustsdk.customize import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.AuthTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator - -private fun String.shapeId() = ShapeId.from(this) +import software.amazon.smithy.rust.codegen.core.util.shapeId // / STS (and possibly other services) need to have auth manually set to [] class DisabledAuthDecorator : ClientCodegenDecorator { diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ServiceSpecificDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ServiceSpecificDecorator.kt new file mode 100644 index 00000000000..8e957b3f599 --- /dev/null +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ServiceSpecificDecorator.kt @@ -0,0 +1,133 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rustsdk.customize + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ToShapeId +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientProtocolMap +import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.ManifestCustomizations +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization + +/** Only apply this decorator to the given service ID */ +fun ClientCodegenDecorator.onlyApplyTo(serviceId: String): List = + listOf(ServiceSpecificDecorator(ShapeId.from(serviceId), this)) + +/** Apply the given decorators only to this service ID */ +fun String.applyDecorators(vararg decorators: ClientCodegenDecorator): List = + decorators.map { it.onlyApplyTo(this) }.flatten() + +/** + * Delegating decorator that only applies to a configured service ID + */ +class ServiceSpecificDecorator( + /** Service ID this decorator is active for */ + private val appliesToServiceId: ShapeId, + /** Decorator to delegate to */ + private val delegateTo: ClientCodegenDecorator, + /** Decorator name */ + override val name: String = "${appliesToServiceId.namespace}.${appliesToServiceId.name}", + /** Decorator order */ + override val order: Byte = 0, +) : ClientCodegenDecorator { + private fun T.maybeApply(serviceId: ToShapeId, delegatedValue: () -> T): T = + if (appliesToServiceId == serviceId.toShapeId()) { + delegatedValue() + } else { + this + } + + // This kind of decorator gets explicitly added to the root sdk-codegen decorator + override fun classpathDiscoverable(): Boolean = false + + override fun builderCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.builderCustomizations(codegenContext, baseCustomizations) + } + + override fun configCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.configCustomizations(codegenContext, baseCustomizations) + } + + override fun crateManifestCustomizations(codegenContext: ClientCodegenContext): ManifestCustomizations = + emptyMap().maybeApply(codegenContext.serviceShape) { + delegateTo.crateManifestCustomizations(codegenContext) + } + + override fun endpointCustomizations(codegenContext: ClientCodegenContext): List = + emptyList().maybeApply(codegenContext.serviceShape) { + delegateTo.endpointCustomizations(codegenContext) + } + + override fun errorCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.errorCustomizations(codegenContext, baseCustomizations) + } + + override fun errorImplCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.errorImplCustomizations(codegenContext, baseCustomizations) + } + + override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + maybeApply(codegenContext.serviceShape) { + delegateTo.extras(codegenContext, rustCrate) + } + } + + override fun libRsCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.libRsCustomizations(codegenContext, baseCustomizations) + } + + override fun operationCustomizations( + codegenContext: ClientCodegenContext, + operation: OperationShape, + baseCustomizations: List, + ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.operationCustomizations(codegenContext, operation, baseCustomizations) + } + + override fun protocols(serviceId: ShapeId, currentProtocols: ClientProtocolMap): ClientProtocolMap = + currentProtocols.maybeApply(serviceId) { + delegateTo.protocols(serviceId, currentProtocols) + } + + override fun structureCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations.maybeApply(codegenContext.serviceShape) { + delegateTo.structureCustomizations(codegenContext, baseCustomizations) + } + + override fun transformModel(service: ServiceShape, model: Model): Model = + model.maybeApply(service) { + delegateTo.transformModel(service, model) + } +} diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt index 9fc0f3e4c70..5959918ef75 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/apigateway/ApiGatewayDecorator.kt @@ -6,34 +6,24 @@ package software.amazon.smithy.rustsdk.customize.apigateway import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection -import software.amazon.smithy.rust.codegen.core.util.letIf class ApiGatewayDecorator : ClientCodegenDecorator { override val name: String = "ApiGateway" override val order: Byte = 0 - private fun applies(codegenContext: CodegenContext) = - codegenContext.serviceShape.id == ShapeId.from("com.amazonaws.apigateway#BackplaneControlService") - override fun operationCustomizations( codegenContext: ClientCodegenContext, operation: OperationShape, baseCustomizations: List, - ): List { - return baseCustomizations.letIf(applies(codegenContext)) { - it + ApiGatewayAddAcceptHeader() - } - } + ): List = baseCustomizations + ApiGatewayAddAcceptHeader() } class ApiGatewayAddAcceptHeader : OperationCustomization() { diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ec2/Ec2Decorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ec2/Ec2Decorator.kt index ee005da318a..e788920e1d1 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ec2/Ec2Decorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/ec2/Ec2Decorator.kt @@ -7,24 +7,14 @@ package software.amazon.smithy.rustsdk.customize.ec2 import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator -import software.amazon.smithy.rust.codegen.core.util.letIf class Ec2Decorator : ClientCodegenDecorator { override val name: String = "Ec2" override val order: Byte = 0 - private val ec2 = ShapeId.from("com.amazonaws.ec2#AmazonEC2") - private fun applies(serviceShape: ServiceShape) = - serviceShape.id == ec2 - - override fun transformModel(service: ServiceShape, model: Model): Model { - // EC2 incorrectly models primitive shapes as unboxed when they actually - // need to be boxed for the API to work properly - return model.letIf( - applies(service), - EC2MakePrimitivesOptional::processModel, - ) - } + // EC2 incorrectly models primitive shapes as unboxed when they actually + // need to be boxed for the API to work properly + override fun transformModel(service: ServiceShape, model: Model): Model = + EC2MakePrimitivesOptional.processModel(model) } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/AccountIdAutofill.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/AccountIdAutofill.kt index 65e4f3a4558..63ef01d1b18 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/AccountIdAutofill.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/AccountIdAutofill.kt @@ -37,7 +37,9 @@ class AccountIdAutofill : OperationCustomization() { val input = operation.inputShape(model) return if (input.memberNames.contains("accountId")) { AccountIdAutofill() - } else null + } else { + null + } } } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/GlacierDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/GlacierDecorator.kt index 7bfc3c4e424..5ba71e2f20c 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/GlacierDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/GlacierDecorator.kt @@ -6,35 +6,21 @@ package software.amazon.smithy.rustsdk.customize.glacier import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization -val Glacier: ShapeId = ShapeId.from("com.amazonaws.glacier#Glacier") - class GlacierDecorator : ClientCodegenDecorator { override val name: String = "Glacier" override val order: Byte = 0 - private fun applies(codegenContext: CodegenContext) = codegenContext.serviceShape.id == Glacier - override fun operationCustomizations( codegenContext: ClientCodegenContext, operation: OperationShape, baseCustomizations: List, - ): List { - val extras = if (applies(codegenContext)) { - val apiVersion = codegenContext.serviceShape.version - listOfNotNull( - ApiVersionHeader(apiVersion), - TreeHashHeader.forOperation(operation, codegenContext.runtimeConfig), - AccountIdAutofill.forOperation(operation, codegenContext.model), - ) - } else { - emptyList() - } - return baseCustomizations + extras - } + ): List = baseCustomizations + listOfNotNull( + ApiVersionHeader(codegenContext.serviceShape.version), + TreeHashHeader.forOperation(operation, codegenContext.runtimeConfig), + AccountIdAutofill.forOperation(operation, codegenContext.model), + ) } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/TreeHashHeader.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/TreeHashHeader.kt index 023a663c429..c97357a2fd9 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/TreeHashHeader.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/glacier/TreeHashHeader.kt @@ -33,13 +33,16 @@ private val UploadMultipartPart: ShapeId = ShapeId.from("com.amazonaws.glacier#U private val Applies = setOf(UploadArchive, UploadMultipartPart) class TreeHashHeader(private val runtimeConfig: RuntimeConfig) : OperationCustomization() { - private val glacierChecksums = RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("glacier_checksums")) + private val glacierChecksums = RuntimeType.forInlineDependency( + InlineAwsDependency.forRustFile( + "glacier_checksums", + additionalDependency = TreeHashDependencies.toTypedArray(), + ), + ) + override fun section(section: OperationSection): Writable { return when (section) { is OperationSection.MutateRequest -> writable { - TreeHashDependencies.forEach { dep -> - addDependency(dep) - } rustTemplate( """ #{glacier_checksums}::add_checksum_treehash( @@ -49,6 +52,7 @@ class TreeHashHeader(private val runtimeConfig: RuntimeConfig) : OperationCustom "glacier_checksums" to glacierChecksums, "BuildError" to runtimeConfig.operationBuildError(), ) } + else -> emptySection } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/route53/Route53Decorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/route53/Route53Decorator.kt index e1adcf46afd..007254c479d 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/route53/Route53Decorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/route53/Route53Decorator.kt @@ -26,26 +26,19 @@ import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rustsdk.InlineAwsDependency import java.util.logging.Logger -val Route53: ShapeId = ShapeId.from("com.amazonaws.route53#AWSDnsV20130401") - class Route53Decorator : ClientCodegenDecorator { override val name: String = "Route53" override val order: Byte = 0 private val logger: Logger = Logger.getLogger(javaClass.name) private val resourceShapes = setOf(ShapeId.from("com.amazonaws.route53#ResourceId"), ShapeId.from("com.amazonaws.route53#ChangeId")) - private fun applies(service: ServiceShape) = service.id == Route53 - - override fun transformModel(service: ServiceShape, model: Model): Model { - return model.letIf(applies(service)) { - ModelTransformer.create().mapShapes(model) { shape -> - shape.letIf(isResourceId(shape)) { - logger.info("Adding TrimResourceId trait to $shape") - (shape as MemberShape).toBuilder().addTrait(TrimResourceId()).build() - } + override fun transformModel(service: ServiceShape, model: Model): Model = + ModelTransformer.create().mapShapes(model) { shape -> + shape.letIf(isResourceId(shape)) { + logger.info("Adding TrimResourceId trait to $shape") + (shape as MemberShape).toBuilder().addTrait(TrimResourceId()).build() } } - } override fun operationCustomizations( codegenContext: ClientCodegenContext, @@ -56,7 +49,9 @@ class Route53Decorator : ClientCodegenDecorator { operation.inputShape(codegenContext.model).members().find { it.hasTrait() } return if (hostedZoneMember != null) { baseCustomizations + TrimResourceIdCustomization(codegenContext.symbolProvider.toMemberName(hostedZoneMember)) - } else baseCustomizations + } else { + baseCustomizations + } } private fun isResourceId(shape: Shape): Boolean { diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt index d6c0f8257f4..bd0c5210080 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt @@ -22,19 +22,15 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Cli import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientRestXmlFactory import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml import software.amazon.smithy.rust.codegen.core.smithy.traits.AllowInvalidXmlRoot import software.amazon.smithy.rust.codegen.core.util.letIf -import software.amazon.smithy.rustsdk.AwsRuntimeType import software.amazon.smithy.rustsdk.endpoints.stripEndpointTrait import software.amazon.smithy.rustsdk.getBuiltIn import software.amazon.smithy.rustsdk.toWritable @@ -52,38 +48,22 @@ class S3Decorator : ClientCodegenDecorator { ShapeId.from("com.amazonaws.s3#GetObjectAttributesOutput"), ) - private fun applies(serviceId: ShapeId) = - serviceId == ShapeId.from("com.amazonaws.s3#AmazonS3") - override fun protocols( serviceId: ShapeId, currentProtocols: ProtocolMap, - ): ProtocolMap = - currentProtocols.letIf(applies(serviceId)) { - it + mapOf( - RestXmlTrait.ID to ClientRestXmlFactory { protocolConfig -> - S3(protocolConfig) - }, - ) - } - - override fun transformModel(service: ServiceShape, model: Model): Model { - return model.letIf(applies(service.id)) { - ModelTransformer.create().mapShapes(model) { shape -> - shape.letIf(isInInvalidXmlRootAllowList(shape)) { - logger.info("Adding AllowInvalidXmlRoot trait to $it") - (it as StructureShape).toBuilder().addTrait(AllowInvalidXmlRoot()).build() - } - }.let(StripBucketFromHttpPath()::transform).let(stripEndpointTrait("RequestRoute")) - } - } + ): ProtocolMap = currentProtocols + mapOf( + RestXmlTrait.ID to ClientRestXmlFactory { protocolConfig -> + S3ProtocolOverride(protocolConfig) + }, + ) - override fun libRsCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List = baseCustomizations.letIf(applies(codegenContext.serviceShape.id)) { - it + S3PubUse() - } + override fun transformModel(service: ServiceShape, model: Model): Model = + ModelTransformer.create().mapShapes(model) { shape -> + shape.letIf(isInInvalidXmlRootAllowList(shape)) { + logger.info("Adding AllowInvalidXmlRoot trait to $it") + (it as StructureShape).toBuilder().addTrait(AllowInvalidXmlRoot()).build() + } + }.let(StripBucketFromHttpPath()::transform).let(stripEndpointTrait("RequestRoute")) override fun endpointCustomizations(codegenContext: ClientCodegenContext): List { return listOf(object : EndpointCustomization { @@ -108,35 +88,36 @@ class S3Decorator : ClientCodegenDecorator { } } -class S3(codegenContext: CodegenContext) : RestXml(codegenContext) { +class S3ProtocolOverride(codegenContext: CodegenContext) : RestXml(codegenContext) { private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, - "Error" to RuntimeType.genericError(runtimeConfig), + "ErrorMetadata" to RuntimeType.errorMetadata(runtimeConfig), + "ErrorBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), "HeaderMap" to RuntimeType.HttpHeaderMap, "Response" to RuntimeType.HttpResponse, "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), "base_errors" to restXmlErrors, - "s3_errors" to AwsRuntimeType.S3Errors, ) - override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType { - return RuntimeType.forInlineFun("parse_http_generic_error", RustModule.private("xml_deser")) { + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType { + return RuntimeType.forInlineFun("parse_http_error_metadata", RustModule.private("xml_deser")) { rustBlockTemplate( - "pub fn parse_http_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{XmlDecodeError}>", + "pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorBuilder}, #{XmlDecodeError}>", *errorScope, ) { rustTemplate( """ + // S3 HEAD responses have no response body to for an error code. Therefore, + // check the HTTP response status and populate an error code for 404s. if response.body().is_empty() { - let mut err = #{Error}::builder(); + let mut builder = #{ErrorMetadata}::builder(); if response.status().as_u16() == 404 { - err.code("NotFound"); + builder = builder.code("NotFound"); } - Ok(err.build()) + Ok(builder) } else { - let base_err = #{base_errors}::parse_generic_error(response.body().as_ref())?; - Ok(#{s3_errors}::parse_extended_error(base_err, response.headers())) + #{base_errors}::parse_error_metadata(response.body().as_ref()) } """, *errorScope, @@ -145,16 +126,3 @@ class S3(codegenContext: CodegenContext) : RestXml(codegenContext) { } } } - -class S3PubUse : LibRsCustomization() { - override fun section(section: LibRsSection): Writable = when (section) { - is LibRsSection.Body -> writable { - rust( - "pub use #T::ErrorExt;", - AwsRuntimeType.S3Errors, - ) - } - - else -> emptySection - } -} diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3ExtendedRequestIdDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3ExtendedRequestIdDecorator.kt new file mode 100644 index 00000000000..6b117b60da2 --- /dev/null +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3ExtendedRequestIdDecorator.kt @@ -0,0 +1,28 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rustsdk.customize.s3 + +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rustsdk.BaseRequestIdDecorator +import software.amazon.smithy.rustsdk.InlineAwsDependency + +class S3ExtendedRequestIdDecorator : BaseRequestIdDecorator() { + override val name: String = "S3ExtendedRequestIdDecorator" + override val order: Byte = 0 + + override val fieldName: String = "extended_request_id" + override val accessorFunctionName: String = "extended_request_id" + + private val requestIdModule: RuntimeType = + RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("s3_request_id")) + + override fun accessorTrait(codegenContext: ClientCodegenContext): RuntimeType = + requestIdModule.resolve("RequestIdExt") + + override fun applyToError(codegenContext: ClientCodegenContext): RuntimeType = + requestIdModule.resolve("apply_extended_request_id") +} diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3control/S3ControlDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3control/S3ControlDecorator.kt index 39e85d78988..3534258b189 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3control/S3ControlDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3control/S3ControlDecorator.kt @@ -6,21 +6,41 @@ package software.amazon.smithy.rustsdk.customize.s3control import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization +import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rustsdk.endpoints.stripEndpointTrait +import software.amazon.smithy.rustsdk.getBuiltIn +import software.amazon.smithy.rustsdk.toWritable class S3ControlDecorator : ClientCodegenDecorator { override val name: String = "S3Control" override val order: Byte = 0 - private fun applies(service: ServiceShape) = - service.id == ShapeId.from("com.amazonaws.s3control#AWSS3ControlServiceV20180820") - override fun transformModel(service: ServiceShape, model: Model): Model { - if (!applies(service)) { - return model - } - return stripEndpointTrait("AccountId")(model) + override fun transformModel(service: ServiceShape, model: Model): Model = + stripEndpointTrait("AccountId")(model) + + override fun endpointCustomizations(codegenContext: ClientCodegenContext): List { + return listOf(object : EndpointCustomization { + override fun setBuiltInOnServiceConfig(name: String, value: Node, configBuilderRef: String): Writable? { + if (!name.startsWith("AWS::S3Control")) { + return null + } + val builtIn = codegenContext.getBuiltIn(name) ?: return null + return writable { + rustTemplate( + "let $configBuilderRef = $configBuilderRef.${builtIn.name.rustName()}(#{value});", + "value" to value.toWritable(), + ) + } + } + }, + ) } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sts/STSDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sts/STSDecorator.kt index 75d0555c8f7..a332dd30350 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sts/STSDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/sts/STSDecorator.kt @@ -7,7 +7,6 @@ package software.amazon.smithy.rustsdk.customize.sts import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.RetryableTrait @@ -22,24 +21,18 @@ class STSDecorator : ClientCodegenDecorator { override val order: Byte = 0 private val logger: Logger = Logger.getLogger(javaClass.name) - private fun applies(serviceId: ShapeId) = - serviceId == ShapeId.from("com.amazonaws.sts#AWSSecurityTokenServiceV20110615") - private fun isIdpCommunicationError(shape: Shape): Boolean = shape is StructureShape && shape.hasTrait() && shape.id.namespace == "com.amazonaws.sts" && shape.id.name == "IDPCommunicationErrorException" - override fun transformModel(service: ServiceShape, model: Model): Model { - return model.letIf(applies(service.id)) { - ModelTransformer.create().mapShapes(model) { shape -> - shape.letIf(isIdpCommunicationError(shape)) { - logger.info("Adding @retryable trait to $shape and setting its error type to 'server'") - (shape as StructureShape).toBuilder() - .removeTrait(ErrorTrait.ID) - .addTrait(ErrorTrait("server")) - .addTrait(RetryableTrait.builder().build()).build() - } + override fun transformModel(service: ServiceShape, model: Model): Model = + ModelTransformer.create().mapShapes(model) { shape -> + shape.letIf(isIdpCommunicationError(shape)) { + logger.info("Adding @retryable trait to $shape and setting its error type to 'server'") + (shape as StructureShape).toBuilder() + .removeTrait(ErrorTrait.ID) + .addTrait(ErrorTrait("server")) + .addTrait(RetryableTrait.builder().build()).build() } } - } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/AwsEndpointDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/AwsEndpointDecorator.kt index ab9ce5bf1f7..7885db72d60 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/AwsEndpointDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/AwsEndpointDecorator.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesGenerator import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.EndpointsModule +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedConfigModule import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.core.rustlang.Attribute @@ -27,12 +28,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.adhocCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection import software.amazon.smithy.rust.codegen.core.util.extendIf import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf @@ -91,14 +89,14 @@ class AwsEndpointDecorator : ClientCodegenDecorator { } } - override fun libRsCustomizations( - codegenContext: ClientCodegenContext, - baseCustomizations: List, - ): List { - return baseCustomizations + PubUseEndpoint(codegenContext.runtimeConfig) - } - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + rustCrate.withModule(codegenContext.featureGatedConfigModule()) { + rust( + "pub use #T::endpoint::Endpoint;", + CargoDependency.smithyHttp(codegenContext.runtimeConfig).toType(), + ) + } + val epTypes = EndpointTypesGenerator.fromContext(codegenContext) if (epTypes.defaultResolver() == null) { throw CodegenException( @@ -252,23 +250,7 @@ class AwsEndpointDecorator : ClientCodegenDecorator { } ServiceConfig.ConfigStruct -> rust("endpoint_url: Option,") - ServiceConfig.ConfigStructAdditionalDocs -> emptySection - ServiceConfig.Extras -> emptySection - } - } - } - - class PubUseEndpoint(private val runtimeConfig: RuntimeConfig) : LibRsCustomization() { - override fun section(section: LibRsSection): Writable { - return when (section) { - is LibRsSection.Body -> writable { - rust( - "pub use #T::endpoint::Endpoint;", - CargoDependency.smithyHttp(runtimeConfig).toType(), - ) - } - - else -> emptySection + else -> {} } } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/OperationInputTestGenerator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/OperationInputTestGenerator.kt index 37ecbad73cc..9d620e11cb6 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/OperationInputTestGenerator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/endpoints/OperationInputTestGenerator.kt @@ -17,7 +17,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesG import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.AttributeKind -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -25,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.PublicImportSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.testutil.integrationTest @@ -146,8 +146,7 @@ class OperationInputTestGenerator(_ctx: ClientCodegenContext, private val test: let _result = dbg!(#{invoke_operation}); #{assertion} """, - "capture_request" to CargoDependency.smithyClient(runtimeConfig) - .withFeature("test-util").toType().resolve("test_connection::capture_request"), + "capture_request" to RuntimeType.captureRequest(runtimeConfig), "conf" to config(testOperationInput), "invoke_operation" to operationInvocation(testOperationInput), "assertion" to writable { diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointsCredentialsTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointsCredentialsTest.kt index 1df1a0c0e04..ea2bfe026fa 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointsCredentialsTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/EndpointsCredentialsTest.kt @@ -6,8 +6,8 @@ package software.amazon.smithy.rustsdk import org.junit.jupiter.api.Test -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest import software.amazon.smithy.rust.codegen.core.testutil.tokioTest @@ -96,8 +96,7 @@ class EndpointsCredentialsTest { let auth_header = req.headers().get("AUTHORIZATION").unwrap().to_str().unwrap(); assert!(auth_header.contains("/us-west-2/foobaz/aws4_request"), "{}", auth_header); """, - "capture_request" to CargoDependency.smithyClient(context.runtimeConfig) - .withFeature("test-util").toType().resolve("test_connection::capture_request"), + "capture_request" to RuntimeType.captureRequest(context.runtimeConfig), "Credentials" to AwsCargoDependency.awsCredentialTypes(context.runtimeConfig) .withFeature("test-util").toType().resolve("Credentials"), "Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"), @@ -120,8 +119,7 @@ class EndpointsCredentialsTest { let auth_header = req.headers().get("AUTHORIZATION").unwrap().to_str().unwrap(); assert!(auth_header.contains("/region-custom-auth/name-custom-auth/aws4_request"), "{}", auth_header); """, - "capture_request" to CargoDependency.smithyClient(context.runtimeConfig) - .withFeature("test-util").toType().resolve("test_connection::capture_request"), + "capture_request" to RuntimeType.captureRequest(context.runtimeConfig), "Credentials" to AwsCargoDependency.awsCredentialTypes(context.runtimeConfig) .withFeature("test-util").toType().resolve("Credentials"), "Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"), diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/HttpConnectorConfigCustomizationTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/HttpConnectorConfigCustomizationTest.kt index 1307a46fbb8..b97952e0021 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/HttpConnectorConfigCustomizationTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/HttpConnectorConfigCustomizationTest.kt @@ -7,30 +7,13 @@ package software.amazon.smithy.rustsdk import org.junit.jupiter.api.Test import software.amazon.smithy.rust.codegen.client.testutil.validateConfigCustomizations -import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace -import software.amazon.smithy.rust.codegen.core.testutil.rustSettings class HttpConnectorConfigCustomizationTest { @Test fun `generates a valid config`() { val project = TestWorkspace.testProject() - val projectSettings = project.rustSettings() - val codegenContext = awsTestCodegenContext( - coreRustSettings = CoreRustSettings( - service = projectSettings.service, - moduleName = projectSettings.moduleName, - moduleVersion = projectSettings.moduleVersion, - moduleAuthors = projectSettings.moduleAuthors, - moduleDescription = projectSettings.moduleDescription, - moduleRepository = projectSettings.moduleRepository, - runtimeConfig = AwsTestRuntimeConfig, - codegenConfig = projectSettings.codegenConfig, - license = projectSettings.license, - examplesUri = projectSettings.examplesUri, - customizationConfig = projectSettings.customizationConfig, - ), - ) + val codegenContext = awsTestCodegenContext() validateConfigCustomizations(HttpConnectorConfigCustomization(codegenContext), project) } } diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/RegionProviderConfigTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/RegionProviderConfigTest.kt index 8d69fb2c868..9d2e865a646 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/RegionProviderConfigTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/RegionProviderConfigTest.kt @@ -6,8 +6,8 @@ package software.amazon.smithy.rustsdk import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.client.testutil.testClientRustSettings import software.amazon.smithy.rust.codegen.client.testutil.validateConfigCustomizations -import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.rustSettings @@ -15,21 +15,12 @@ internal class RegionProviderConfigTest { @Test fun `generates a valid config`() { val project = TestWorkspace.testProject() - val projectSettings = project.rustSettings() - val coreRustSettings = CoreRustSettings( - service = projectSettings.service, - moduleName = projectSettings.moduleName, - moduleVersion = projectSettings.moduleVersion, - moduleAuthors = projectSettings.moduleAuthors, - moduleDescription = projectSettings.moduleDescription, - moduleRepository = projectSettings.moduleRepository, - runtimeConfig = AwsTestRuntimeConfig, - codegenConfig = projectSettings.codegenConfig, - license = projectSettings.license, - examplesUri = projectSettings.examplesUri, - customizationConfig = projectSettings.customizationConfig, + val codegenContext = awsTestCodegenContext( + settings = testClientRustSettings( + moduleName = project.rustSettings().moduleName, + runtimeConfig = AwsTestRuntimeConfig, + ), ) - val codegenContext = awsTestCodegenContext(coreRustSettings = coreRustSettings) validateConfigCustomizations(RegionProviderConfig(codegenContext), project) } } diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/TestUtil.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/TestUtil.kt index 8db0d6a1dff..d0a619b4679 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/TestUtil.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/TestUtil.kt @@ -8,14 +8,15 @@ package software.amazon.smithy.rustsdk import software.amazon.smithy.model.Model import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustSettings import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest -import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings +import software.amazon.smithy.rust.codegen.client.testutil.testClientCodegenContext +import software.amazon.smithy.rust.codegen.client.testutil.testClientRustSettings import software.amazon.smithy.rust.codegen.core.smithy.RuntimeCrateLocation import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel -import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings import java.io.File // In aws-sdk-codegen, the working dir when gradle runs tests is actually `./aws`. So, to find the smithy runtime, we need @@ -28,10 +29,10 @@ val AwsTestRuntimeConfig = TestRuntimeConfig.copy( }, ) -fun awsTestCodegenContext(model: Model? = null, coreRustSettings: CoreRustSettings?) = - testCodegenContext( +fun awsTestCodegenContext(model: Model? = null, settings: ClientRustSettings? = null) = + testClientCodegenContext( model ?: "namespace test".asSmithyModel(), - settings = coreRustSettings ?: testRustSettings(runtimeConfig = AwsTestRuntimeConfig), + settings = settings ?: testClientRustSettings(runtimeConfig = AwsTestRuntimeConfig), ) fun awsSdkIntegrationTest( @@ -39,9 +40,10 @@ fun awsSdkIntegrationTest( test: (ClientCodegenContext, RustCrate) -> Unit = { _, _ -> }, ) = clientIntegrationTest( - model, runtimeConfig = AwsTestRuntimeConfig, - additionalSettings = ObjectNode.builder() - .withMember( + model, + IntegrationTestParams( + runtimeConfig = AwsTestRuntimeConfig, + additionalSettings = ObjectNode.builder().withMember( "customizationConfig", ObjectNode.builder() .withMember( @@ -51,6 +53,7 @@ fun awsSdkIntegrationTest( .build(), ).build(), ) - .withMember("codegen", ObjectNode.builder().withMember("includeFluentClient", false).build()).build(), + .withMember("codegen", ObjectNode.builder().withMember("includeFluentClient", false).build()).build(), + ), test = test, ) diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index e2b3d0b72bf..31478777048 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -52,7 +52,9 @@ dependencies { // Class and functions for service and protocol membership for SDK generation -val awsServices: AwsServices by lazy { discoverServices(properties.get("aws.sdk.models.path"), loadServiceMembership()) } +val awsServices: AwsServices by lazy { + discoverServices(properties.get("aws.sdk.models.path"), loadServiceMembership()) +} val eventStreamAllowList: Set by lazy { eventStreamAllowList() } val crateVersioner by lazy { aws.sdk.CrateVersioner.defaultFor(rootProject, properties) } @@ -97,7 +99,8 @@ fun generateSmithyBuild(services: AwsServices): String { "codegen": { "includeFluentClient": false, "renameErrors": false, - "eventStreamAllowList": [$eventStreamAllowListMembers] + "eventStreamAllowList": [$eventStreamAllowListMembers], + "enableNewCrateOrganizationScheme": false }, "service": "${service.service}", "module": "$moduleName", diff --git a/aws/sdk/integration-tests/Cargo.toml b/aws/sdk/integration-tests/Cargo.toml index 406b718a94b..a36345cda0f 100644 --- a/aws/sdk/integration-tests/Cargo.toml +++ b/aws/sdk/integration-tests/Cargo.toml @@ -15,4 +15,5 @@ members = [ "s3control", "sts", "transcribestreaming", + "using-native-tls-instead-of-rustls", ] diff --git a/aws/sdk/integration-tests/kms/tests/integration.rs b/aws/sdk/integration-tests/kms/tests/integration.rs index baa39ef6a48..8c9bd1e1058 100644 --- a/aws/sdk/integration-tests/kms/tests/integration.rs +++ b/aws/sdk/integration-tests/kms/tests/integration.rs @@ -6,6 +6,7 @@ use aws_http::user_agent::AwsUserAgent; use aws_sdk_kms as kms; use aws_sdk_kms::middleware::DefaultMiddleware; +use aws_sdk_kms::types::RequestId; use aws_smithy_client::test_connection::TestConnection; use aws_smithy_client::{Client as CoreClient, SdkError}; use aws_smithy_http::body::SdkBody; diff --git a/aws/sdk/integration-tests/kms/tests/sensitive-it.rs b/aws/sdk/integration-tests/kms/tests/sensitive-it.rs index 5a97651d83e..00f3c8d95e0 100644 --- a/aws/sdk/integration-tests/kms/tests/sensitive-it.rs +++ b/aws/sdk/integration-tests/kms/tests/sensitive-it.rs @@ -19,12 +19,17 @@ use kms::types::Blob; #[test] fn validate_sensitive_trait() { + let builder = GenerateRandomOutput::builder().plaintext(Blob::new("some output")); + assert_eq!( + format!("{:?}", builder), + "Builder { plaintext: \"*** Sensitive Data Redacted ***\", _request_id: None }" + ); let output = GenerateRandomOutput::builder() .plaintext(Blob::new("some output")) .build(); assert_eq!( format!("{:?}", output), - "GenerateRandomOutput { plaintext: \"*** Sensitive Data Redacted ***\" }" + "GenerateRandomOutput { plaintext: \"*** Sensitive Data Redacted ***\", _request_id: None }" ); } diff --git a/aws/sdk/integration-tests/lambda/tests/request_id.rs b/aws/sdk/integration-tests/lambda/tests/request_id.rs new file mode 100644 index 00000000000..ab3ede5f0ac --- /dev/null +++ b/aws/sdk/integration-tests/lambda/tests/request_id.rs @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_sdk_lambda::error::ListFunctionsError; +use aws_sdk_lambda::operation::ListFunctions; +use aws_sdk_lambda::types::RequestId; +use aws_smithy_http::response::ParseHttpResponse; +use bytes::Bytes; + +#[test] +fn get_request_id_from_unmodeled_error() { + let resp = http::Response::builder() + .header("x-amzn-RequestId", "correct-request-id") + .header("X-Amzn-Errortype", "ListFunctions") + .status(500) + .body("{}") + .unwrap(); + let err = ListFunctions::new() + .parse_loaded(&resp.map(Bytes::from)) + .expect_err("status was 500, this is an error"); + assert!(matches!(err, ListFunctionsError::Unhandled(_))); + assert_eq!(Some("correct-request-id"), err.request_id()); + assert_eq!(Some("correct-request-id"), err.meta().request_id()); +} + +#[test] +fn get_request_id_from_successful_response() { + let resp = http::Response::builder() + .header("x-amzn-RequestId", "correct-request-id") + .status(200) + .body(r#"{"Functions":[],"NextMarker":null}"#) + .unwrap(); + let output = ListFunctions::new() + .parse_loaded(&resp.map(Bytes::from)) + .expect("valid successful response"); + assert_eq!(Some("correct-request-id"), output.request_id()); +} diff --git a/aws/sdk/integration-tests/s3/tests/custom-error-deserializer.rs b/aws/sdk/integration-tests/s3/tests/custom-error-deserializer.rs deleted file mode 100644 index 46b1fc50f72..00000000000 --- a/aws/sdk/integration-tests/s3/tests/custom-error-deserializer.rs +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use aws_sdk_s3::operation::GetObject; -use aws_sdk_s3::ErrorExt; -use aws_smithy_http::response::ParseHttpResponse; -use bytes::Bytes; - -#[test] -fn deserialize_extended_errors() { - let resp = http::Response::builder() - .header( - "x-amz-id-2", - "gyB+3jRPnrkN98ZajxHXr3u7EFM67bNgSAxexeEHndCX/7GRnfTXxReKUQF28IfP", - ) - .header("x-amz-request-id", "3B3C7C725673C630") - .status(404) - .body( - r#" - - NoSuchKey - The resource you requested does not exist - /mybucket/myfoto.jpg - 4442587FB7D0A2F9 -"#, - ) - .unwrap(); - let err = GetObject::new() - .parse_loaded(&resp.map(Bytes::from)) - .expect_err("status was 404, this is an error"); - assert_eq!( - err.meta().extended_request_id(), - Some("gyB+3jRPnrkN98ZajxHXr3u7EFM67bNgSAxexeEHndCX/7GRnfTXxReKUQF28IfP") - ); - assert_eq!(err.meta().request_id(), Some("4442587FB7D0A2F9")); -} diff --git a/aws/sdk/integration-tests/s3/tests/endpoints.rs b/aws/sdk/integration-tests/s3/tests/endpoints.rs index 93c37f5c872..02b1718569e 100644 --- a/aws/sdk/integration-tests/s3/tests/endpoints.rs +++ b/aws/sdk/integration-tests/s3/tests/endpoints.rs @@ -16,7 +16,7 @@ fn test_client(update_builder: fn(Builder) -> Builder) -> (CaptureRequestReceive let sdk_config = SdkConfig::builder() .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) .region(Region::new("us-west-4")) - .http_connector(conn.clone()) + .http_connector(conn) .build(); let client = Client::from_conf(update_builder(Builder::from(&sdk_config)).build()); (captured_request, client) diff --git a/aws/sdk/integration-tests/s3/tests/presigning.rs b/aws/sdk/integration-tests/s3/tests/presigning.rs index 341c7f3e0ce..c68424a5946 100644 --- a/aws/sdk/integration-tests/s3/tests/presigning.rs +++ b/aws/sdk/integration-tests/s3/tests/presigning.rs @@ -130,3 +130,15 @@ async fn test_presigned_upload_part() -> Result<(), Box> { ); Ok(()) } + +#[tokio::test] +async fn test_presigning_object_lambda() -> Result<(), Box> { + let presigned = presign_input!(s3::input::GetObjectInput::builder() + .bucket("arn:aws:s3-object-lambda:us-west-2:123456789012:accesspoint:my-banner-ap-name") + .key("test2.txt") + .build() + .unwrap()); + // since the URI is `my-banner-api-name...` we know EP2 is working properly for presigning + assert_eq!(presigned.uri().to_string(), "https://my-banner-ap-name-123456789012.s3-object-lambda.us-west-2.amazonaws.com/test2.txt?x-id=GetObject&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=ANOTREAL%2F20090213%2Fus-west-2%2Fs3-object-lambda%2Faws4_request&X-Amz-Date=20090213T233131Z&X-Amz-Expires=30&X-Amz-SignedHeaders=host&X-Amz-Signature=027976453050b6f9cca7af80a59c05ee572b462e0fc1ef564c59412b903fcdf2&X-Amz-Security-Token=notarealsessiontoken"); + Ok(()) +} diff --git a/aws/sdk/integration-tests/s3/tests/query-strings-are-correctly-encoded.rs b/aws/sdk/integration-tests/s3/tests/query-strings-are-correctly-encoded.rs index 858d0138410..7714becfe55 100644 --- a/aws/sdk/integration-tests/s3/tests/query-strings-are-correctly-encoded.rs +++ b/aws/sdk/integration-tests/s3/tests/query-strings-are-correctly-encoded.rs @@ -71,7 +71,7 @@ async fn test_s3_signer_query_string_with_all_valid_chars() { #[tokio::test] #[ignore] async fn test_query_strings_are_correctly_encoded() { - use aws_sdk_s3::error::{ListObjectsV2Error, ListObjectsV2ErrorKind}; + use aws_sdk_s3::error::ListObjectsV2Error; use aws_smithy_http::result::SdkError; tracing_subscriber::fmt::init(); @@ -92,22 +92,19 @@ async fn test_query_strings_are_correctly_encoded() { .send() .await; if let Err(SdkError::ServiceError(context)) = res { - let ListObjectsV2Error { kind, .. } = context.err(); - match kind { - ListObjectsV2ErrorKind::Unhandled(e) + match context.err() { + ListObjectsV2Error::Unhandled(e) if e.to_string().contains("SignatureDoesNotMatch") => { chars_that_break_signing.push(byte); } - ListObjectsV2ErrorKind::Unhandled(e) if e.to_string().contains("InvalidUri") => { + ListObjectsV2Error::Unhandled(e) if e.to_string().contains("InvalidUri") => { chars_that_break_uri_parsing.push(byte); } - ListObjectsV2ErrorKind::Unhandled(e) - if e.to_string().contains("InvalidArgument") => - { + ListObjectsV2Error::Unhandled(e) if e.to_string().contains("InvalidArgument") => { chars_that_are_invalid_arguments.push(byte); } - ListObjectsV2ErrorKind::Unhandled(e) if e.to_string().contains("InvalidToken") => { + ListObjectsV2Error::Unhandled(e) if e.to_string().contains("InvalidToken") => { panic!("refresh your credentials and run this test again"); } e => todo!("unexpected error: {:?}", e), diff --git a/aws/sdk/integration-tests/s3/tests/request_id.rs b/aws/sdk/integration-tests/s3/tests/request_id.rs new file mode 100644 index 00000000000..957dd8cb284 --- /dev/null +++ b/aws/sdk/integration-tests/s3/tests/request_id.rs @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_sdk_s3::error::GetObjectError; +use aws_sdk_s3::operation::{GetObject, ListBuckets}; +use aws_sdk_s3::types::{RequestId, RequestIdExt}; +use aws_smithy_http::body::SdkBody; +use aws_smithy_http::operation; +use aws_smithy_http::response::ParseHttpResponse; +use bytes::Bytes; + +#[test] +fn get_request_id_from_modeled_error() { + let resp = http::Response::builder() + .header("x-amz-request-id", "correct-request-id") + .header("x-amz-id-2", "correct-extended-request-id") + .status(404) + .body( + r#" + + NoSuchKey + The resource you requested does not exist + /mybucket/myfoto.jpg + incorrect-request-id + "#, + ) + .unwrap(); + let err = GetObject::new() + .parse_loaded(&resp.map(Bytes::from)) + .expect_err("status was 404, this is an error"); + assert!(matches!(err, GetObjectError::NoSuchKey(_))); + assert_eq!(Some("correct-request-id"), err.request_id()); + assert_eq!(Some("correct-request-id"), err.meta().request_id()); + assert_eq!( + Some("correct-extended-request-id"), + err.extended_request_id() + ); + assert_eq!( + Some("correct-extended-request-id"), + err.meta().extended_request_id() + ); +} + +#[test] +fn get_request_id_from_unmodeled_error() { + let resp = http::Response::builder() + .header("x-amz-request-id", "correct-request-id") + .header("x-amz-id-2", "correct-extended-request-id") + .status(500) + .body( + r#" + + SomeUnmodeledError + Something bad happened + /mybucket/myfoto.jpg + incorrect-request-id + "#, + ) + .unwrap(); + let err = GetObject::new() + .parse_loaded(&resp.map(Bytes::from)) + .expect_err("status 500"); + assert!(matches!(err, GetObjectError::Unhandled(_))); + assert_eq!(Some("correct-request-id"), err.request_id()); + assert_eq!(Some("correct-request-id"), err.meta().request_id()); + assert_eq!( + Some("correct-extended-request-id"), + err.extended_request_id() + ); + assert_eq!( + Some("correct-extended-request-id"), + err.meta().extended_request_id() + ); +} + +#[test] +fn get_request_id_from_successful_nonstreaming_response() { + let resp = http::Response::builder() + .header("x-amz-request-id", "correct-request-id") + .header("x-amz-id-2", "correct-extended-request-id") + .status(200) + .body( + r#" + + some-idsome-display-name + + "#, + ) + .unwrap(); + let output = ListBuckets::new() + .parse_loaded(&resp.map(Bytes::from)) + .expect("valid successful response"); + assert_eq!(Some("correct-request-id"), output.request_id()); + assert_eq!( + Some("correct-extended-request-id"), + output.extended_request_id() + ); +} + +#[test] +fn get_request_id_from_successful_streaming_response() { + let resp = http::Response::builder() + .header("x-amz-request-id", "correct-request-id") + .header("x-amz-id-2", "correct-extended-request-id") + .status(200) + .body(SdkBody::from("some streaming file data")) + .unwrap(); + let mut resp = operation::Response::new(resp); + let output = GetObject::new() + .parse_unloaded(&mut resp) + .expect("valid successful response"); + assert_eq!(Some("correct-request-id"), output.request_id()); + assert_eq!( + Some("correct-extended-request-id"), + output.extended_request_id() + ); +} + +// Verify that the conversion from operation error to the top-level service error maintains the request ID +#[test] +fn conversion_to_service_error_maintains_request_id() { + let resp = http::Response::builder() + .header("x-amz-request-id", "correct-request-id") + .header("x-amz-id-2", "correct-extended-request-id") + .status(404) + .body( + r#" + + NoSuchKey + The resource you requested does not exist + /mybucket/myfoto.jpg + incorrect-request-id + "#, + ) + .unwrap(); + let err = GetObject::new() + .parse_loaded(&resp.map(Bytes::from)) + .expect_err("status was 404, this is an error"); + + let service_error: aws_sdk_s3::Error = err.into(); + assert_eq!(Some("correct-request-id"), service_error.request_id()); + assert_eq!( + Some("correct-extended-request-id"), + service_error.extended_request_id() + ); +} diff --git a/aws/sdk/integration-tests/sts/tests/retry_idp_comms_err.rs b/aws/sdk/integration-tests/sts/tests/retry_idp_comms_err.rs index 6fe9895cd3b..3b546bbb0b9 100644 --- a/aws/sdk/integration-tests/sts/tests/retry_idp_comms_err.rs +++ b/aws/sdk/integration-tests/sts/tests/retry_idp_comms_err.rs @@ -4,24 +4,21 @@ */ use aws_sdk_sts as sts; -use aws_smithy_types::error::Error as ErrorMeta; +use aws_smithy_types::error::ErrorMetadata; use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind}; -use sts::error::{ - AssumeRoleWithWebIdentityError, AssumeRoleWithWebIdentityErrorKind, - IdpCommunicationErrorException, -}; +use sts::error::{AssumeRoleWithWebIdentityError, IdpCommunicationErrorException}; #[tokio::test] async fn idp_comms_err_retryable() { - let error = AssumeRoleWithWebIdentityError::new( - AssumeRoleWithWebIdentityErrorKind::IdpCommunicationErrorException( - IdpCommunicationErrorException::builder() - .message("test") - .build(), - ), - ErrorMeta::builder() - .code("IDPCommunicationError") + let error = AssumeRoleWithWebIdentityError::IdpCommunicationErrorException( + IdpCommunicationErrorException::builder() .message("test") + .meta( + ErrorMetadata::builder() + .code("IDPCommunicationError") + .message("test") + .build(), + ) .build(), ); assert_eq!( diff --git a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs index f48515038a9..fe6b028820f 100644 --- a/aws/sdk/integration-tests/transcribestreaming/tests/test.rs +++ b/aws/sdk/integration-tests/transcribestreaming/tests/test.rs @@ -4,9 +4,7 @@ */ use async_stream::stream; -use aws_sdk_transcribestreaming::error::{ - AudioStreamError, TranscriptResultStreamError, TranscriptResultStreamErrorKind, -}; +use aws_sdk_transcribestreaming::error::{AudioStreamError, TranscriptResultStreamError}; use aws_sdk_transcribestreaming::model::{ AudioEvent, AudioStream, LanguageCode, MediaEncoding, TranscriptResultStream, }; @@ -76,10 +74,7 @@ async fn test_error() { match output.transcript_result_stream.recv().await { Err(SdkError::ServiceError(context)) => match context.err() { - TranscriptResultStreamError { - kind: TranscriptResultStreamErrorKind::BadRequestException(err), - .. - } => { + TranscriptResultStreamError::BadRequestException(err) => { assert_eq!( Some("A complete signal was sent without the preceding empty frame."), err.message() diff --git a/aws/sdk/integration-tests/using-native-tls-instead-of-rustls/Cargo.toml b/aws/sdk/integration-tests/using-native-tls-instead-of-rustls/Cargo.toml new file mode 100644 index 00000000000..3642d7ba248 --- /dev/null +++ b/aws/sdk/integration-tests/using-native-tls-instead-of-rustls/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "using-native-tls-instead-of-rustls" +version = "0.1.0" +authors = ["AWS Rust SDK Team "] +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dev-dependencies] +# aws-config pulls in rustls and several other things by default. We have to disable defaults in order to use native-tls +# and then manually bring the other defaults back +aws-config = { path = "../../build/aws-sdk/sdk/aws-config", default-features = false, features = [ + "native-tls", + "rt-tokio", +] } +# aws-sdk-s3 brings in rustls by default so we disable that in order to use native-tls only +aws-sdk-s3 = { path = "../../build/aws-sdk/sdk/s3", default-features = false, features = [ + "native-tls", +] } +tokio = { version = "1.20.1", features = ["rt", "macros"] } diff --git a/aws/sdk/integration-tests/using-native-tls-instead-of-rustls/tests/no-rustls-in-dependency.rs b/aws/sdk/integration-tests/using-native-tls-instead-of-rustls/tests/no-rustls-in-dependency.rs new file mode 100644 index 00000000000..dddeebc4795 --- /dev/null +++ b/aws/sdk/integration-tests/using-native-tls-instead-of-rustls/tests/no-rustls-in-dependency.rs @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +/// The SDK defaults to using RusTLS by default but you can also use [`native_tls`](https://github.com/sfackler/rust-native-tls) +/// which will choose a TLS implementation appropriate for your platform. This test looks much like +/// any other. Activating and deactivating `features` in your app's `Cargo.toml` is all that's needed. + +async fn list_buckets() -> Result<(), aws_sdk_s3::Error> { + let sdk_config = aws_config::load_from_env().await; + let client = aws_sdk_s3::Client::new(&sdk_config); + + let _resp = client.list_buckets().send().await?; + + Ok(()) +} + +/// You can run this test to ensure that it is only using `native-tls` and +/// that nothing is pulling in `rustls` as a dependency +#[test] +#[should_panic = "error: package ID specification `rustls` did not match any packages"] +fn test_rustls_is_not_in_dependency_tree() { + let cargo_location = std::env::var("CARGO").unwrap(); + let cargo_command = std::process::Command::new(&cargo_location) + .arg("tree") + .arg("--invert") + .arg("rustls") + .output() + .expect("failed to run 'cargo tree'"); + + let stderr = String::from_utf8_lossy(&cargo_command.stderr); + + // We expect the call to `cargo tree` to error out. If it did, we panic with the resulting + // message here. In the case that no error message is set, that's bad. + if !stderr.is_empty() { + panic!("{}", stderr); + } + + // Uh oh. We expected an error message but got none, likely because `cargo tree` found + // `rustls` in our dependencies. We'll print out the message we got to see what went wrong. + let stdout = String::from_utf8_lossy(&cargo_command.stdout); + + println!("{}", stdout) +} + +// NOTE: not currently run in CI, separate PR will set up a with-creds CI runner +#[tokio::test] +#[ignore] +async fn needs_creds_native_tls_works() { + list_buckets().await.expect("should succeed") +} diff --git a/build.gradle.kts b/build.gradle.kts index 2c06af67257..a2d6385db4e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -14,9 +14,7 @@ buildscript { } } -plugins { - kotlin("jvm") version "1.3.72" apply false -} +plugins { } allprojects { repositories { @@ -61,7 +59,7 @@ tasks.register("ktlint") { group = "Verification" classpath = configurations.getByName("ktlint") mainClass.set("com.pinterest.ktlint.Main") - args = listOf("--verbose", "--relative", "--") + lintPaths + args = listOf("--log-level=info", "--relative", "--") + lintPaths // https://github.com/pinterest/ktlint/issues/1195#issuecomment-1009027802 jvmArgs("--add-opens", "java.base/java.lang=ALL-UNNAMED") } @@ -71,7 +69,7 @@ tasks.register("ktlintFormat") { group = "formatting" classpath = configurations.getByName("ktlint") mainClass.set("com.pinterest.ktlint.Main") - args = listOf("--verbose", "--relative", "--format", "--") + lintPaths + args = listOf("--log-level=info", "--relative", "--format", "--") + lintPaths // https://github.com/pinterest/ktlint/issues/1195#issuecomment-1009027802 jvmArgs("--add-opens", "java.base/java.lang=ALL-UNNAMED") } diff --git a/buildSrc/src/main/kotlin/CodegenTestCommon.kt b/buildSrc/src/main/kotlin/CodegenTestCommon.kt index d975d1a521e..033d418c843 100644 --- a/buildSrc/src/main/kotlin/CodegenTestCommon.kt +++ b/buildSrc/src/main/kotlin/CodegenTestCommon.kt @@ -62,7 +62,7 @@ enum class Cargo(val toString: String) { CHECK("cargoCheck"), TEST("cargoTest"), DOCS("cargoDoc"), - CLIPPY("cargoClippy"); + CLIPPY("cargoClippy"), } private fun generateCargoWorkspace(pluginName: String, tests: List) = @@ -86,7 +86,9 @@ private fun codegenTests(properties: PropertyRetriever, allTests: List { AllCargoCommands } require(ret.isNotEmpty()) { - "None of the provided cargo commands (`$cargoCommandsOverride`) are valid cargo commands (`${AllCargoCommands.map { it.toString }}`)" + "None of the provided cargo commands (`$cargoCommandsOverride`) are valid cargo commands (`${AllCargoCommands.map { + it.toString + }}`)" } return ret } diff --git a/buildSrc/src/main/kotlin/CrateSet.kt b/buildSrc/src/main/kotlin/CrateSet.kt index c915649f617..90e1df95ab3 100644 --- a/buildSrc/src/main/kotlin/CrateSet.kt +++ b/buildSrc/src/main/kotlin/CrateSet.kt @@ -21,6 +21,7 @@ object CrateSet { "aws-smithy-checksums", "aws-smithy-eventstream", "aws-smithy-http", + "aws-smithy-http-auth", "aws-smithy-http-tower", "aws-smithy-json", "aws-smithy-protocol-test", diff --git a/buildSrc/src/main/kotlin/aws/sdk/DocsLandingPage.kt b/buildSrc/src/main/kotlin/aws/sdk/DocsLandingPage.kt index d522a02f4a3..715917c4db2 100644 --- a/buildSrc/src/main/kotlin/aws/sdk/DocsLandingPage.kt +++ b/buildSrc/src/main/kotlin/aws/sdk/DocsLandingPage.kt @@ -44,7 +44,9 @@ fun Project.docsLandingPage(awsServices: AwsServices, outputPath: File) { /** * Generate a link to the examples for a given service */ -private fun examplesLink(service: AwsService, project: Project) = service.examplesUri(project)?.let { "([examples]($it))" } +private fun examplesLink(service: AwsService, project: Project) = service.examplesUri(project)?.let { + "([examples]($it))" +} /** * Generate a link to the docs diff --git a/buildSrc/src/main/kotlin/aws/sdk/ModelMetadata.kt b/buildSrc/src/main/kotlin/aws/sdk/ModelMetadata.kt index a70aa7a0d68..95d65fd106d 100644 --- a/buildSrc/src/main/kotlin/aws/sdk/ModelMetadata.kt +++ b/buildSrc/src/main/kotlin/aws/sdk/ModelMetadata.kt @@ -11,7 +11,7 @@ import java.io.File enum class ChangeType { UNCHANGED, FEATURE, - DOCUMENTATION + DOCUMENTATION, } /** Model metadata toml file */ diff --git a/buildSrc/src/main/kotlin/aws/sdk/ServiceLoader.kt b/buildSrc/src/main/kotlin/aws/sdk/ServiceLoader.kt index fdc6448b0c5..44510d1b118 100644 --- a/buildSrc/src/main/kotlin/aws/sdk/ServiceLoader.kt +++ b/buildSrc/src/main/kotlin/aws/sdk/ServiceLoader.kt @@ -135,9 +135,9 @@ fun Project.discoverServices(awsModelsPath: String?, serviceMembership: Membersh serviceMembership.exclusions.forEach { disabledService -> check(baseModules.contains(disabledService)) { "Service $disabledService was explicitly disabled but no service was generated with that name. Generated:\n ${ - baseModules.joinToString( - "\n ", - ) + baseModules.joinToString( + "\n ", + ) }" } } @@ -206,7 +206,9 @@ fun parseMembership(rawList: String): Membership { } val conflictingMembers = inclusions.intersect(exclusions) - require(conflictingMembers.isEmpty()) { "$conflictingMembers specified both for inclusion and exclusion in $rawList" } + require(conflictingMembers.isEmpty()) { + "$conflictingMembers specified both for inclusion and exclusion in $rawList" + } return Membership(inclusions, exclusions) } diff --git a/codegen-client-test/build.gradle.kts b/codegen-client-test/build.gradle.kts index 434bcef65f0..f3796a6911f 100644 --- a/codegen-client-test/build.gradle.kts +++ b/codegen-client-test/build.gradle.kts @@ -79,6 +79,11 @@ val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> """.trimIndent(), imports = listOf("$commonModels/naming-obstacle-course-ops.smithy"), ), + CodegenTest( + "casing#ACRONYMInside_Service", + "naming_test_casing", + imports = listOf("$commonModels/naming-obstacle-course-casing.smithy"), + ), CodegenTest( "naming_obs_structs#NamingObstacleCourseStructs", "naming_test_structs", diff --git a/codegen-client/build.gradle.kts b/codegen-client/build.gradle.kts index ba6ac6ac1b4..62d543beb4d 100644 --- a/codegen-client/build.gradle.kts +++ b/codegen-client/build.gradle.kts @@ -28,6 +28,10 @@ dependencies { implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-waiters:$smithyVersion") implementation("software.amazon.smithy:smithy-rules-engine:$smithyVersion") + + // `smithy.framework#ValidationException` is defined here, which is used in event stream +// marshalling/unmarshalling tests. + testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } tasks.compileKotlin { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt index 44099cceb1b..881cecdad86 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.rust.codegen.client.smithy import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.OperationShape @@ -16,38 +17,42 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientEnumGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceGenerator +import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorGenerator +import software.amazon.smithy.rust.codegen.client.smithy.generators.error.OperationErrorGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader import software.amazon.smithy.rust.codegen.client.smithy.transformers.AddErrorMessage import software.amazon.smithy.rust.codegen.client.smithy.transformers.RemoveEventStreamOperations +import software.amazon.smithy.rust.codegen.core.rustlang.EscapeFor import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig +import software.amazon.smithy.rust.codegen.core.smithy.contextName import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.OperationErrorGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer -import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors -import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.core.util.CommandFailed import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isEventStream import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.runCommand +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import java.util.logging.Logger /** @@ -57,7 +62,6 @@ class ClientCodegenVisitor( context: PluginContext, private val codegenDecorator: ClientCodegenDecorator, ) : ShapeVisitor.Default() { - private val logger = Logger.getLogger(javaClass.name) private val settings = ClientRustSettings.from(context.model, context.settings) @@ -70,12 +74,21 @@ class ClientCodegenVisitor( private val protocolGenerator: ClientProtocolGenerator init { - val symbolVisitorConfig = - SymbolVisitorConfig( - runtimeConfig = settings.runtimeConfig, - renameExceptions = settings.codegenConfig.renameExceptions, - nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1, - ) + val rustSymbolProviderConfig = RustSymbolProviderConfig( + runtimeConfig = settings.runtimeConfig, + renameExceptions = settings.codegenConfig.renameExceptions, + nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1, + moduleProvider = when (settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> ClientModuleProvider + else -> OldModuleSchemeClientModuleProvider + }, + nameBuilderFor = { symbol -> + when (settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> "${symbol.name}Builder" + else -> "Builder" + } + }, + ) val baseModel = baselineTransform(context.model) val untransformedService = settings.getService(baseModel) val (protocol, generator) = ClientProtocolLoader( @@ -85,7 +98,7 @@ class ClientCodegenVisitor( model = codegenDecorator.transformModel(untransformedService, baseModel) // the model transformer _might_ change the service shape val service = settings.getService(model) - symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(model, service, symbolVisitorConfig) + symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(settings, model, service, rustSymbolProviderConfig) codegenContext = ClientCodegenContext(model, symbolProvider, service, protocol, settings, codegenDecorator) @@ -108,14 +121,14 @@ class ClientCodegenVisitor( // Add errors attached at the service level to the models .let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) } // Add `Box` to recursive shapes as necessary - .let(RecursiveShapeBoxer::transform) + .let(RecursiveShapeBoxer()::transform) // Normalize the `message` field on errors when enabled in settings (default: true) .letIf(settings.codegenConfig.addMessageToErrors, AddErrorMessage::transform) // NormalizeOperations by ensuring every operation has an input & output shape .let(OperationNormalizer::transform) // Drop unsupported event stream operations from the model .let { RemoveEventStreamOperations.transform(it, settings) } - // - Normalize event stream operations + // Normalize event stream operations .let(EventStreamNormalizer::transform) /** @@ -177,6 +190,29 @@ class ClientCodegenVisitor( override fun getDefault(shape: Shape?) { } + // TODO(CrateReorganization): Remove this function when cleaning up `enableNewCrateOrganizationScheme` + private fun RustCrate.maybeInPrivateModuleWithReexport( + privateModule: RustModule.LeafModule, + symbol: Symbol, + writer: Writable, + ) { + if (codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme) { + inPrivateModuleWithReexport(privateModule, symbol, writer) + } else { + withModule(symbol.module(), writer) + } + } + + private fun privateModule(shape: Shape): RustModule.LeafModule = + RustModule.private(privateModuleName(shape), parent = symbolProvider.moduleForShape(shape)) + + private fun privateModuleName(shape: Shape): String = + shape.contextName(codegenContext.serviceShape).let(this::privateModuleName) + + private fun privateModuleName(name: String): String = + // Add the underscore to avoid colliding with public module names + "_" + RustReservedWords.escapeIfNeeded(name.toSnakeCase(), EscapeFor.ModuleName) + /** * Structure Shape Visitor * @@ -187,17 +223,50 @@ class ClientCodegenVisitor( * This function _does not_ generate any serializers */ override fun structureShape(shape: StructureShape) { - logger.fine("generating a structure...") - rustCrate.useShapeWriter(shape) { - StructureGenerator(model, symbolProvider, this, shape).render() - if (!shape.hasTrait()) { - val builderGenerator = BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape) - builderGenerator.render(this) - this.implBlock(shape, symbolProvider) { - builderGenerator.renderConvenienceMethod(this) + val (renderStruct, renderBuilder) = when (val errorTrait = shape.getTrait()) { + null -> { + val struct: Writable = { + StructureGenerator( + model, + symbolProvider, + this, + shape, + codegenDecorator.structureCustomizations(codegenContext, emptyList()), + ).render() + + implBlock(symbolProvider.toSymbol(shape)) { + BuilderGenerator.renderConvenienceMethod(this, symbolProvider, shape) + } } + val builder: Writable = { + BuilderGenerator( + codegenContext.model, + codegenContext.symbolProvider, + shape, + codegenDecorator.builderCustomizations(codegenContext, emptyList()), + ).render(this) + } + struct to builder + } + else -> { + val errorGenerator = ErrorGenerator( + model, + symbolProvider, + shape, + errorTrait, + codegenDecorator.errorImplCustomizations(codegenContext, emptyList()), + ) + errorGenerator::renderStruct to errorGenerator::renderBuilder } } + + val privateModule = privateModule(shape) + rustCrate.maybeInPrivateModuleWithReexport(privateModule, symbolProvider.toSymbol(shape)) { + renderStruct(this) + } + rustCrate.maybeInPrivateModuleWithReexport(privateModule, symbolProvider.symbolForBuilder(shape)) { + renderBuilder(this) + } } /** @@ -206,9 +275,10 @@ class ClientCodegenVisitor( * Although raw strings require no code generation, enums are actually `EnumTrait` applied to string shapes. */ override fun stringShape(shape: StringShape) { - shape.getTrait()?.also { enum -> - rustCrate.useShapeWriter(shape) { - EnumGenerator(model, symbolProvider, this, shape, enum).render() + if (shape.hasTrait()) { + val privateModule = privateModule(shape) + rustCrate.maybeInPrivateModuleWithReexport(privateModule, symbolProvider.toSymbol(shape)) { + ClientEnumGenerator(codegenContext, shape).render(this) } } } @@ -221,17 +291,17 @@ class ClientCodegenVisitor( * Note: this does not generate serializers */ override fun unionShape(shape: UnionShape) { - rustCrate.useShapeWriter(shape) { + rustCrate.maybeInPrivateModuleWithReexport(privateModule(shape), symbolProvider.toSymbol(shape)) { UnionGenerator(model, symbolProvider, this, shape, renderUnknownVariant = true).render() } if (shape.isEventStream()) { - rustCrate.withModule(RustModule.Error) { - val symbol = symbolProvider.toSymbol(shape) - val errors = shape.eventStreamErrors() - .map { model.expectShape(it.asMemberShape().get().target, StructureShape::class.java) } - val errorSymbol = shape.eventStreamErrorSymbol(symbolProvider) - OperationErrorGenerator(model, symbolProvider, symbol, errors) - .renderErrors(this, errorSymbol, symbol) + rustCrate.withModule(ClientRustModule.Error) { + OperationErrorGenerator( + model, + symbolProvider, + shape, + codegenDecorator.errorCustomizations(codegenContext, emptyList()), + ).render(this) } } } @@ -240,13 +310,12 @@ class ClientCodegenVisitor( * Generate errors for operation shapes */ override fun operationShape(shape: OperationShape) { - rustCrate.withModule(RustModule.Error) { - val operationSymbol = symbolProvider.toSymbol(shape) + rustCrate.withModule(symbolProvider.moduleForOperationError(shape)) { OperationErrorGenerator( model, symbolProvider, - operationSymbol, - shape.operationErrors(model).map { it.asStructureShape().get() }, + shape, + codegenDecorator.errorCustomizations(codegenContext, emptyList()), ).render(this) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt new file mode 100644 index 00000000000..895b2c1ff4a --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustModule.kt @@ -0,0 +1,192 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.smithy.ModuleProvider +import software.amazon.smithy.rust.codegen.core.smithy.ModuleProviderContext +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.contextName +import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase + +/** + * Modules for code generated client crates. + */ +object ClientRustModule { + /** crate */ + val root = RustModule.LibRs + + /** crate::client */ + val client = Client.self + object Client { + /** crate::client */ + val self = RustModule.public("client", "Client and fluent builders for calling the service.") + + /** crate::client::customize */ + val customize = RustModule.public("customize", parent = self, documentation = "Operation customization and supporting types") + } + + val Config = RustModule.public("config", documentation = "Configuration for the service.") + val Error = RustModule.public("error", documentation = "All error types that operations can return. Documentation on these types is copied from the model.") + val Operation = RustModule.public("operation", documentation = "All operations that this crate can perform.") + val Meta = RustModule.public("meta", documentation = "Information about this crate.") + val Input = RustModule.public("input", documentation = "Input structures for operations. Documentation on these types is copied from the model.") + val Output = RustModule.public("output", documentation = "Output structures for operations. Documentation on these types is copied from the model.") + val Primitives = RustModule.public("primitives", documentation = "Data primitives referenced by other data types.") + + /** crate::types */ + val types = Types.self + object Types { + /** crate::types */ + val self = RustModule.public("types", documentation = "Data primitives referenced by other data types.") + + /** crate::types::error */ + val Error = RustModule.public("error", parent = self, documentation = "All error types that operations can return. Documentation on these types is copied from the model.") + } + + // TODO(CrateReorganization): Remove this module when cleaning up `enableNewCrateOrganizationScheme` + val Model = RustModule.public("model", documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.") +} + +object ClientModuleProvider : ModuleProvider { + override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = when (shape) { + is OperationShape -> perOperationModule(context, shape) + is StructureShape -> when { + shape.hasTrait() -> ClientRustModule.Types.Error + shape.hasTrait() -> perOperationModule(context, shape) + shape.hasTrait() -> perOperationModule(context, shape) + else -> ClientRustModule.types + } + + else -> ClientRustModule.types + } + + override fun moduleForOperationError( + context: ModuleProviderContext, + operation: OperationShape, + ): RustModule.LeafModule = perOperationModule(context, operation) + + override fun moduleForEventStreamError( + context: ModuleProviderContext, + eventStream: UnionShape, + ): RustModule.LeafModule = ClientRustModule.Error + + override fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule = + RustModule.public("builders", parent = symbol.module(), documentation = "Builders") + + private fun Shape.findOperation(model: Model): OperationShape { + val inputTrait = getTrait() + val outputTrait = getTrait() + return when { + this is OperationShape -> this + inputTrait != null -> model.expectShape(inputTrait.operation, OperationShape::class.java) + outputTrait != null -> model.expectShape(outputTrait.operation, OperationShape::class.java) + else -> UNREACHABLE("this is only called with compatible shapes") + } + } + + private fun perOperationModule(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule { + val operationShape = shape.findOperation(context.model) + val contextName = operationShape.contextName(context.serviceShape) + val operationModuleName = + RustReservedWords.escapeIfNeeded(contextName.toSnakeCase()) + return RustModule.public( + operationModuleName, + parent = ClientRustModule.Operation, + documentation = "Types for the `$contextName` operation.", + ) + } +} + +// TODO(CrateReorganization): Remove this provider +object OldModuleSchemeClientModuleProvider : ModuleProvider { + override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = when (shape) { + is OperationShape -> ClientRustModule.Operation + is StructureShape -> when { + shape.hasTrait() -> ClientRustModule.Error + shape.hasTrait() -> ClientRustModule.Input + shape.hasTrait() -> ClientRustModule.Output + else -> ClientRustModule.Model + } + + else -> ClientRustModule.Model + } + + override fun moduleForOperationError( + context: ModuleProviderContext, + operation: OperationShape, + ): RustModule.LeafModule = ClientRustModule.Error + + override fun moduleForEventStreamError( + context: ModuleProviderContext, + eventStream: UnionShape, + ): RustModule.LeafModule = ClientRustModule.Error + + override fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule { + val builderNamespace = RustReservedWords.escapeIfNeeded(symbol.name.toSnakeCase()) + return RustModule.new( + builderNamespace, + visibility = Visibility.PUBLIC, + parent = symbol.module(), + inline = true, + documentation = "See [`${symbol.name}`](${symbol.module().fullyQualifiedPath()}::${symbol.name}).", + ) + } +} + +// TODO(CrateReorganization): Remove when cleaning up `enableNewCrateOrganizationScheme` +fun ClientCodegenContext.featureGatedConfigModule() = when (settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> ClientRustModule.Config + else -> ClientRustModule.root +} + +// TODO(CrateReorganization): Remove when cleaning up `enableNewCrateOrganizationScheme` +fun ClientCodegenContext.featureGatedCustomizeModule() = when (settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> ClientRustModule.Client.customize + else -> RustModule.public( + "customize", + "Operation customization and supporting types", + parent = ClientRustModule.Operation, + ) +} + +// TODO(CrateReorganization): Remove when cleaning up `enableNewCrateOrganizationScheme` +fun ClientCodegenContext.featureGatedMetaModule() = when (settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> ClientRustModule.Meta + else -> ClientRustModule.root +} + +// TODO(CrateReorganization): Remove when cleaning up `enableNewCrateOrganizationScheme` +fun ClientCodegenContext.featureGatedPaginatorModule(symbolProvider: RustSymbolProvider, operation: OperationShape) = + when (settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> RustModule.public( + "paginator", + parent = symbolProvider.moduleForShape(operation), + documentation = "Paginator for this operation", + ) + else -> RustModule.public("paginator", "Paginators for the service") + } + +// TODO(CrateReorganization): Remove when cleaning up `enableNewCrateOrganizationScheme` +fun ClientCodegenContext.featureGatedPrimitivesModule() = when (settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> ClientRustModule.Primitives + else -> ClientRustModule.types +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt index 5fb6eb1d1bf..a33fa4f9e7d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientRustSettings.kt @@ -86,6 +86,8 @@ data class ClientCodegenConfig( val addMessageToErrors: Boolean = defaultAddMessageToErrors, // TODO(EventStream): [CLEANUP] Remove this property when turning on Event Stream for all services val eventStreamAllowList: Set = defaultEventStreamAllowList, + // TODO(CrateReorganization): Remove this once we commit to the breaking change + val enableNewCrateOrganizationScheme: Boolean = defaultEnableNewCrateOrganizationScheme, ) : CoreCodegenConfig( formatTimeoutSeconds, debugMode, ) { @@ -94,6 +96,7 @@ data class ClientCodegenConfig( private const val defaultIncludeFluentClient = true private const val defaultAddMessageToErrors = true private val defaultEventStreamAllowList: Set = emptySet() + private const val defaultEnableNewCrateOrganizationScheme = false fun fromCodegenConfigAndNode(coreCodegenConfig: CoreCodegenConfig, node: Optional) = if (node.isPresent) { @@ -106,12 +109,14 @@ data class ClientCodegenConfig( renameExceptions = node.get().getBooleanMemberOrDefault("renameErrors", defaultRenameExceptions), includeFluentClient = node.get().getBooleanMemberOrDefault("includeFluentClient", defaultIncludeFluentClient), addMessageToErrors = node.get().getBooleanMemberOrDefault("addMessageToErrors", defaultAddMessageToErrors), + enableNewCrateOrganizationScheme = node.get().getBooleanMemberOrDefault("enableNewCrateOrganizationScheme", false), ) } else { ClientCodegenConfig( formatTimeoutSeconds = coreCodegenConfig.formatTimeoutSeconds, debugMode = coreCodegenConfig.debugMode, eventStreamAllowList = defaultEventStreamAllowList, + enableNewCrateOrganizationScheme = defaultEnableNewCrateOrganizationScheme, ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt index 0189a4cb65d..38af758ad04 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.build.PluginContext import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.rust.codegen.client.smithy.customizations.ApiKeyAuthDecorator import software.amazon.smithy.rust.codegen.client.smithy.customizations.ClientCustomizations import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator @@ -16,16 +17,16 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.NoOpEventStre import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCustomizations import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsDecorator import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDecorator -import software.amazon.smithy.rust.codegen.client.testutil.DecoratableBuildPlugin +import software.amazon.smithy.rust.codegen.client.testutil.ClientDecoratableBuildPlugin import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.NonExhaustive import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import java.util.logging.Level import java.util.logging.Logger @@ -36,7 +37,7 @@ import java.util.logging.Logger * `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which * enables the smithy-build plugin to invoke `execute` with all Smithy plugin context + models. */ -class RustClientCodegenPlugin : DecoratableBuildPlugin() { +class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() { override fun getName(): String = "rust-client-codegen" override fun executeWithDecorator( @@ -58,6 +59,7 @@ class RustClientCodegenPlugin : DecoratableBuildPlugin() { FluentClientDecorator(), EndpointsDecorator(), NoOpEventStreamSigningDecorator(), + ApiKeyAuthDecorator(), *decorator, ) @@ -66,24 +68,29 @@ class RustClientCodegenPlugin : DecoratableBuildPlugin() { } companion object { - /** SymbolProvider + /** * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider * - * The Symbol provider is composed of a base `SymbolVisitor` which handles the core functionality, then is layered + * The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered * with other symbol providers, documented inline, to handle the full scope of Smithy types. */ - fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig) = - SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) + fun baseSymbolProvider( + settings: ClientRustSettings, + model: Model, + serviceShape: ServiceShape, + rustSymbolProviderConfig: RustSymbolProviderConfig, + ) = + SymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.CLIENT) } + .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.CLIENT) } // Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) - .let { StreamingShapeSymbolProvider(it, model) } + .let { StreamingShapeSymbolProvider(it) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes - .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf(NonExhaustive)) } + .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(NonExhaustive)) } // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`) - .let { StreamingShapeMetadataProvider(it, model) } + .let { StreamingShapeMetadataProvider(it) } // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot // be the name of an operation input - .let { RustReservedWordSymbolProvider(it, model) } + .let { RustReservedWordSymbolProvider(it) } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ApiKeyAuthDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ApiKeyAuthDecorator.kt new file mode 100644 index 00000000000..0d9f5bde460 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ApiKeyAuthDecorator.kt @@ -0,0 +1,206 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.customizations + +import software.amazon.smithy.model.knowledge.ServiceIndex +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.HttpApiKeyAuthTrait +import software.amazon.smithy.model.traits.OptionalAuthTrait +import software.amazon.smithy.model.traits.Trait +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule +import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.letIf + +/** + * Inserts a ApiKeyAuth configuration into the operation + */ +class ApiKeyAuthDecorator : ClientCodegenDecorator { + override val name: String = "ApiKeyAuth" + override val order: Byte = 10 + + private fun applies(codegenContext: ClientCodegenContext) = + isSupportedApiKeyAuth(codegenContext) + + override fun configCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List { + return baseCustomizations.letIf(applies(codegenContext)) { customizations -> + customizations + ApiKeyConfigCustomization(codegenContext.runtimeConfig) + } + } + + override fun operationCustomizations( + codegenContext: ClientCodegenContext, + operation: OperationShape, + baseCustomizations: List, + ): List { + if (applies(codegenContext) && hasApiKeyAuthScheme(codegenContext, operation)) { + val service = codegenContext.serviceShape + val authDefinition: HttpApiKeyAuthTrait = service.expectTrait(HttpApiKeyAuthTrait::class.java) + return baseCustomizations + ApiKeyOperationCustomization(codegenContext.runtimeConfig, authDefinition) + } + return baseCustomizations + } + + override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + if (applies(codegenContext)) { + rustCrate.withModule(ClientRustModule.Config) { + rust("pub use #T;", apiKey(codegenContext.runtimeConfig)) + } + } + } +} + +/** + * Returns if the service supports the httpApiKeyAuth trait. + * + * @param codegenContext Codegen context that includes the model and service shape + * @return if the httpApiKeyAuth trait is used by the service + */ +private fun isSupportedApiKeyAuth(codegenContext: ClientCodegenContext): Boolean { + return ServiceIndex.of(codegenContext.model).getAuthSchemes(codegenContext.serviceShape).containsKey(HttpApiKeyAuthTrait.ID) +} + +/** + * Returns if the service and operation have the httpApiKeyAuthTrait. + * + * @param codegenContext codegen context that includes the model and service shape + * @param operation operation shape + * @return if the service and operation have the httpApiKeyAuthTrait + */ +private fun hasApiKeyAuthScheme(codegenContext: ClientCodegenContext, operation: OperationShape): Boolean { + val auth: Map = ServiceIndex.of(codegenContext.model).getEffectiveAuthSchemes(codegenContext.serviceShape.getId(), operation.getId()) + return auth.containsKey(HttpApiKeyAuthTrait.ID) && !operation.hasTrait(OptionalAuthTrait.ID) +} + +private class ApiKeyOperationCustomization(private val runtimeConfig: RuntimeConfig, private val authDefinition: HttpApiKeyAuthTrait) : OperationCustomization() { + override fun section(section: OperationSection): Writable = when (section) { + is OperationSection.MutateRequest -> writable { + rustBlock("if let Some(api_key_config) = ${section.config}.api_key()") { + rust( + """ + ${section.request}.properties_mut().insert(api_key_config.clone()); + let api_key = api_key_config.api_key(); + """, + ) + val definitionName = authDefinition.getName() + if (authDefinition.getIn() == HttpApiKeyAuthTrait.Location.QUERY) { + rustTemplate( + """ + let auth_definition = #{http_auth_definition}::query( + "$definitionName".to_owned(), + ); + let name = auth_definition.name(); + let mut query = #{query_writer}::new(${section.request}.http().uri()); + query.insert(name, api_key); + *${section.request}.http_mut().uri_mut() = query.build_uri(); + """, + "http_auth_definition" to + RuntimeType.smithyHttpAuth(runtimeConfig).resolve("definition::HttpAuthDefinition"), + "query_writer" to RuntimeType.smithyHttp(runtimeConfig).resolve("query_writer::QueryWriter"), + ) + } else { + val definitionScheme: String = authDefinition.getScheme() + .map { scheme -> + "Some(\"" + scheme + "\".to_owned())" + } + .orElse("None") + rustTemplate( + """ + let auth_definition = #{http_auth_definition}::header( + "$definitionName".to_owned(), + $definitionScheme, + ); + let name = auth_definition.name(); + let value = match auth_definition.scheme() { + Some(value) => format!("{value} {api_key}"), + None => api_key.to_owned(), + }; + ${section.request} + .http_mut() + .headers_mut() + .insert( + #{http_header}::HeaderName::from_bytes(name.as_bytes()).expect("valid header name for api key auth"), + #{http_header}::HeaderValue::from_bytes(value.as_bytes()).expect("valid header value for api key auth") + ); + """, + "http_auth_definition" to + RuntimeType.smithyHttpAuth(runtimeConfig).resolve("definition::HttpAuthDefinition"), + "http_header" to RuntimeType.Http.resolve("header"), + ) + } + } + } + else -> emptySection + } +} + +private class ApiKeyConfigCustomization(runtimeConfig: RuntimeConfig) : ConfigCustomization() { + private val codegenScope = arrayOf( + "ApiKey" to apiKey(runtimeConfig), + ) + + override fun section(section: ServiceConfig): Writable = + when (section) { + is ServiceConfig.BuilderStruct -> writable { + rustTemplate("api_key: Option<#{ApiKey}>,", *codegenScope) + } + is ServiceConfig.BuilderImpl -> writable { + rustTemplate( + """ + /// Sets the API key that will be used by the client. + pub fn api_key(mut self, api_key: #{ApiKey}) -> Self { + self.set_api_key(Some(api_key)); + self + } + + /// Sets the API key that will be used by the client. + pub fn set_api_key(&mut self, api_key: Option<#{ApiKey}>) -> &mut Self { + self.api_key = api_key; + self + } + """, + *codegenScope, + ) + } + is ServiceConfig.BuilderBuild -> writable { + rust("api_key: self.api_key,") + } + is ServiceConfig.ConfigStruct -> writable { + rustTemplate("api_key: Option<#{ApiKey}>,", *codegenScope) + } + is ServiceConfig.ConfigImpl -> writable { + rustTemplate( + """ + /// Returns API key used by the client, if it was provided. + pub fn api_key(&self) -> Option<&#{ApiKey}> { + self.api_key.as_ref() + } + """, + *codegenScope, + ) + } + else -> emptySection + } +} + +private fun apiKey(runtimeConfig: RuntimeConfig) = RuntimeType.smithyHttpAuth(runtimeConfig).resolve("api_key::AuthApiKey") diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ClientDocsGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ClientDocsGenerator.kt index 8b8d419ad0b..835f243df37 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ClientDocsGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ClientDocsGenerator.kt @@ -16,7 +16,9 @@ class ClientDocsGenerator : LibRsCustomization() { return when (section) { is LibRsSection.ModuleDocumentation -> if (section.subsection == LibRsSection.CrateOrganization) { crateLayout() - } else emptySection + } else { + emptySection + } else -> emptySection } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt index 84a6d5f38a1..a52c4a81d63 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/ResiliencyConfigCustomization.kt @@ -5,9 +5,9 @@ package software.amazon.smithy.rust.codegen.client.smithy.customizations +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext @@ -235,7 +235,7 @@ class ResiliencyConfigCustomization(codegenContext: CodegenContext) : ConfigCust class ResiliencyReExportCustomization(private val runtimeConfig: RuntimeConfig) { fun extras(rustCrate: RustCrate) { - rustCrate.withModule(RustModule.Config) { + rustCrate.withModule(ClientRustModule.Config) { rustTemplate( """ pub use #{sleep}::{AsyncSleep, Sleep}; diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt index f4034c1b9ba..c893ff078f6 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/ClientCodegenDecorator.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ErrorCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCodegenDecorator import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator @@ -40,6 +41,14 @@ interface ClientCodegenDecorator : CoreCodegenDecorator { baseCustomizations: List, ): List = baseCustomizations + /** + * Hook to customize generated errors. + */ + fun errorCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List = baseCustomizations + fun protocols(serviceId: ShapeId, currentProtocols: ClientProtocolMap): ClientProtocolMap = currentProtocols fun endpointCustomizations(codegenContext: ClientCodegenContext): List = listOf() @@ -72,6 +81,13 @@ open class CombinedClientCodegenDecorator(decorators: List, + ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.errorCustomizations(codegenContext, customizations) + } + override fun protocols(serviceId: ShapeId, currentProtocols: ClientProtocolMap): ClientProtocolMap = combineCustomizations(currentProtocols) { decorator, protocolMap -> decorator.protocols(serviceId, protocolMap) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt index f65e042d445..c88af4c2bdb 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt @@ -7,18 +7,22 @@ package software.amazon.smithy.rust.codegen.client.smithy.customize import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.customizations.EndpointPrefixGenerator import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpChecksumRequiredGenerator import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpVersionListCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.IdempotencyTokenGenerator import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyReExportCustomization +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedMetaModule +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedPrimitivesModule import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization -import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyTypes +import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyErrorTypes +import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitives import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization @@ -54,7 +58,7 @@ class RequiredCustomizations : ClientCodegenDecorator { codegenContext: ClientCodegenContext, baseCustomizations: List, ): List = - baseCustomizations + CrateVersionCustomization() + AllowLintsCustomization() + baseCustomizations + AllowLintsCustomization() override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { // Add rt-tokio feature for `ByteStream::from_path` @@ -65,6 +69,22 @@ class RequiredCustomizations : ClientCodegenDecorator { // Re-export resiliency types ResiliencyReExportCustomization(codegenContext.runtimeConfig).extras(rustCrate) - pubUseSmithyTypes(codegenContext.runtimeConfig, codegenContext.model, rustCrate) + rustCrate.withModule(codegenContext.featureGatedPrimitivesModule()) { + pubUseSmithyPrimitives(codegenContext, codegenContext.model)(this) + if (!codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme) { + pubUseSmithyErrorTypes(codegenContext)(this) + } + } + if (codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme) { + rustCrate.withModule(ClientRustModule.Error) { + pubUseSmithyErrorTypes(codegenContext)(this) + } + } + + codegenContext.featureGatedMetaModule().also { metaModule -> + rustCrate.withModule(metaModule) { + CrateVersionCustomization.extras(rustCrate, metaModule) + } + } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextParamDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomization.kt similarity index 100% rename from codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextParamDecorator.kt rename to codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/ClientContextConfigCustomization.kt diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointConfigCustomization.kt index 4d44c5602cd..1b31dec174c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/EndpointConfigCustomization.kt @@ -90,7 +90,9 @@ internal class EndpointConfigCustomization( /// let config = $moduleUseName::Config::builder().endpoint_resolver(prefix_resolver); /// ``` """ - } else "" + } else { + "" + } rustTemplate( """ /// Sets the endpoint resolver to use when making requests. diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsGenerator.kt index 1e5059830d1..b2a7bde2921 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointParamsGenerator.kt @@ -13,7 +13,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName import software.amazon.smithy.rust.codegen.client.smithy.endpoint.symbol import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter @@ -59,7 +58,7 @@ val EndpointTests = RustModule.new( documentation = "Generated endpoint tests", parent = EndpointsModule, inline = true, -).copy(rustMetadata = RustMetadata.TestModule) +).cfgTest() // stdlib is isolated because it contains code generated names of stdlib functions–we want to ensure we avoid clashing val EndpointsStdLib = RustModule.private("endpoint_lib", "Endpoints standard library functions") diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt index 183e25d33e7..c4d6efc3278 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/generators/EndpointTestGenerator.kt @@ -14,7 +14,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustom import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Types import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator -import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.escape @@ -48,8 +47,7 @@ internal class EndpointTestGenerator( "Error" to types.resolveEndpointError, "Document" to RuntimeType.document(runtimeConfig), "HashMap" to RuntimeType.HashMap, - "capture_request" to CargoDependency.smithyClient(runtimeConfig) - .withFeature("test-util").toType().resolve("test_connection::capture_request"), + "capture_request" to RuntimeType.captureRequest(runtimeConfig), ) private val instantiator = clientInstantiator(codegenContext) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt index f0bad1aac2e..b0d45366ed0 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/ExpressionGenerator.kt @@ -65,7 +65,14 @@ class ExpressionGenerator( getAttr.path.toList().forEach { part -> when (part) { is GetAttr.Part.Key -> rust(".${part.key().rustName()}()") - is GetAttr.Part.Index -> rust(".get(${part.index()}).cloned()") // we end up with Option<&&T>, we need to get to Option<&T> + is GetAttr.Part.Index -> { + if (part.index() == 0) { + // In this case, `.first()` is more idiomatic and `.get(0)` triggers lint warnings + rust(".first().cloned()") + } else { + rust(".get(${part.index()}).cloned()") // we end up with Option<&&T>, we need to get to Option<&T> + } + } } } if (ownership == Ownership.Owned && getAttr.type() != Type.bool()) { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/StdLib.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/StdLib.kt index d2d43f28580..2c16a51bb72 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/StdLib.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/endpoint/rulesgen/StdLib.kt @@ -74,10 +74,9 @@ class AwsPartitionResolver(runtimeConfig: RuntimeConfig, private val partitionsD ) override fun structFieldInit() = writable { + val json = Node.printJson(partitionsDotJson).dq() rustTemplate( - """partition_resolver: #{PartitionResolver}::new_from_json(b${ - Node.printJson(partitionsDotJson).dq() - }).expect("valid JSON")""", + """partition_resolver: #{PartitionResolver}::new_from_json(b$json).expect("valid JSON")""", *codegenScope, ) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt new file mode 100644 index 00000000000..a217b0f5f09 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt @@ -0,0 +1,176 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators + +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedPrimitivesModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGeneratorContext +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumMemberModel +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumType +import software.amazon.smithy.rust.codegen.core.util.dq + +/** Infallible enums have an `Unknown` variant and can't fail to parse */ +data class InfallibleEnumType( + val unknownVariantModule: RustModule, +) : EnumType() { + companion object { + /** Name of the generated unknown enum member name for enums with named members. */ + const val UnknownVariant = "Unknown" + + /** Name of the opaque struct that is inner data for the generated [UnknownVariant]. */ + const val UnknownVariantValue = "UnknownVariantValue" + } + + override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { + rustTemplate( + """ + impl #{From}<&str> for ${context.enumName} { + fn from(s: &str) -> Self { + match s { + #{matchArms} + } + } + } + """, + "From" to RuntimeType.From, + "matchArms" to writable { + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},") + } + rust( + "other => ${context.enumName}::$UnknownVariant(#T(other.to_owned()))", + unknownVariantValue(context), + ) + }, + ) + } + + override fun implFromStr(context: EnumGeneratorContext): Writable = writable { + rust( + """ + impl std::str::FromStr for ${context.enumName} { + type Err = std::convert::Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(${context.enumName}::from(s)) + } + } + """, + ) + } + + override fun additionalDocs(context: EnumGeneratorContext): Writable = writable { + renderForwardCompatibilityNote(context.enumName, context.sortedMembers, UnknownVariant, UnknownVariantValue) + } + + override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable { + docs("`$UnknownVariant` contains new variants that have been added since this code was generated.") + rust("$UnknownVariant(#T)", unknownVariantValue(context)) + } + + override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = writable { + rust("${context.enumName}::$UnknownVariant(value) => value.as_str()") + } + + private fun unknownVariantValue(context: EnumGeneratorContext): RuntimeType { + return RuntimeType.forInlineFun(UnknownVariantValue, unknownVariantModule) { + docs( + """ + Opaque struct used as inner data for the `Unknown` variant defined in enums in + the crate + + While this is not intended to be used directly, it is marked as `pub` because it is + part of the enums that are public interface. + """.trimIndent(), + ) + context.enumMeta.render(this) + rust("struct $UnknownVariantValue(pub(crate) String);") + rustBlock("impl $UnknownVariantValue") { + // The generated as_str is not pub as we need to prevent users from calling it on this opaque struct. + rustBlock("pub(crate) fn as_str(&self) -> &str") { + rust("&self.0") + } + } + } + } + + /** + * Generate the rustdoc describing how to write a match expression against a generated enum in a + * forward-compatible way. + */ + private fun RustWriter.renderForwardCompatibilityNote( + enumName: String, sortedMembers: List, + unknownVariant: String, unknownVariantValue: String, + ) { + docs( + """ + When writing a match expression against `$enumName`, it is important to ensure + your code is forward-compatible. That is, if a match arm handles a case for a + feature that is supported by the service but has not been represented as an enum + variant in a current version of SDK, your code should continue to work when you + upgrade SDK to a future version in which the enum does include a variant for that + feature. + """.trimIndent(), + ) + docs("") + docs("Here is an example of how you can make a match expression forward-compatible:") + docs("") + docs("```text") + rust("/// ## let ${enumName.lowercase()} = unimplemented!();") + rust("/// match ${enumName.lowercase()} {") + sortedMembers.mapNotNull { it.name() }.forEach { member -> + rust("/// $enumName::${member.name} => { /* ... */ },") + } + rust("""/// other @ _ if other.as_str() == "NewFeature" => { /* handles a case for `NewFeature` */ },""") + rust("/// _ => { /* ... */ },") + rust("/// }") + docs("```") + docs( + """ + The above code demonstrates that when `${enumName.lowercase()}` represents + `NewFeature`, the execution path will lead to the second last match arm, + even though the enum does not contain a variant `$enumName::NewFeature` + in the current version of SDK. The reason is that the variable `other`, + created by the `@` operator, is bound to + `$enumName::$unknownVariant($unknownVariantValue("NewFeature".to_owned()))` + and calling `as_str` on it yields `"NewFeature"`. + This match expression is forward-compatible when executed with a newer + version of SDK where the variant `$enumName::NewFeature` is defined. + Specifically, when `${enumName.lowercase()}` represents `NewFeature`, + the execution path will hit the second last match arm as before by virtue of + calling `as_str` on `$enumName::NewFeature` also yielding `"NewFeature"`. + """.trimIndent(), + ) + docs("") + docs( + """ + Explicitly matching on the `$unknownVariant` variant should + be avoided for two reasons: + - The inner data `$unknownVariantValue` is opaque, and no further information can be extracted. + - It might inadvertently shadow other intended match arms. + """.trimIndent(), + ) + } +} + +class ClientEnumGenerator(codegenContext: ClientCodegenContext, shape: StringShape) : + EnumGenerator( + codegenContext.model, + codegenContext.symbolProvider, + shape, + InfallibleEnumType(codegenContext.featureGatedPrimitivesModule()), + ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt index 54982643e8b..d87c2d1483e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/PaginatorGenerator.kt @@ -8,11 +8,11 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.PaginatedIndex import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.traits.IdempotencyTokenTrait import software.amazon.smithy.model.traits.PaginatedTrait +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedPaginatorModule import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerics -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.render @@ -20,11 +20,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.findMemberWithTrait @@ -40,25 +36,21 @@ fun OperationShape.isPaginated(model: Model) = .findMemberWithTrait(model) == null class PaginatorGenerator private constructor( - private val model: Model, - private val symbolProvider: RustSymbolProvider, - service: ServiceShape, + codegenContext: ClientCodegenContext, operation: OperationShape, private val generics: FluentClientGenerics, retryClassifier: RuntimeType, ) { companion object { fun paginatorType( - codegenContext: CodegenContext, + codegenContext: ClientCodegenContext, generics: FluentClientGenerics, operationShape: OperationShape, retryClassifier: RuntimeType, ): RuntimeType? { return if (operationShape.isPaginated(codegenContext.model)) { PaginatorGenerator( - codegenContext.model, - codegenContext.symbolProvider, - codegenContext.serviceShape, + codegenContext, operationShape, generics, retryClassifier, @@ -69,17 +61,19 @@ class PaginatorGenerator private constructor( } } + private val model = codegenContext.model + private val symbolProvider = codegenContext.symbolProvider + private val runtimeConfig = codegenContext.runtimeConfig private val paginatorName = "${operation.id.name.toPascalCase()}Paginator" - private val runtimeConfig = symbolProvider.config().runtimeConfig private val idx = PaginatedIndex.of(model) - private val paginationInfo = - idx.getPaginationInfo(service, operation).orNull() ?: PANIC("failed to load pagination info") - private val module = RustModule.public("paginator", "Paginators for the service") + private val paginationInfo = idx.getPaginationInfo(codegenContext.serviceShape, operation).orNull() + ?: PANIC("failed to load pagination info") + private val module = codegenContext.featureGatedPaginatorModule(symbolProvider, operation) private val inputType = symbolProvider.toSymbol(operation.inputShape(model)) private val outputShape = operation.outputShape(model) private val outputType = symbolProvider.toSymbol(outputShape) - private val errorType = operation.errorSymbol(symbolProvider) + private val errorType = symbolProvider.symbolForOperationError(operation) private fun paginatorType(): RuntimeType = RuntimeType.forInlineFun( paginatorName, @@ -103,7 +97,7 @@ class PaginatorGenerator private constructor( "Input" to inputType, "Output" to outputType, "Error" to errorType, - "Builder" to operation.inputShape(model).builderSymbol(symbolProvider), + "Builder" to symbolProvider.symbolForBuilder(operation.inputShape(model)), // SDK Types "SdkError" to RuntimeType.sdkError(runtimeConfig), diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt index 6710ac9c3d5..c495c623f5d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ServiceGenerator.kt @@ -7,14 +7,14 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfigGenerator +import software.amazon.smithy.rust.codegen.client.smithy.generators.error.ServiceErrorGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolTestGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServiceErrorGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -56,9 +56,13 @@ class ServiceGenerator( } } - ServiceErrorGenerator(clientCodegenContext, operations).render(rustCrate) + ServiceErrorGenerator( + clientCodegenContext, + operations, + decorator.errorCustomizations(clientCodegenContext, emptyList()), + ).render(rustCrate) - rustCrate.withModule(RustModule.Config) { + rustCrate.withModule(ClientRustModule.Config) { ServiceConfigGenerator.withBaseBehavior( clientCodegenContext, extraCustomizations = decorator.configCustomizations(clientCodegenContext, listOf()), diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt index 31a86d07ac0..31f2ee8f58b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt @@ -5,14 +5,13 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.client +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedCustomizeModule import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.GenericTypeArg import software.amazon.smithy.rust.codegen.core.rustlang.RustGenerics -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustCrate @@ -21,20 +20,16 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate * fluent client builders. */ class CustomizableOperationGenerator( - private val runtimeConfig: RuntimeConfig, + private val codegenContext: ClientCodegenContext, private val generics: FluentClientGenerics, - private val includeFluentClient: Boolean, ) { - - companion object { - val CustomizeModule = RustModule.public("customize", "Operation customization and supporting types", parent = RustModule.operation(Visibility.PUBLIC)) - } - + private val includeFluentClient = codegenContext.settings.codegenConfig.includeFluentClient + private val runtimeConfig = codegenContext.runtimeConfig private val smithyHttp = CargoDependency.smithyHttp(runtimeConfig).toType() private val smithyTypes = CargoDependency.smithyTypes(runtimeConfig).toType() fun render(crate: RustCrate) { - crate.withModule(CustomizeModule) { + crate.withModule(codegenContext.featureGatedCustomizeModule()) { rustTemplate( """ pub use #{Operation}; @@ -67,6 +62,7 @@ class CustomizableOperationGenerator( "handle_generics_bounds" to handleGenerics.bounds(), "operation_generics_decl" to operationGenerics.declaration(), "combined_generics_decl" to combinedGenerics.declaration(), + "customize_module" to codegenContext.featureGatedCustomizeModule(), ) writer.rustTemplate( @@ -81,7 +77,7 @@ class CustomizableOperationGenerator( /// A wrapper type for [`Operation`](aws_smithy_http::operation::Operation)s that allows for /// customization of the operation before it is sent. A `CustomizableOperation` may be sent - /// by calling its [`.send()`][crate::operation::customize::CustomizableOperation::send] method. + /// by calling its [`.send()`][#{customize_module}::CustomizableOperation::send] method. ##[derive(Debug)] pub struct CustomizableOperation#{combined_generics_decl:W} { pub(crate) handle: Arc, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDecorator.kt index b6fced279fb..475c9c490ae 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientDecorator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.client +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext @@ -65,7 +66,7 @@ sealed class FluentClientSection(name: String) : Section(name) { /** Write custom code into an operation fluent builder's impl block */ data class FluentBuilderImpl( val operationShape: OperationShape, - val operationErrorType: RuntimeType, + val operationErrorType: Symbol, ) : FluentClientSection("FluentBuilderImpl") /** Write custom code into the docs */ diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt index d5b8ebd9418..be9c69e0cb8 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt @@ -13,6 +13,8 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.DocumentationTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule +import software.amazon.smithy.rust.codegen.client.smithy.featureGatedCustomizeModule import software.amazon.smithy.rust.codegen.client.smithy.generators.PaginatorGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.isPaginated import software.amazon.smithy.rust.codegen.core.rustlang.Attribute @@ -21,12 +23,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.asArgumentType import software.amazon.smithy.rust.codegen.core.rustlang.asOptional import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape import software.amazon.smithy.rust.codegen.core.rustlang.docLink -import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.documentShape import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.normalizeHtml @@ -43,8 +43,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -54,6 +52,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase class FluentClientGenerator( private val codegenContext: ClientCodegenContext, + private val reexportSmithyClientBuilder: Boolean = true, private val generics: FluentClientGenerics = FlexibleClientGenerics( connectorDefault = null, middlewareDefault = null, @@ -66,11 +65,6 @@ class FluentClientGenerator( companion object { fun clientOperationFnName(operationShape: OperationShape, symbolProvider: RustSymbolProvider): String = RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(operationShape).name.toSnakeCase()) - - val clientModule = RustModule.public( - "client", - "Client and fluent builders for calling the service.", - ) } private val serviceShape = codegenContext.serviceShape @@ -82,18 +76,29 @@ class FluentClientGenerator( private val core = FluentClientCore(model) fun render(crate: RustCrate) { - crate.withModule(clientModule) { + crate.withModule(ClientRustModule.client) { renderFluentClient(this) } - CustomizableOperationGenerator( - runtimeConfig, - generics, - codegenContext.settings.codegenConfig.includeFluentClient, - ).render(crate) + operations.forEach { operation -> + crate.withModule(operation.fluentBuilderModule(codegenContext, symbolProvider)) { + renderFluentBuilder(operation) + } + } + + CustomizableOperationGenerator(codegenContext, generics).render(crate) } private fun renderFluentClient(writer: RustWriter) { + if (!codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme || reexportSmithyClientBuilder) { + writer.rustTemplate( + """ + ##[doc(inline)] + pub use #{client}::Builder; + """, + "client" to RuntimeType.smithyClient(runtimeConfig), + ) + } writer.rustTemplate( """ ##[derive(Debug)] @@ -114,9 +119,6 @@ class FluentClientGenerator( } } - ##[doc(inline)] - pub use #{client}::Builder; - impl${generics.inst} From<#{client}::Client#{smithy_inst:W}> for Client${generics.inst} { fn from(client: #{client}::Client#{smithy_inst:W}) -> Self { Self::with_config(client, crate::Config::builder().build()) @@ -144,15 +146,15 @@ class FluentClientGenerator( "smithy_inst" to generics.smithyInst, "client" to RuntimeType.smithyClient(runtimeConfig), "client_docs" to writable - { - customizations.forEach { - it.section( - FluentClientSection.FluentClientDocs( - serviceShape, - ), - )(this) - } - }, + { + customizations.forEach { + it.section( + FluentClientSection.FluentClientDocs( + serviceShape, + ), + )(this) + } + }, ) writer.rustBlockTemplate( "impl${generics.inst} Client${generics.inst} #{bounds:W}", @@ -161,19 +163,24 @@ class FluentClientGenerator( ) { operations.forEach { operation -> val name = symbolProvider.toSymbol(operation).name - val fullPath = operation.fullyQualifiedFluentBuilder(symbolProvider) + val fullPath = operation.fullyQualifiedFluentBuilder(codegenContext, symbolProvider) val maybePaginated = if (operation.isPaginated(model)) { "\n/// This operation supports pagination; See [`into_paginator()`]($fullPath::into_paginator)." - } else "" + } else { + "" + } val output = operation.outputShape(model) val operationOk = symbolProvider.toSymbol(output) - val operationErr = operation.errorSymbol(symbolProvider).toSymbol() + val operationErr = symbolProvider.symbolForOperationError(operation) - val inputFieldsBody = - generateOperationShapeDocs(writer, symbolProvider, operation, model).joinToString("\n") { - "/// - $it" - } + val inputFieldsBody = generateOperationShapeDocs( + writer, + codegenContext, + symbolProvider, + operation, + model, + ).joinToString("\n") { "/// - $it" } val inputFieldsHead = if (inputFieldsBody.isNotEmpty()) { "The fluent builder is configurable:" @@ -203,155 +210,147 @@ class FluentClientGenerator( """, ) - writer.rust( + // Write a deprecation notice if this operation is deprecated. + writer.deprecatedShape(operation) + + writer.rustTemplate( """ - pub fn ${ - clientOperationFnName( - operation, - symbolProvider, - ) - }(&self) -> fluent_builders::$name${generics.inst} { - fluent_builders::$name::new(self.handle.clone()) + pub fn #{fnName}(&self) -> #{FluentBuilder}${generics.inst} { + #{FluentBuilder}::new(self.handle.clone()) } """, + "fnName" to writable { rust(clientOperationFnName(operation, symbolProvider)) }, + "FluentBuilder" to operation.fluentBuilderType(codegenContext, symbolProvider), ) } } - writer.withInlineModule(RustModule.new("fluent_builders", visibility = Visibility.PUBLIC, inline = true)) { - docs( + } + + private fun RustWriter.renderFluentBuilder(operation: OperationShape) { + val operationSymbol = symbolProvider.toSymbol(operation) + val input = operation.inputShape(model) + val baseDerives = symbolProvider.toSymbol(input).expectRustMetadata().derives + // Filter out any derive that isn't Clone. Then add a Debug derive + val derives = baseDerives.filter { it == RuntimeType.Clone } + RuntimeType.Debug + rust( + """ + /// Fluent builder constructing a request to `${operationSymbol.name}`. + /// + """, + ) + + val builderName = operation.fluentBuilderType(codegenContext, symbolProvider).name + documentShape(operation, model, autoSuppressMissingDocs = false) + deprecatedShape(operation) + Attribute(derive(derives.toSet())).render(this) + rustTemplate( + """ + pub struct $builderName#{generics:W} { + handle: std::sync::Arc, + inner: #{Inner} + } + """, + "Inner" to symbolProvider.symbolForBuilder(input), + "client" to RuntimeType.smithyClient(runtimeConfig), + "generics" to generics.decl, + "operation" to operationSymbol, + ) + + rustBlockTemplate( + "impl${generics.inst} $builderName${generics.inst} #{bounds:W}", + "client" to RuntimeType.smithyClient(runtimeConfig), + "bounds" to generics.bounds, + ) { + val outputType = symbolProvider.toSymbol(operation.outputShape(model)) + val errorType = symbolProvider.symbolForOperationError(operation) + + // Have to use fully-qualified result here or else it could conflict with an op named Result + rustTemplate( """ - Utilities to ergonomically construct a request to the service. + /// Creates a new `${operationSymbol.name}`. + pub(crate) fn new(handle: std::sync::Arc) -> Self { + Self { handle, inner: Default::default() } + } - Fluent builders are created through the [`Client`](crate::client::Client) by calling - one if its operation methods. After parameters are set using the builder methods, - the `send` method can be called to initiate the request. - """.trim(), - newlinePrefix = "//! ", - ) - operations.forEach { operation -> - val operationSymbol = symbolProvider.toSymbol(operation) - val input = operation.inputShape(model) - val baseDerives = symbolProvider.toSymbol(input).expectRustMetadata().derives - // Filter out any derive that isn't Clone. Then add a Debug derive - val derives = baseDerives.filter { it == RuntimeType.Clone } + RuntimeType.Debug - rust( - """ - /// Fluent builder constructing a request to `${operationSymbol.name}`. - /// - """, - ) + /// Consume this builder, creating a customizable operation that can be modified before being + /// sent. The operation's inner [http::Request] can be modified as well. + pub async fn customize(self) -> std::result::Result< + #{CustomizableOperation}#{customizable_op_type_params:W}, + #{SdkError}<#{OperationError}> + > #{send_bounds:W} { + let handle = self.handle.clone(); + let operation = self.inner.build().map_err(#{SdkError}::construction_failure)? + .make_operation(&handle.conf) + .await + .map_err(#{SdkError}::construction_failure)?; + Ok(#{CustomizableOperation} { handle, operation }) + } - documentShape(operation, model, autoSuppressMissingDocs = false) - deprecatedShape(operation) - Attribute(derive(derives.toSet())).render(this) + /// Sends the request and returns the response. + /// + /// If an error occurs, an `SdkError` will be returned with additional details that + /// can be matched against. + /// + /// By default, any retryable failures will be retried twice. Retry behavior + /// is configurable with the [RetryConfig](aws_smithy_types::retry::RetryConfig), which can be + /// set when configuring the client. + pub async fn send(self) -> std::result::Result<#{OperationOutput}, #{SdkError}<#{OperationError}>> + #{send_bounds:W} { + let op = self.inner.build().map_err(#{SdkError}::construction_failure)? + .make_operation(&self.handle.conf) + .await + .map_err(#{SdkError}::construction_failure)?; + self.handle.client.call(op).await + } + """, + "CustomizableOperation" to codegenContext.featureGatedCustomizeModule().toType() + .resolve("CustomizableOperation"), + "ClassifyRetry" to RuntimeType.classifyRetry(runtimeConfig), + "OperationError" to errorType, + "OperationOutput" to outputType, + "SdkError" to RuntimeType.sdkError(runtimeConfig), + "SdkSuccess" to RuntimeType.sdkSuccess(runtimeConfig), + "send_bounds" to generics.sendBounds(operationSymbol, outputType, errorType, retryClassifier), + "customizable_op_type_params" to rustTypeParameters( + symbolProvider.toSymbol(operation), + retryClassifier, + generics.toRustGenerics(), + ), + ) + PaginatorGenerator.paginatorType(codegenContext, generics, operation, retryClassifier)?.also { paginatorType -> rustTemplate( """ - pub struct ${operationSymbol.name}#{generics:W} { - handle: std::sync::Arc, - inner: #{Inner} + /// Create a paginator for this request + /// + /// Paginators are used by calling [`send().await`](#{Paginator}::send) which returns a `Stream`. + pub fn into_paginator(self) -> #{Paginator}${generics.inst} { + #{Paginator}::new(self.handle, self.inner) } """, - "Inner" to input.builderSymbol(symbolProvider), - "client" to RuntimeType.smithyClient(runtimeConfig), - "generics" to generics.decl, - "operation" to operationSymbol, + "Paginator" to paginatorType, ) - - rustBlockTemplate( - "impl${generics.inst} ${operationSymbol.name}${generics.inst} #{bounds:W}", - "client" to RuntimeType.smithyClient(runtimeConfig), - "bounds" to generics.bounds, - ) { - val outputType = symbolProvider.toSymbol(operation.outputShape(model)) - val errorType = operation.errorSymbol(symbolProvider) - - // Have to use fully-qualified result here or else it could conflict with an op named Result - rustTemplate( - """ - /// Creates a new `${operationSymbol.name}`. - pub(crate) fn new(handle: std::sync::Arc) -> Self { - Self { handle, inner: Default::default() } - } - - /// Consume this builder, creating a customizable operation that can be modified before being - /// sent. The operation's inner [http::Request] can be modified as well. - pub async fn customize(self) -> std::result::Result< - crate::operation::customize::CustomizableOperation#{customizable_op_type_params:W}, - #{SdkError}<#{OperationError}> - > #{send_bounds:W} { - let handle = self.handle.clone(); - let operation = self.inner.build().map_err(#{SdkError}::construction_failure)? - .make_operation(&handle.conf) - .await - .map_err(#{SdkError}::construction_failure)?; - Ok(crate::operation::customize::CustomizableOperation { handle, operation }) - } - - /// Sends the request and returns the response. - /// - /// If an error occurs, an `SdkError` will be returned with additional details that - /// can be matched against. - /// - /// By default, any retryable failures will be retried twice. Retry behavior - /// is configurable with the [RetryConfig](aws_smithy_types::retry::RetryConfig), which can be - /// set when configuring the client. - pub async fn send(self) -> std::result::Result<#{OperationOutput}, #{SdkError}<#{OperationError}>> - #{send_bounds:W} { - let op = self.inner.build().map_err(#{SdkError}::construction_failure)? - .make_operation(&self.handle.conf) - .await - .map_err(#{SdkError}::construction_failure)?; - self.handle.client.call(op).await - } - """, - "ClassifyRetry" to RuntimeType.classifyRetry(runtimeConfig), - "OperationError" to errorType, - "OperationOutput" to outputType, - "SdkError" to RuntimeType.sdkError(runtimeConfig), - "SdkSuccess" to RuntimeType.sdkSuccess(runtimeConfig), - "send_bounds" to generics.sendBounds(operationSymbol, outputType, errorType, retryClassifier), - "customizable_op_type_params" to rustTypeParameters( - symbolProvider.toSymbol(operation), - retryClassifier, - generics.toRustGenerics(), - ), - ) - PaginatorGenerator.paginatorType(codegenContext, generics, operation, retryClassifier)?.also { paginatorType -> - rustTemplate( - """ - /// Create a paginator for this request - /// - /// Paginators are used by calling [`send().await`](#{Paginator}::send) which returns a `Stream`. - pub fn into_paginator(self) -> #{Paginator}${generics.inst} { - #{Paginator}::new(self.handle, self.inner) - } - """, - "Paginator" to paginatorType, - ) - } - writeCustomizations( - customizations, - FluentClientSection.FluentBuilderImpl( - operation, - operation.errorSymbol(symbolProvider), - ), - ) - input.members().forEach { member -> - val memberName = symbolProvider.toMemberName(member) - // All fields in the builder are optional - val memberSymbol = symbolProvider.toSymbol(member) - val outerType = memberSymbol.rustType() - when (val coreType = outerType.stripOuter()) { - is RustType.Vec -> with(core) { renderVecHelper(member, memberName, coreType) } - is RustType.HashMap -> with(core) { renderMapHelper(member, memberName, coreType) } - else -> with(core) { renderInputHelper(member, memberName, coreType) } - } - // pure setter - val setterName = member.setterName() - val optionalInputType = outerType.asOptional() - with(core) { renderInputHelper(member, setterName, optionalInputType) } - } + } + writeCustomizations( + customizations, + FluentClientSection.FluentBuilderImpl( + operation, + symbolProvider.symbolForOperationError(operation), + ), + ) + input.members().forEach { member -> + val memberName = symbolProvider.toMemberName(member) + // All fields in the builder are optional + val memberSymbol = symbolProvider.toSymbol(member) + val outerType = memberSymbol.rustType() + when (val coreType = outerType.stripOuter()) { + is RustType.Vec -> with(core) { renderVecHelper(member, memberName, coreType) } + is RustType.HashMap -> with(core) { renderMapHelper(member, memberName, coreType) } + else -> with(core) { renderInputHelper(member, memberName, coreType) } } + // pure setter + val setterName = member.setterName() + val optionalInputType = outerType.asOptional() + with(core) { renderInputHelper(member, setterName, optionalInputType) } } } } @@ -365,12 +364,13 @@ class FluentClientGenerator( */ private fun generateOperationShapeDocs( writer: RustWriter, - symbolProvider: SymbolProvider, + codegenContext: ClientCodegenContext, + symbolProvider: RustSymbolProvider, operation: OperationShape, model: Model, ): List { val input = operation.inputShape(model) - val fluentBuilderFullyQualifiedName = operation.fullyQualifiedFluentBuilder(symbolProvider) + val fluentBuilderFullyQualifiedName = operation.fullyQualifiedFluentBuilder(codegenContext, symbolProvider) return input.members().map { memberShape -> val builderInputDoc = memberShape.asFluentBuilderInputDoc(symbolProvider) val builderInputLink = docLink("$fluentBuilderFullyQualifiedName::${symbolProvider.toMemberName(memberShape)}") @@ -413,17 +413,46 @@ private fun generateShapeMemberDocs( } } +private fun OperationShape.fluentBuilderModule( + codegenContext: ClientCodegenContext, + symbolProvider: RustSymbolProvider, +) = when (codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> symbolProvider.moduleForBuilder(this) + else -> RustModule.public( + "fluent_builders", + parent = ClientRustModule.client, + documentation = """ + Utilities to ergonomically construct a request to the service. + + Fluent builders are created through the [`Client`](crate::client::Client) by calling + one if its operation methods. After parameters are set using the builder methods, + the `send` method can be called to initiate the request. + """.trimIndent(), + ) +} + +internal fun OperationShape.fluentBuilderType( + codegenContext: ClientCodegenContext, + symbolProvider: RustSymbolProvider, +): RuntimeType = fluentBuilderModule(codegenContext, symbolProvider).toType() + .resolve( + symbolProvider.toSymbol(this).name + + when (codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> "FluentBuilder" + else -> "" + }, + ) + /** * Generate a valid fully-qualified Type for a fluent builder e.g. - * `OperationShape(AssumeRole)` -> `"crate::client::fluent_builders::AssumeRole"` + * `OperationShape(AssumeRole)` -> `"crate::operations::assume_role::AssumeRoleFluentBuilder"` * * * _NOTE: This function generates the links that appear under **"The fluent builder is configurable:"**_ */ -private fun OperationShape.fullyQualifiedFluentBuilder(symbolProvider: SymbolProvider): String { - val operationName = symbolProvider.toSymbol(this).name - - return "crate::client::fluent_builders::$operationName" -} +private fun OperationShape.fullyQualifiedFluentBuilder( + codegenContext: ClientCodegenContext, + symbolProvider: RustSymbolProvider, +): String = fluentBuilderType(codegenContext, symbolProvider).fullyQualifiedName() /** * Generate a string that looks like a Rust function pointer for documenting a fluent builder method e.g. diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerics.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerics.kt index b3229051e69..399085d5e5b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerics.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerics.kt @@ -28,7 +28,7 @@ interface FluentClientGenerics { val bounds: Writable /** Bounds for generated `send()` functions */ - fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: RuntimeType, retryClassifier: RuntimeType): Writable + fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: Symbol, retryClassifier: RuntimeType): Writable /** Convert this `FluentClientGenerics` into the more general `RustGenerics` */ fun toRustGenerics(): RustGenerics @@ -70,7 +70,7 @@ data class FlexibleClientGenerics( } /** Bounds for generated `send()` functions */ - override fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: RuntimeType, retryClassifier: RuntimeType): Writable = writable { + override fun sendBounds(operation: Symbol, operationOutput: Symbol, operationError: Symbol, retryClassifier: RuntimeType): Writable = writable { rustTemplate( """ where diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/IdempotencyTokenProviderCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/IdempotencyTokenProviderCustomization.kt index 7cda821957e..75b075dc4d5 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/IdempotencyTokenProviderCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/IdempotencyTokenProviderCustomization.kt @@ -63,7 +63,9 @@ class IdempotencyTokenProviderCustomization : NamedCustomization( rust("make_token: self.make_token.unwrap_or_else(#T::default_provider),", RuntimeType.IdempotencyToken) } - is ServiceConfig.DefaultForTests -> writable { rust("""${section.configBuilderRef}.set_make_token(Some("00000000-0000-4000-8000-000000000000".into()));""") } + is ServiceConfig.DefaultForTests -> writable { + rust("""${section.configBuilderRef}.set_make_token(Some("00000000-0000-4000-8000-000000000000".into()));""") + } else -> writable { } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorCustomization.kt new file mode 100644 index 00000000000..d275c9b17d4 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorCustomization.kt @@ -0,0 +1,25 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.error + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section + +/** Error customization sections */ +sealed class ErrorSection(name: String) : Section(name) { + /** Use this section to add additional trait implementations to the generated operation errors */ + data class OperationErrorAdditionalTraitImpls(val errorSymbol: Symbol, val allErrors: List) : + ErrorSection("OperationErrorAdditionalTraitImpls") + + /** Use this section to add additional trait implementations to the generated service error */ + class ServiceErrorAdditionalTraitImpls(val allErrors: List) : + ErrorSection("ServiceErrorAdditionalTraitImpls") +} + +/** Customizations for generated errors */ +abstract class ErrorCustomization : NamedCustomization() diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGenerator.kt new file mode 100644 index 00000000000..7380d475115 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGenerator.kt @@ -0,0 +1,126 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.error + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.errorMetadata +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderSection +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureSection +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator + +class ErrorGenerator( + private val model: Model, + private val symbolProvider: RustSymbolProvider, + private val shape: StructureShape, + private val error: ErrorTrait, + private val implCustomizations: List, +) { + private val runtimeConfig = symbolProvider.config.runtimeConfig + private val symbol = symbolProvider.toSymbol(shape) + + fun renderStruct(writer: RustWriter) { + writer.apply { + StructureGenerator( + model, symbolProvider, this, shape, + listOf( + object : StructureCustomization() { + override fun section(section: StructureSection): Writable = writable { + when (section) { + is StructureSection.AdditionalFields -> { + rust("pub(crate) meta: #T,", errorMetadata(runtimeConfig)) + } + + is StructureSection.AdditionalDebugFields -> { + rust("""${section.formatterName}.field("meta", &self.meta);""") + } + + else -> {} + } + } + }, + ), + ).render() + + ErrorImplGenerator( + model, + symbolProvider, + this, + shape, + error, + implCustomizations, + ).render(CodegenTarget.CLIENT) + + rustBlock("impl #T for ${symbol.name}", RuntimeType.provideErrorMetadataTrait(runtimeConfig)) { + rust("fn meta(&self) -> &#T { &self.meta }", errorMetadata(runtimeConfig)) + } + + implBlock(symbol) { + BuilderGenerator.renderConvenienceMethod(this, symbolProvider, shape) + } + } + } + + fun renderBuilder(writer: RustWriter) { + writer.apply { + BuilderGenerator( + model, symbolProvider, shape, + listOf( + object : BuilderCustomization() { + override fun section(section: BuilderSection): Writable = writable { + when (section) { + is BuilderSection.AdditionalFields -> { + rust("meta: Option<#T>,", errorMetadata(runtimeConfig)) + } + + is BuilderSection.AdditionalMethods -> { + rustTemplate( + """ + /// Sets error metadata + pub fn meta(mut self, meta: #{error_metadata}) -> Self { + self.meta = Some(meta); + self + } + + /// Sets error metadata + pub fn set_meta(&mut self, meta: Option<#{error_metadata}>) -> &mut Self { + self.meta = meta; + self + } + """, + "error_metadata" to errorMetadata(runtimeConfig), + ) + } + + is BuilderSection.AdditionalFieldsInBuild -> { + rust("meta: self.meta.unwrap_or_default(),") + } + + else -> {} + } + } + }, + ), + ).render(this) + } + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGenerator.kt new file mode 100644 index 00000000000..d9ef055a375 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGenerator.kt @@ -0,0 +1,269 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.error + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.RetryableTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.errorMetadata +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.unhandledError +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations +import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors +import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase + +/** + * Generates a unified error enum for [operation]. [ErrorGenerator] handles generating the individual variants, + * but we must still combine those variants into an enum covering all possible errors for a given operation. + * + * This generator also generates errors for event streams. + */ +class OperationErrorGenerator( + private val model: Model, + private val symbolProvider: RustSymbolProvider, + private val operationOrEventStream: Shape, + private val customizations: List, +) { + private val runtimeConfig = symbolProvider.config.runtimeConfig + private val symbol = symbolProvider.toSymbol(operationOrEventStream) + private val errorMetadata = errorMetadata(symbolProvider.config.runtimeConfig) + private val createUnhandledError = + RuntimeType.smithyHttp(runtimeConfig).resolve("result::CreateUnhandledError") + + private fun operationErrors(): List = + (operationOrEventStream as OperationShape).operationErrors(model).map { it.asStructureShape().get() } + private fun eventStreamErrors(): List = + (operationOrEventStream as UnionShape).eventStreamErrors() + .map { model.expectShape(it.asMemberShape().get().target, StructureShape::class.java) } + + fun render(writer: RustWriter) { + val (errorSymbol, errors) = when (operationOrEventStream) { + is OperationShape -> symbolProvider.symbolForOperationError(operationOrEventStream) to operationErrors() + is UnionShape -> symbolProvider.symbolForEventStreamError(operationOrEventStream) to eventStreamErrors() + else -> UNREACHABLE("OperationErrorGenerator only supports operation or event stream shapes") + } + + val meta = RustMetadata( + derives = setOf(RuntimeType.Debug), + additionalAttributes = listOf(Attribute.NonExhaustive), + visibility = Visibility.PUBLIC, + ) + + // TODO(deprecated): Remove this temporary alias. This was added so that the compiler + // points customers in the right direction when they are upgrading. Unfortunately there's no + // way to provide better backwards compatibility on this change. + val kindDeprecationMessage = "Operation `*Error/*ErrorKind` types were combined into a single `*Error` enum. " + + "The `.kind` field on `*Error` no longer exists and isn't needed anymore (you can just match on the " + + "error directly since it's an enum now)." + writer.rust( + """ + /// Do not use this. + /// + /// $kindDeprecationMessage + ##[deprecated(note = ${kindDeprecationMessage.dq()})] + pub type ${errorSymbol.name}Kind = ${errorSymbol.name}; + """, + ) + + writer.rust("/// Error type for the `${errorSymbol.name}` operation.") + meta.render(writer) + writer.rustBlock("enum ${errorSymbol.name}") { + errors.forEach { errorVariant -> + documentShape(errorVariant, model) + deprecatedShape(errorVariant) + val errorVariantSymbol = symbolProvider.toSymbol(errorVariant) + write("${errorVariantSymbol.name}(#T),", errorVariantSymbol) + } + rust( + """ + /// An unexpected error occurred (e.g., invalid JSON returned by the service or an unknown error code). + Unhandled(#T), + """, + unhandledError(runtimeConfig), + ) + } + writer.rustBlock("impl #T for ${errorSymbol.name}", createUnhandledError) { + rustBlock( + """ + fn create_unhandled_error( + source: Box, + meta: Option<#T> + ) -> Self + """, + errorMetadata, + ) { + rust( + """ + Self::Unhandled({ + let mut builder = #T::builder().source(source); + builder.set_meta(meta); + builder.build() + }) + """, + unhandledError(runtimeConfig), + ) + } + } + writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.Display) { + rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { + delegateToVariants(errors) { + writable { rust("_inner.fmt(f)") } + } + } + } + + val errorMetadataTrait = RuntimeType.provideErrorMetadataTrait(runtimeConfig) + writer.rustBlock("impl #T for ${errorSymbol.name}", errorMetadataTrait) { + rustBlock("fn meta(&self) -> &#T", errorMetadata(runtimeConfig)) { + delegateToVariants(errors) { + writable { rust("#T::meta(_inner)", errorMetadataTrait) } + } + } + } + + writer.writeCustomizations(customizations, ErrorSection.OperationErrorAdditionalTraitImpls(errorSymbol, errors)) + + val retryErrorKindT = RuntimeType.retryErrorKind(symbolProvider.config.runtimeConfig) + writer.rustBlock( + "impl #T for ${errorSymbol.name}", + RuntimeType.provideErrorKind(symbolProvider.config.runtimeConfig), + ) { + rustBlock("fn code(&self) -> Option<&str>") { + rust("#T::code(self)", RuntimeType.provideErrorMetadataTrait(runtimeConfig)) + } + + rustBlock("fn retryable_error_kind(&self) -> Option<#T>", retryErrorKindT) { + val retryableVariants = errors.filter { it.hasTrait() } + if (retryableVariants.isEmpty()) { + rust("None") + } else { + rustBlock("match self") { + retryableVariants.forEach { + val errorVariantSymbol = symbolProvider.toSymbol(it) + rust("Self::${errorVariantSymbol.name}(inner) => Some(inner.retryable_error_kind()),") + } + rust("_ => None") + } + } + } + } + + writer.rustBlock("impl ${errorSymbol.name}") { + writer.rustTemplate( + """ + /// Creates the `${errorSymbol.name}::Unhandled` variant from any error type. + pub fn unhandled(err: impl Into>) -> Self { + Self::Unhandled(#{Unhandled}::builder().source(err).build()) + } + + /// Creates the `${errorSymbol.name}::Unhandled` variant from a `#{error_metadata}`. + pub fn generic(err: #{error_metadata}) -> Self { + Self::Unhandled(#{Unhandled}::builder().source(err.clone()).meta(err).build()) + } + """, + "error_metadata" to errorMetadata, + "std_error" to RuntimeType.StdError, + "Unhandled" to unhandledError(runtimeConfig), + ) + writer.docs( + """ + Returns error metadata, which includes the error code, message, + request ID, and potentially additional information. + """, + ) + writer.rustBlock("pub fn meta(&self) -> &#T", errorMetadata) { + rust("use #T;", RuntimeType.provideErrorMetadataTrait(runtimeConfig)) + rustBlock("match self") { + errors.forEach { error -> + val errorVariantSymbol = symbolProvider.toSymbol(error) + rust("Self::${errorVariantSymbol.name}(e) => e.meta(),") + } + rust("Self::Unhandled(e) => e.meta(),") + } + } + errors.forEach { error -> + val errorVariantSymbol = symbolProvider.toSymbol(error) + val fnName = errorVariantSymbol.name.toSnakeCase() + writer.rust("/// Returns `true` if the error kind is `${errorSymbol.name}::${errorVariantSymbol.name}`.") + writer.rustBlock("pub fn is_$fnName(&self) -> bool") { + rust("matches!(self, Self::${errorVariantSymbol.name}(_))") + } + } + } + + writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.StdError) { + rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) { + delegateToVariants(errors) { + writable { + rust("Some(_inner)") + } + } + } + } + } + + sealed class VariantMatch(name: String) : Section(name) { + object Unhandled : VariantMatch("Unhandled") + data class Modeled(val symbol: Symbol, val shape: Shape) : VariantMatch("Modeled") + } + + /** + * Generates code to delegate behavior to the variants, for example: + * + * ```rust + * match self { + * Self::InvalidGreeting(_inner) => inner.fmt(f), + * Self::ComplexError(_inner) => inner.fmt(f), + * Self::FooError(_inner) => inner.fmt(f), + * Self::Unhandled(_inner) => _inner.fmt(f), + * } + * ``` + * + * [handler] is passed an instance of [VariantMatch]—a [writable] should be returned containing the content to be + * written for this variant. + * + * The field will always be bound as `_inner`. + */ + fun RustWriter.delegateToVariants( + errors: List, + handler: (VariantMatch) -> Writable, + ) { + rustBlock("match self") { + errors.forEach { + val errorSymbol = symbolProvider.toSymbol(it) + rust("""Self::${errorSymbol.name}(_inner) => """) + handler(VariantMatch.Modeled(errorSymbol, it))(this) + write(",") + } + val unhandledHandler = handler(VariantMatch.Unhandled) + rustBlock("Self::Unhandled(_inner) =>") { + unhandledHandler(this) + } + } + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ServiceErrorGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGenerator.kt similarity index 71% rename from codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ServiceErrorGenerator.kt rename to codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGenerator.kt index c03a30b6cca..f461d59e053 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ServiceErrorGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGenerator.kt @@ -3,8 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rust.codegen.core.smithy.generators.error +package software.amazon.smithy.rust.codegen.client.smithy.generators.error +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape @@ -23,7 +24,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.unhandledError import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations import software.amazon.smithy.rust.codegen.core.smithy.transformers.allErrors import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors @@ -42,11 +45,17 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamE * } * ``` */ -class ServiceErrorGenerator(private val codegenContext: CodegenContext, private val operations: List) { +class ServiceErrorGenerator( + private val codegenContext: CodegenContext, + private val operations: List, + private val customizations: List, +) { private val symbolProvider = codegenContext.symbolProvider private val model = codegenContext.model - private val allErrors = operations.flatMap { it.allErrors(model) }.map { it.id }.distinctBy { it.getName(codegenContext.serviceShape) } + private val allErrors = operations.flatMap { + it.allErrors(model) + }.map { it.id }.distinctBy { it.getName(codegenContext.serviceShape) } .map { codegenContext.model.expectShape(it, StructureShape::class.java) } .sortedBy { it.id.getName(codegenContext.serviceShape) } @@ -59,7 +68,7 @@ class ServiceErrorGenerator(private val codegenContext: CodegenContext, private // Every operation error can be converted into service::Error operations.forEach { operationShape -> // operation errors - renderImplFrom(operationShape.errorSymbol(symbolProvider), operationShape.errors) + renderImplFrom(symbolProvider.symbolForOperationError(operationShape), operationShape.errors) } // event stream errors operations.map { it.eventStreamErrors(codegenContext.model) } @@ -67,11 +76,12 @@ class ServiceErrorGenerator(private val codegenContext: CodegenContext, private .associate { it.key to it.value } .forEach { (unionShape, errors) -> renderImplFrom( - unionShape.eventStreamErrorSymbol(symbolProvider), + symbolProvider.symbolForEventStreamError(unionShape), errors.map { it.id }, ) } rust("impl #T for Error {}", RuntimeType.StdError) + writeCustomizations(customizations, ErrorSection.ServiceErrorAdditionalTraitImpls(allErrors)) } crate.lib { rust("pub use error_meta::Error;") } } @@ -89,7 +99,7 @@ class ServiceErrorGenerator(private val codegenContext: CodegenContext, private } } - private fun RustWriter.renderImplFrom(errorSymbol: RuntimeType, errors: List) { + private fun RustWriter.renderImplFrom(errorSymbol: Symbol, errors: List) { if (errors.isNotEmpty() || CodegenTarget.CLIENT == codegenContext.target) { val operationErrors = errors.map { model.expectShape(it) } rustBlock( @@ -104,25 +114,36 @@ class ServiceErrorGenerator(private val codegenContext: CodegenContext, private ) { rustBlock("match err") { rust("#T::ServiceError(context) => Self::from(context.into_err()),", sdkError) - rust("_ => Error::Unhandled(#T::new(err.into())),", unhandledError()) + rustTemplate( + """ + _ => Error::Unhandled( + #{Unhandled}::builder() + .meta(#{ProvideErrorMetadata}::meta(&err).clone()) + .source(err) + .build() + ), + """, + "Unhandled" to unhandledError(codegenContext.runtimeConfig), + "ProvideErrorMetadata" to RuntimeType.provideErrorMetadataTrait(codegenContext.runtimeConfig), + ) } } } rustBlock("impl From<#T> for Error", errorSymbol) { rustBlock("fn from(err: #T) -> Self", errorSymbol) { - rustBlock("match err.kind") { + rustBlock("match err") { operationErrors.forEach { errorShape -> val errSymbol = symbolProvider.toSymbol(errorShape) rust( - "#TKind::${errSymbol.name}(inner) => Error::${errSymbol.name}(inner),", + "#T::${errSymbol.name}(inner) => Error::${errSymbol.name}(inner),", errorSymbol, ) } rustTemplate( - "#{errorSymbol}Kind::Unhandled(inner) => Error::Unhandled(#{unhandled}::new(inner.into())),", + "#{errorSymbol}::Unhandled(inner) => Error::Unhandled(inner),", "errorSymbol" to errorSymbol, - "unhandled" to unhandledError(), + "unhandled" to unhandledError(codegenContext.runtimeConfig), ) } } @@ -143,8 +164,8 @@ class ServiceErrorGenerator(private val codegenContext: CodegenContext, private val sym = symbolProvider.toSymbol(error) rust("${sym.name}(#T),", sym) } - docs(UNHANDLED_ERROR_DOCS) - rust("Unhandled(#T)", unhandledError()) + docs("An unexpected error occurred (e.g., invalid JSON returned by the service or an unknown error code).") + rust("Unhandled(#T)", unhandledError(codegenContext.runtimeConfig)) } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt index 9ec7173d03d..ae725af304a 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGenerator.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.http -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.pattern.SmithyPattern @@ -13,7 +12,6 @@ import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute @@ -26,7 +24,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.OperationBuildError -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.isOptional @@ -68,9 +65,8 @@ class RequestBindingGenerator( private val symbolProvider = codegenContext.symbolProvider private val runtimeConfig = codegenContext.runtimeConfig private val httpTrait = protocol.httpBindingResolver.httpTrait(operationShape) - private fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(symbolProvider) private val httpBindingGenerator = - HttpBindingGenerator(protocol, codegenContext, codegenContext.symbolProvider, operationShape, ::builderSymbol) + HttpBindingGenerator(protocol, codegenContext, codegenContext.symbolProvider, operationShape) private val index = HttpBindingIndex.of(model) private val encoder = RuntimeType.smithyTypes(runtimeConfig).resolve("primitive::Encoder") diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGenerator.kt index 7becd9df016..911028ea5ac 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGenerator.kt @@ -7,24 +7,20 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.http import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindingGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol class ResponseBindingGenerator( protocol: Protocol, - private val codegenContext: CodegenContext, + codegenContext: CodegenContext, operationShape: OperationShape, ) { - private fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider) - private val httpBindingGenerator = - HttpBindingGenerator(protocol, codegenContext, codegenContext.symbolProvider, operationShape, ::builderSymbol) + HttpBindingGenerator(protocol, codegenContext, codegenContext.symbolProvider, operationShape) fun generateDeserializeHeaderFn(binding: HttpBindingDescriptor): RuntimeType = httpBindingGenerator.generateDeserializeHeaderFn(binding) @@ -34,11 +30,7 @@ class ResponseBindingGenerator( fun generateDeserializePayloadFn( binding: HttpBindingDescriptor, - errorT: RuntimeType, + errorSymbol: Symbol, payloadParser: RustWriter.(String) -> Unit, - ): RuntimeType = httpBindingGenerator.generateDeserializePayloadFn( - binding, - errorT, - payloadParser, - ) + ): RuntimeType = httpBindingGenerator.generateDeserializePayloadFn(binding, errorSymbol, payloadParser) } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolGenerator.kt index b8d6a652e80..31b1fb64ad6 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ClientProtocolGenerator.kt @@ -6,27 +6,29 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.protocol import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator +import software.amazon.smithy.rust.codegen.client.smithy.generators.client.fluentBuilderType import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.docLink +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTraitImplGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.core.util.inputShape open class ClientProtocolGenerator( - codegenContext: CodegenContext, + private val codegenContext: ClientCodegenContext, private val protocol: Protocol, /** * Operations generate a `make_operation(&config)` method to build a `aws_smithy_http::Operation` that can be dispatched @@ -50,40 +52,77 @@ open class ClientProtocolGenerator( customizations: List, ) { val inputShape = operationShape.inputShape(model) - val builderGenerator = BuilderGenerator(model, symbolProvider, operationShape.inputShape(model)) - builderGenerator.render(inputWriter) // impl OperationInputShape { ... } - val operationName = symbolProvider.toSymbol(operationShape).name - inputWriter.implBlock(inputShape, symbolProvider) { + inputWriter.implBlock(symbolProvider.toSymbol(inputShape)) { writeCustomizations( customizations, OperationSection.InputImpl(customizations, operationShape, inputShape, protocol), ) makeOperationGenerator.generateMakeOperation(this, operationShape, customizations) + } - // pub fn builder() -> ... { } - builderGenerator.renderConvenienceMethod(this) + when (codegenContext.settings.codegenConfig.enableNewCrateOrganizationScheme) { + true -> renderOperationStruct(operationWriter, operationShape, customizations) + else -> oldRenderOperationStruct(operationWriter, operationShape, inputShape, customizations) } + } + + private fun renderOperationStruct( + operationWriter: RustWriter, + operationShape: OperationShape, + customizations: List, + ) { + val operationName = symbolProvider.toSymbol(operationShape).name // pub struct Operation { ... } - val fluentBuilderName = FluentClientGenerator.clientOperationFnName(operationShape, symbolProvider) operationWriter.rust( + """ + /// `ParseStrictResponse` impl for `$operationName`. + """, + ) + Attribute(derive(RuntimeType.Clone, RuntimeType.Default, RuntimeType.Debug)).render(operationWriter) + Attribute.NonExhaustive.render(operationWriter) + Attribute.DocHidden.render(operationWriter) + operationWriter.rust("pub struct $operationName;") + operationWriter.implBlock(symbolProvider.toSymbol(operationShape)) { + rustBlock("pub(crate) fn new() -> Self") { + rust("Self") + } + + writeCustomizations(customizations, OperationSection.OperationImplBlock(customizations)) + } + traitGenerator.generateTraitImpls(operationWriter, operationShape, customizations) + } + + // TODO(CrateReorganization): Remove this function when removing `enableNewCrateOrganizationScheme` + private fun oldRenderOperationStruct( + operationWriter: RustWriter, + operationShape: OperationShape, + inputShape: StructureShape, + customizations: List, + ) { + val operationName = symbolProvider.toSymbol(operationShape).name + + // pub struct Operation { ... } + val fluentBuilderName = FluentClientGenerator.clientOperationFnName(operationShape, symbolProvider) + operationWriter.rustTemplate( """ /// Operation shape for `$operationName`. /// /// This is usually constructed for you using the the fluent builder returned by - /// [`$fluentBuilderName`](${docLink("crate::client::Client::$fluentBuilderName")}). + /// [`$fluentBuilderName`](#{fluentBuilder}). /// - /// See [`crate::client::fluent_builders::$operationName`] for more details about the operation. + /// `ParseStrictResponse` impl for `$operationName`. """, + "fluentBuilder" to operationShape.fluentBuilderType(codegenContext, symbolProvider), ) Attribute(derive(RuntimeType.Clone, RuntimeType.Default, RuntimeType.Debug)).render(operationWriter) operationWriter.rustBlock("pub struct $operationName") { write("_private: ()") } - operationWriter.implBlock(operationShape, symbolProvider) { - builderGenerator.renderConvenienceMethod(this) + operationWriter.implBlock(symbolProvider.toSymbol(operationShape)) { + BuilderGenerator.renderConvenienceMethod(this, symbolProvider, inputShape) rust("/// Creates a new `$operationName` operation.") rustBlock("pub fn new() -> Self") { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt index 541af1ceb56..3c37049eb3a 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/MakeOperationGenerator.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.protocol import software.amazon.smithy.aws.traits.ServiceTrait import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.generators.http.RequestBindingGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter @@ -54,7 +55,7 @@ open class MakeOperationGenerator( ?: codegenContext.serviceShape.id.getName(codegenContext.serviceShape) private val codegenScope = arrayOf( - "config" to RuntimeType.Config, + "config" to ClientRustModule.Config, "header_util" to RuntimeType.smithyHttp(runtimeConfig).resolve("header"), "http" to RuntimeType.Http, "HttpRequestBuilder" to RuntimeType.HttpRequestBuilder, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt index 476e67ef5a0..5e9fc23cb0f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGenerator.kt @@ -19,13 +19,12 @@ import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.rust @@ -34,7 +33,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.getTrait @@ -91,14 +89,10 @@ class ProtocolTestGenerator( if (allTests.isNotEmpty()) { val operationName = operationSymbol.name val testModuleName = "${operationName.toSnakeCase()}_request_test" - val moduleMeta = RustMetadata( - visibility = Visibility.PRIVATE, - additionalAttributes = listOf( - Attribute.CfgTest, - Attribute(allow("unreachable_code", "unused_variables")), - ), + val additionalAttributes = listOf( + Attribute(allow("unreachable_code", "unused_variables")), ) - writer.withInlineModule(RustModule.LeafModule(testModuleName, moduleMeta, inline = true)) { + writer.withInlineModule(RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes)) { renderAllTestCases(allTests) } } @@ -175,12 +169,12 @@ class ProtocolTestGenerator( } ?: writable { } rustTemplate( """ - let builder = #{Config}::Config::builder().with_test_defaults().endpoint_resolver("https://example.com"); + let builder = #{config}::Config::builder().with_test_defaults().endpoint_resolver("https://example.com"); #{customParams} let config = builder.build(); """, - "Config" to RuntimeType.Config, + "config" to ClientRustModule.Config, "customParams" to customParams, ) writeInline("let input =") @@ -217,9 +211,9 @@ class ProtocolTestGenerator( checkQueryParams(this, httpRequestTestCase.queryParams) checkForbidQueryParams(this, httpRequestTestCase.forbidQueryParams) checkRequiredQueryParams(this, httpRequestTestCase.requireQueryParams) - checkHeaders(this, "&http_request.headers()", httpRequestTestCase.headers) - checkForbidHeaders(this, "&http_request.headers()", httpRequestTestCase.forbidHeaders) - checkRequiredHeaders(this, "&http_request.headers()", httpRequestTestCase.requireHeaders) + checkHeaders(this, "http_request.headers()", httpRequestTestCase.headers) + checkForbidHeaders(this, "http_request.headers()", httpRequestTestCase.forbidHeaders) + checkRequiredHeaders(this, "http_request.headers()", httpRequestTestCase.requireHeaders) if (protocolSupport.requestBodySerialization) { // "If no request body is defined, then no assertions are made about the body of the message." httpRequestTestCase.body.orNull()?.also { body -> @@ -253,10 +247,10 @@ class ProtocolTestGenerator( expectedShape: StructureShape, ) { if (!protocolSupport.responseDeserialization || ( - !protocolSupport.errorDeserialization && expectedShape.hasTrait( + !protocolSupport.errorDeserialization && expectedShape.hasTrait( ErrorTrait::class.java, ) - ) + ) ) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -296,49 +290,53 @@ class ProtocolTestGenerator( "parse_http_response" to RuntimeType.parseHttpResponse(codegenContext.runtimeConfig), ) if (expectedShape.hasTrait()) { - val errorSymbol = operationShape.errorSymbol(codegenContext.symbolProvider) + val errorSymbol = codegenContext.symbolProvider.symbolForOperationError(operationShape) val errorVariant = codegenContext.symbolProvider.toSymbol(expectedShape).name rust("""let parsed = parsed.expect_err("should be error response");""") - rustBlock("if let #TKind::$errorVariant(actual_error) = parsed.kind", errorSymbol) { - rustTemplate("#{AssertEq}(expected_output, actual_error);", *codegenScope) + rustBlock("if let #T::$errorVariant(parsed) = parsed", errorSymbol) { + compareMembers(expectedShape) } rustBlock("else") { rust("panic!(\"wrong variant: Got: {:?}. Expected: {:?}\", parsed, expected_output);") } } else { rust("let parsed = parsed.unwrap();") - outputShape.members().forEach { member -> - val memberName = codegenContext.symbolProvider.toMemberName(member) - if (member.isStreaming(codegenContext.model)) { - rustTemplate( - """ - #{AssertEq}( - parsed.$memberName.collect().await.unwrap().into_bytes(), - expected_output.$memberName.collect().await.unwrap().into_bytes() - ); - """, - *codegenScope, - ) - } else { - when (codegenContext.model.expectShape(member.target)) { - is DoubleShape, is FloatShape -> { - addUseImports( - RuntimeType.protocolTest(codegenContext.runtimeConfig, "FloatEquals").toSymbol(), - ) - rust( - """ - assert!(parsed.$memberName.float_equals(&expected_output.$memberName), - "Unexpected value for `$memberName` {:?} vs. {:?}", expected_output.$memberName, parsed.$memberName); - """, - ) - } - - else -> - rustTemplate( - """#{AssertEq}(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""", - *codegenScope, - ) + compareMembers(outputShape) + } + } + + private fun RustWriter.compareMembers(shape: StructureShape) { + shape.members().forEach { member -> + val memberName = codegenContext.symbolProvider.toMemberName(member) + if (member.isStreaming(codegenContext.model)) { + rustTemplate( + """ + #{AssertEq}( + parsed.$memberName.collect().await.unwrap().into_bytes(), + expected_output.$memberName.collect().await.unwrap().into_bytes() + ); + """, + *codegenScope, + ) + } else { + when (codegenContext.model.expectShape(member.target)) { + is DoubleShape, is FloatShape -> { + addUseImports( + RuntimeType.protocolTest(codegenContext.runtimeConfig, "FloatEquals").toSymbol(), + ) + rust( + """ + assert!(parsed.$memberName.float_equals(&expected_output.$memberName), + "Unexpected value for `$memberName` {:?} vs. {:?}", expected_output.$memberName, parsed.$memberName); + """, + ) } + + else -> + rustTemplate( + """#{AssertEq}(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""", + *codegenScope, + ) } } } @@ -359,7 +357,7 @@ class ProtocolTestGenerator( assertOk(rustWriter) { rustWriter.write( "#T(&body, ${ - rustWriter.escape(body).dq() + rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", RuntimeType.protocolTest(codegenContext.runtimeConfig, "validate_body"), RuntimeType.protocolTest(codegenContext.runtimeConfig, "MediaType"), diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt index ecfea77ff55..a194dc98a2e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/ClientProtocolLoader.kt @@ -7,16 +7,19 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait +import software.amazon.smithy.aws.traits.protocols.AwsQueryCompatibleTrait import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion +import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsQueryCompatible import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsQueryProtocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.Ec2QueryProtocol import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol @@ -25,6 +28,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.core.util.hasTrait class ClientProtocolLoader(supportedProtocols: ProtocolMap) : ProtocolLoader(supportedProtocols) { @@ -57,12 +61,20 @@ private val CLIENT_PROTOCOL_SUPPORT = ProtocolSupport( private class ClientAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory { - override fun protocol(codegenContext: ClientCodegenContext): Protocol = AwsJson(codegenContext, version) + override fun protocol(codegenContext: ClientCodegenContext): Protocol = + if (compatibleWithAwsQuery(codegenContext.serviceShape, version)) { + AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version)) + } else { + AwsJson(codegenContext, version) + } override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): HttpBoundProtocolGenerator = HttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT + + private fun compatibleWithAwsQuery(serviceShape: ServiceShape, version: AwsJsonVersion) = + serviceShape.hasTrait() && version == AwsJsonVersion.Json10 } private class ClientAwsQueryFactory : ProtocolGeneratorFactory { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt index 0592d91cd9a..f47a15e559c 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/HttpBoundProtocolGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.generators.http.ResponseBindingGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.MakeOperationGenerator @@ -29,8 +30,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustom import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationSection import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTraitImplGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor @@ -49,7 +48,7 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.core.util.toSnakeCase class HttpBoundProtocolGenerator( - codegenContext: CodegenContext, + codegenContext: ClientCodegenContext, protocol: Protocol, ) : ClientProtocolGenerator( codegenContext, @@ -116,6 +115,7 @@ class HttpBoundProtocolTraitImplGenerator( impl #{ParseStrict} for $operationName { type Output = std::result::Result<#{O}, #{E}>; fn parse(&self, response: &#{http}::Response<#{Bytes}>) -> Self::Output { + #{BeforeParseResponse} if !response.status().is_success() && response.status().as_u16() != $successCode { #{parse_error}(response) } else { @@ -125,9 +125,12 @@ class HttpBoundProtocolTraitImplGenerator( }""", *codegenScope, "O" to outputSymbol, - "E" to operationShape.errorSymbol(symbolProvider), - "parse_error" to parseError(operationShape), + "E" to symbolProvider.symbolForOperationError(operationShape), + "parse_error" to parseError(operationShape, customizations), "parse_response" to parseResponse(operationShape, customizations), + "BeforeParseResponse" to writable { + writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response")) + }, ) } @@ -156,18 +159,18 @@ class HttpBoundProtocolTraitImplGenerator( } """, "O" to outputSymbol, - "E" to operationShape.errorSymbol(symbolProvider), + "E" to symbolProvider.symbolForOperationError(operationShape), "parse_streaming_response" to parseStreamingResponse(operationShape, customizations), - "parse_error" to parseError(operationShape), + "parse_error" to parseError(operationShape, customizations), *codegenScope, ) } - private fun parseError(operationShape: OperationShape): RuntimeType { + private fun parseError(operationShape: OperationShape, customizations: List): RuntimeType { val fnName = "parse_${operationShape.id.name.toSnakeCase()}_error" val outputShape = operationShape.outputShape(model) val outputSymbol = symbolProvider.toSymbol(outputShape) - val errorSymbol = operationShape.errorSymbol(symbolProvider) + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) return RuntimeType.forInlineFun(fnName, operationDeserModule) { Attribute.AllowClippyUnnecessaryWraps.render(this) rustBlockTemplate( @@ -176,11 +179,17 @@ class HttpBoundProtocolTraitImplGenerator( "O" to outputSymbol, "E" to errorSymbol, ) { + Attribute.AllowUnusedMut.render(this) rust( - "let generic = #T(response).map_err(#T::unhandled)?;", - protocol.parseHttpGenericError(operationShape), + "let mut generic_builder = #T(response).map_err(#T::unhandled)?;", + protocol.parseHttpErrorMetadata(operationShape), errorSymbol, ) + writeCustomizations( + customizations, + OperationSection.PopulateErrorMetadataExtras(customizations, "generic_builder", "response"), + ) + rust("let generic = generic_builder.build();") if (operationShape.operationErrors(model).isNotEmpty()) { rustTemplate( """ @@ -200,8 +209,8 @@ class HttpBoundProtocolTraitImplGenerator( val variantName = symbolProvider.toSymbol(model.expectShape(error.id)).name val errorCode = httpBindingResolver.errorCode(errorShape).dq() withBlock( - "$errorCode => #1T { meta: generic, kind: #1TKind::$variantName({", - "})},", + "$errorCode => #1T::$variantName({", + "}),", errorSymbol, ) { Attribute.AllowUnusedMut.render(this) @@ -212,7 +221,14 @@ class HttpBoundProtocolTraitImplGenerator( errorShape, httpBindingResolver.errorResponseBindings(errorShape), errorSymbol, - listOf(), + listOf(object : OperationCustomization() { + override fun section(section: OperationSection): Writable = writable { + if (section is OperationSection.MutateOutput) { + rust("let output = output.meta(generic);") + } + } + }, + ), ) } } @@ -241,7 +257,7 @@ class HttpBoundProtocolTraitImplGenerator( val fnName = "parse_${operationShape.id.name.toSnakeCase()}" val outputShape = operationShape.outputShape(model) val outputSymbol = symbolProvider.toSymbol(outputShape) - val errorSymbol = operationShape.errorSymbol(symbolProvider) + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) return RuntimeType.forInlineFun(fnName, operationDeserModule) { Attribute.AllowClippyUnnecessaryWraps.render(this) rustBlockTemplate( @@ -270,7 +286,7 @@ class HttpBoundProtocolTraitImplGenerator( val fnName = "parse_${operationShape.id.name.toSnakeCase()}_response" val outputShape = operationShape.outputShape(model) val outputSymbol = symbolProvider.toSymbol(outputShape) - val errorSymbol = operationShape.errorSymbol(symbolProvider) + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) return RuntimeType.forInlineFun(fnName, operationDeserModule) { Attribute.AllowClippyUnnecessaryWraps.render(this) rustBlockTemplate( @@ -296,13 +312,13 @@ class HttpBoundProtocolTraitImplGenerator( operationShape: OperationShape, outputShape: StructureShape, bindings: List, - errorSymbol: RuntimeType, + errorSymbol: Symbol, customizations: List, ) { val httpBindingGenerator = ResponseBindingGenerator(protocol, codegenContext, operationShape) val structuredDataParser = protocol.structuredDataParser(operationShape) Attribute.AllowUnusedMut.render(this) - rust("let mut output = #T::default();", outputShape.builderSymbol(symbolProvider)) + rust("let mut output = #T::default();", symbolProvider.symbolForBuilder(outputShape)) // avoid non-usage warnings for response rust("let _ = response;") if (outputShape.id == operationShape.output.get()) { @@ -334,7 +350,9 @@ class HttpBoundProtocolTraitImplGenerator( val err = if (BuilderGenerator.hasFallibleBuilder(outputShape, symbolProvider)) { ".map_err(${format(errorSymbol)}::unhandled)?" - } else "" + } else { + "" + } writeCustomizations(customizations, OperationSection.MutateOutput(customizations, operationShape)) @@ -352,7 +370,7 @@ class HttpBoundProtocolTraitImplGenerator( httpBindingGenerator: ResponseBindingGenerator, structuredDataParser: StructuredDataParserGenerator, ): Writable? { - val errorSymbol = operationShape.errorSymbol(symbolProvider) + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) val member = binding.member return when (binding.location) { HttpLocation.HEADER -> writable { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/AddErrorMessage.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/AddErrorMessage.kt index 7e69ce6995a..02d61d9905e 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/AddErrorMessage.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/AddErrorMessage.kt @@ -19,7 +19,7 @@ import java.util.logging.Logger * * Not all errors are modeled with an error message field. However, in many cases, the server can still send an error. * If an error, specifically, a structure shape with the error trait does not have a member `message` or `Message`, - * this transformer will add a `message` member targeting a string. + * this transformer will add a `Message` member targeting a string. * * This ensures that we always generate a modeled error message field enabling end users to easily extract the error * message when present. @@ -37,7 +37,7 @@ object AddErrorMessage { val addMessageField = shape.hasTrait() && shape is StructureShape && shape.errorMessageMember() == null if (addMessageField && shape is StructureShape) { logger.info("Adding message field to ${shape.id}") - shape.toBuilder().addMember("message", ShapeId.from("smithy.api#String")).build() + shape.toBuilder().addMember("Message", ShapeId.from("smithy.api#String")).build() } else { shape } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/ClientCodegenIntegrationTest.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/ClientCodegenIntegrationTest.kt new file mode 100644 index 00000000000..138f43cd264 --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/ClientCodegenIntegrationTest.kt @@ -0,0 +1,56 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.testutil + +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.build.SmithyBuildPlugin +import software.amazon.smithy.model.Model +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.RustClientCodegenPlugin +import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.codegenIntegrationTest +import java.nio.file.Path + +fun clientIntegrationTest( + model: Model, + params: IntegrationTestParams = IntegrationTestParams(), + additionalDecorators: List = listOf(), + test: (ClientCodegenContext, RustCrate) -> Unit = { _, _ -> }, +): Path { + fun invokeRustCodegenPlugin(ctx: PluginContext) { + val codegenDecorator = object : ClientCodegenDecorator { + override val name: String = "Add tests" + override val order: Byte = 0 + + override fun classpathDiscoverable(): Boolean = false + + override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { + test(codegenContext, rustCrate) + } + } + RustClientCodegenPlugin().executeWithDecorator(ctx, codegenDecorator, *additionalDecorators.toTypedArray()) + } + return codegenIntegrationTest(model, params, invokePlugin = ::invokeRustCodegenPlugin) +} + +/** + * A `SmithyBuildPlugin` that accepts an additional decorator. + * + * This exists to allow tests to easily customize the _real_ build without needing to list out customizations + * or attempt to manually discover them from the path. + */ +abstract class ClientDecoratableBuildPlugin : SmithyBuildPlugin { + abstract fun executeWithDecorator( + context: PluginContext, + vararg decorator: ClientCodegenDecorator, + ) + + override fun execute(context: PluginContext) { + executeWithDecorator(context) + } +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/CodegenIntegrationTest.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/CodegenIntegrationTest.kt deleted file mode 100644 index c764e290595..00000000000 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/CodegenIntegrationTest.kt +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.client.testutil - -import software.amazon.smithy.build.PluginContext -import software.amazon.smithy.build.SmithyBuildPlugin -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.node.ObjectNode -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.RustClientCodegenPlugin -import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig -import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext -import software.amazon.smithy.rust.codegen.core.testutil.printGeneratedFiles -import software.amazon.smithy.rust.codegen.core.util.runCommand -import java.io.File -import java.nio.file.Path - -/** - * Run cargo test on a true, end-to-end, codegen product of a given model. - * - * For test purposes, additional codegen decorators can also be composed. - */ -fun clientIntegrationTest( - model: Model, - additionalDecorators: List = listOf(), - addModuleToEventStreamAllowList: Boolean = false, - service: String? = null, - runtimeConfig: RuntimeConfig? = null, - additionalSettings: ObjectNode = ObjectNode.builder().build(), - command: ((Path) -> Unit)? = null, - test: (ClientCodegenContext, RustCrate) -> Unit = { _, _ -> }, -): Path { - return codegenIntegrationTest( - model, - RustClientCodegenPlugin(), - additionalDecorators, - addModuleToEventStreamAllowList = addModuleToEventStreamAllowList, - service = service, - runtimeConfig = runtimeConfig, - additionalSettings = additionalSettings, - test = test, - command = command, - ) -} - -/** - * A Smithy BuildPlugin that accepts an additional decorator - * - * This exists to allow tests to easily customize the _real_ build without needing to list out customizations - * or attempt to manually discover them from the path - */ -abstract class DecoratableBuildPlugin : SmithyBuildPlugin { - abstract fun executeWithDecorator( - context: PluginContext, - vararg decorator: ClientCodegenDecorator, - ) - - override fun execute(context: PluginContext) { - executeWithDecorator(context) - } -} - -// TODO(https://github.com/awslabs/smithy-rs/issues/1864): move to core once CodegenDecorator is in core -private fun codegenIntegrationTest( - model: Model, - buildPlugin: DecoratableBuildPlugin, - additionalDecorators: List, - additionalSettings: ObjectNode = ObjectNode.builder().build(), - addModuleToEventStreamAllowList: Boolean = false, - service: String? = null, - runtimeConfig: RuntimeConfig? = null, - overrideTestDir: File? = null, test: (ClientCodegenContext, RustCrate) -> Unit, - command: ((Path) -> Unit)? = null, -): Path { - val (ctx, testDir) = generatePluginContext( - model, - additionalSettings, - addModuleToEventStreamAllowList, - service, - runtimeConfig, - overrideTestDir, - ) - - val codegenDecorator = object : ClientCodegenDecorator { - override val name: String = "Add tests" - override val order: Byte = 0 - - override fun classpathDiscoverable(): Boolean = false - - override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { - test(codegenContext, rustCrate) - } - } - buildPlugin.executeWithDecorator(ctx, codegenDecorator, *additionalDecorators.toTypedArray()) - ctx.fileManifest.printGeneratedFiles() - command?.invoke(testDir) ?: "cargo test".runCommand(testDir, environment = mapOf("RUSTFLAGS" to "-D warnings")) - return testDir -} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestConfigCustomization.kt index da2362a195a..cb066882254 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestConfigCustomization.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestConfigCustomization.kt @@ -5,10 +5,10 @@ package software.amazon.smithy.rust.codegen.client.testutil +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfigGenerator -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -74,7 +74,7 @@ fun validateConfigCustomizations( fun stubConfigProject(customization: ConfigCustomization, project: TestWriterDelegator): TestWriterDelegator { val customizations = listOf(stubConfigCustomization("a")) + customization + stubConfigCustomization("b") val generator = ServiceConfigGenerator(customizations = customizations.toList()) - project.withModule(RustModule.Config) { + project.withModule(ClientRustModule.Config) { generator.render(this) unitTest( "config_send_sync", diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestHelpers.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestHelpers.kt index 1b3f8139594..e3fa9faa796 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestHelpers.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/testutil/TestHelpers.kt @@ -6,22 +6,24 @@ package software.amazon.smithy.rust.codegen.client.testutil import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenConfig +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.ClientRustSettings +import software.amazon.smithy.rust.codegen.client.smithy.OldModuleSchemeClientModuleProvider import software.amazon.smithy.rust.codegen.client.smithy.RustClientCodegenPlugin -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings +import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig -import software.amazon.smithy.rust.codegen.core.testutil.TestSymbolVisitorConfig -import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings +import software.amazon.smithy.rust.codegen.core.testutil.TestWriterDelegator -fun clientTestRustSettings( +fun testClientRustSettings( service: ShapeId = ShapeId.from("notrelevant#notrelevant"), moduleName: String = "test-module", moduleVersion: String = "0.0.1", @@ -47,25 +49,41 @@ fun clientTestRustSettings( customizationConfig, ) +val TestClientRustSymbolProviderConfig = RustSymbolProviderConfig( + runtimeConfig = TestRuntimeConfig, + renameExceptions = true, + nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1, + moduleProvider = OldModuleSchemeClientModuleProvider, +) + fun testSymbolProvider(model: Model, serviceShape: ServiceShape? = null): RustSymbolProvider = RustClientCodegenPlugin.baseSymbolProvider( + testClientRustSettings(), model, serviceShape ?: ServiceShape.builder().version("test").id("test#Service").build(), - TestSymbolVisitorConfig, + TestClientRustSymbolProviderConfig, ) -fun testCodegenContext( +fun testClientCodegenContext( model: Model, + symbolProvider: RustSymbolProvider? = null, serviceShape: ServiceShape? = null, - settings: CoreRustSettings = testRustSettings(), - codegenTarget: CodegenTarget = CodegenTarget.CLIENT, -): CodegenContext = CodegenContext( + settings: ClientRustSettings = testClientRustSettings(), + rootDecorator: ClientCodegenDecorator? = null, +): ClientCodegenContext = ClientCodegenContext( model, - testSymbolProvider(model), + symbolProvider ?: testSymbolProvider(model), serviceShape ?: model.serviceShapes.firstOrNull() ?: ServiceShape.builder().version("test").id("test#Service").build(), ShapeId.from("test#Protocol"), settings, - codegenTarget, + rootDecorator ?: CombinedClientCodegenDecorator(emptyList()), ) + +fun TestWriterDelegator.clientRustSettings() = + testClientRustSettings( + service = ShapeId.from("fake#Fake"), + moduleName = "test_${baseDir.toFile().nameWithoutExtension}", + codegenConfig = codegenConfig as ClientCodegenConfig, + ) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ApiKeyAuthDecoratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ApiKeyAuthDecoratorTest.kt new file mode 100644 index 00000000000..41e3ebae7aa --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ApiKeyAuthDecoratorTest.kt @@ -0,0 +1,175 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.customizations + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.integrationTest +import software.amazon.smithy.rust.codegen.core.testutil.runWithWarnings + +internal class ApiKeyAuthDecoratorTest { + private val modelQuery = """ + namespace test + + use aws.api#service + use aws.protocols#restJson1 + + @service(sdkId: "Test Api Key Auth") + @restJson1 + @httpApiKeyAuth(name: "api_key", in: "query") + @auth([httpApiKeyAuth]) + service TestService { + version: "2023-01-01", + operations: [SomeOperation] + } + + structure SomeOutput { + someAttribute: Long, + someVal: String + } + + @http(uri: "/SomeOperation", method: "GET") + operation SomeOperation { + output: SomeOutput + } + """.asSmithyModel() + + @Test + fun `set an api key in query parameter`() { + val testDir = clientIntegrationTest( + modelQuery, + // just run integration tests + IntegrationTestParams(command = { "cargo test --test *".runWithWarnings(it) }), + ) { clientCodegenContext, rustCrate -> + rustCrate.integrationTest("api_key_present_in_property_bag") { + val moduleName = clientCodegenContext.moduleUseName() + Attribute.TokioTest.render(this) + rust( + """ + async fn api_key_present_in_property_bag() { + use aws_smithy_http_auth::api_key::AuthApiKey; + let api_key_value = "some-api-key"; + let conf = $moduleName::Config::builder() + .api_key(AuthApiKey::new(api_key_value)) + .build(); + let operation = $moduleName::input::SomeOperationInput::builder() + .build() + .expect("input is valid") + .make_operation(&conf) + .await + .expect("valid operation"); + let props = operation.properties(); + let api_key_config = props.get::().expect("api key in the bag"); + assert_eq!( + api_key_config, + &AuthApiKey::new(api_key_value), + ); + } + """, + ) + } + + rustCrate.integrationTest("api_key_auth_is_set_in_query") { + val moduleName = clientCodegenContext.moduleUseName() + Attribute.TokioTest.render(this) + rust( + """ + async fn api_key_auth_is_set_in_query() { + use aws_smithy_http_auth::api_key::AuthApiKey; + let api_key_value = "some-api-key"; + let conf = $moduleName::Config::builder() + .api_key(AuthApiKey::new(api_key_value)) + .build(); + let operation = $moduleName::input::SomeOperationInput::builder() + .build() + .expect("input is valid") + .make_operation(&conf) + .await + .expect("valid operation"); + assert_eq!( + operation.request().uri().query(), + Some("api_key=some-api-key"), + ); + } + """, + ) + } + } + "cargo clippy".runWithWarnings(testDir) + } + + private val modelHeader = """ + namespace test + + use aws.api#service + use aws.protocols#restJson1 + + @service(sdkId: "Test Api Key Auth") + @restJson1 + @httpApiKeyAuth(name: "authorization", in: "header", scheme: "ApiKey") + @auth([httpApiKeyAuth]) + service TestService { + version: "2023-01-01", + operations: [SomeOperation] + } + + structure SomeOutput { + someAttribute: Long, + someVal: String + } + + @http(uri: "/SomeOperation", method: "GET") + operation SomeOperation { + output: SomeOutput + } + """.asSmithyModel() + + @Test + fun `set an api key in http header`() { + val testDir = clientIntegrationTest( + modelHeader, + // just run integration tests + IntegrationTestParams(command = { "cargo test --test *".runWithWarnings(it) }), + ) { clientCodegenContext, rustCrate -> + rustCrate.integrationTest("api_key_auth_is_set_in_http_header") { + val moduleName = clientCodegenContext.moduleUseName() + Attribute.TokioTest.render(this) + rust( + """ + async fn api_key_auth_is_set_in_http_header() { + use aws_smithy_http_auth::api_key::AuthApiKey; + let api_key_value = "some-api-key"; + let conf = $moduleName::Config::builder() + .api_key(AuthApiKey::new(api_key_value)) + .build(); + let operation = $moduleName::input::SomeOperationInput::builder() + .build() + .expect("input is valid") + .make_operation(&conf) + .await + .expect("valid operation"); + let props = operation.properties(); + let api_key_config = props.get::().expect("api key in the bag"); + assert_eq!( + api_key_config, + &AuthApiKey::new(api_key_value), + ); + assert_eq!( + operation.request().headers().contains_key("authorization"), + true, + ); + } + """, + ) + } + } + "cargo clippy".runWithWarnings(testDir) + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/HttpVersionListGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/HttpVersionListGeneratorTest.kt index 0c57ba622a7..eb166ae1c69 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/HttpVersionListGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/HttpVersionListGeneratorTest.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest @@ -61,7 +62,7 @@ internal class HttpVersionListGeneratorTest { """ async fn test_http_version_list_defaults() { let conf = $moduleName::Config::builder().build(); - let op = $moduleName::operation::SayHello::builder() + let op = $moduleName::input::SayHelloInput::builder() .greeting("hello") .build().expect("valid operation") .make_operation(&conf).await.expect("hello is a valid prefix"); @@ -112,7 +113,7 @@ internal class HttpVersionListGeneratorTest { """ async fn test_http_version_list_defaults() { let conf = $moduleName::Config::builder().build(); - let op = $moduleName::operation::SayHello::builder() + let op = $moduleName::input::SayHelloInput::builder() .greeting("hello") .build().expect("valid operation") .make_operation(&conf).await.expect("hello is a valid prefix"); @@ -170,8 +171,8 @@ internal class HttpVersionListGeneratorTest { clientIntegrationTest( model, - listOf(FakeSigningDecorator()), - addModuleToEventStreamAllowList = true, + IntegrationTestParams(addModuleToEventStreamAllowList = true), + additionalDecorators = listOf(FakeSigningDecorator()), ) { clientCodegenContext, rustCrate -> val moduleName = clientCodegenContext.moduleUseName() rustCrate.integrationTest("validate_eventstream_http") { @@ -180,7 +181,7 @@ internal class HttpVersionListGeneratorTest { """ async fn test_http_version_list_defaults() { let conf = $moduleName::Config::builder().build(); - let op = $moduleName::operation::SayHello::builder() + let op = $moduleName::input::SayHelloInput::builder() .build().expect("valid operation") .make_operation(&conf).await.unwrap(); let properties = op.properties(); @@ -204,7 +205,9 @@ class FakeSigningDecorator : ClientCodegenDecorator { codegenContext: ClientCodegenContext, baseCustomizations: List, ): List { - return baseCustomizations.filterNot { it is EventStreamSigningConfig } + FakeSigningConfig(codegenContext.runtimeConfig) + return baseCustomizations.filterNot { + it is EventStreamSigningConfig + } + FakeSigningConfig(codegenContext.runtimeConfig) } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ResiliencyConfigCustomizationTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ResiliencyConfigCustomizationTest.kt index 2d2cf76916d..612bf8e3335 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ResiliencyConfigCustomizationTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/customizations/ResiliencyConfigCustomizationTest.kt @@ -6,16 +6,17 @@ package software.amazon.smithy.rust.codegen.client.customizations import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenConfig import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyReExportCustomization +import software.amazon.smithy.rust.codegen.client.testutil.clientRustSettings import software.amazon.smithy.rust.codegen.client.testutil.stubConfigProject -import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext +import software.amazon.smithy.rust.codegen.client.testutil.testClientCodegenContext import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.testutil.rustSettings internal class ResiliencyConfigCustomizationTest { private val baseModel = """ @@ -36,9 +37,9 @@ internal class ResiliencyConfigCustomizationTest { @Test fun `generates a valid config`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) - val project = TestWorkspace.testProject() - val codegenContext = testCodegenContext(model, settings = project.rustSettings()) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) + val project = TestWorkspace.testProject(model, ClientCodegenConfig()) + val codegenContext = testClientCodegenContext(model, settings = project.clientRustSettings()) stubConfigProject(ResiliencyConfigCustomization(codegenContext), project) ResiliencyReExportCustomization(codegenContext.runtimeConfig).extras(project) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/ClientContextParamsDecoratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/ClientContextConfigCustomizationTest.kt similarity index 94% rename from codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/ClientContextParamsDecoratorTest.kt rename to codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/ClientContextConfigCustomizationTest.kt index cb964902097..8c3a4122e94 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/ClientContextParamsDecoratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/ClientContextConfigCustomizationTest.kt @@ -7,14 +7,14 @@ package software.amazon.smithy.rust.codegen.client.endpoint import org.junit.jupiter.api.Test import software.amazon.smithy.rust.codegen.client.smithy.endpoint.ClientContextConfigCustomization -import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext +import software.amazon.smithy.rust.codegen.client.testutil.testClientCodegenContext import software.amazon.smithy.rust.codegen.client.testutil.validateConfigCustomizations import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.unitTest -class ClientContextParamsDecoratorTest { +class ClientContextConfigCustomizationTest { val model = """ namespace test use smithy.rules#clientContextParams @@ -52,6 +52,6 @@ class ClientContextParamsDecoratorTest { """, ) } - validateConfigCustomizations(ClientContextConfigCustomization(testCodegenContext(model)), project) + validateConfigCustomizations(ClientContextConfigCustomization(testClientCodegenContext(model)), project) } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointResolverGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointResolverGeneratorTest.kt index ccd028ebcb8..179cbbc944a 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointResolverGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointResolverGeneratorTest.kt @@ -23,7 +23,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.End import software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators.EndpointTestGenerator import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rulesgen.SmithyEndpointsStdLib import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rulesgen.awsStandardLib -import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext +import software.amazon.smithy.rust.codegen.client.testutil.testClientCodegenContext import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace @@ -64,7 +64,7 @@ class EndpointResolverGeneratorTest { paramsType = EndpointParamsGenerator(suite.ruleSet().parameters).paramsStruct(), resolverType = ruleset, suite.ruleSet().parameters, - codegenContext = testCodegenContext(model = Model.builder().build()), + codegenContext = testClientCodegenContext(model = Model.builder().build()), endpointCustomizations = listOf(), ) testGenerator.generate()(this) @@ -90,7 +90,7 @@ class EndpointResolverGeneratorTest { paramsType = EndpointParamsGenerator(suite.ruleSet().parameters).paramsStruct(), resolverType = ruleset, suite.ruleSet().parameters, - codegenContext = testCodegenContext(Model.builder().build()), + codegenContext = testClientCodegenContext(Model.builder().build()), endpointCustomizations = listOf(), ) testGenerator.generate()(this) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointsDecoratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointsDecoratorTest.kt index d05f3b21de9..2b9e75f4039 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointsDecoratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/endpoint/EndpointsDecoratorTest.kt @@ -11,6 +11,7 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.integrationTest import software.amazon.smithy.rust.codegen.core.testutil.runWithWarnings @@ -123,8 +124,8 @@ class EndpointsDecoratorTest { fun `set an endpoint in the property bag`() { val testDir = clientIntegrationTest( model, - // just run integration tests - command = { "cargo test --test *".runWithWarnings(it) }, + // Just run integration tests. + IntegrationTestParams(command = { "cargo test --test *".runWithWarnings(it) }), ) { clientCodegenContext, rustCrate -> rustCrate.integrationTest("endpoint_params_test") { val moduleName = clientCodegenContext.moduleUseName() @@ -133,7 +134,7 @@ class EndpointsDecoratorTest { """ async fn endpoint_params_are_set() { let conf = $moduleName::Config::builder().a_string_param("hello").a_bool_param(false).build(); - let operation = $moduleName::operation::TestOperation::builder() + let operation = $moduleName::input::TestOperationInput::builder() .bucket("bucket-name").build().expect("input is valid") .make_operation(&conf).await.expect("valid operation"); use $moduleName::endpoint::{Params}; diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt index 5fb38e58e31..42663f27563 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/EventStreamSymbolProviderTest.kt @@ -10,6 +10,8 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.rust.codegen.client.testutil.TestClientRustSymbolProviderConfig +import software.amazon.smithy.rust.codegen.client.testutil.testClientRustSettings import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider @@ -17,7 +19,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig -import software.amazon.smithy.rust.codegen.core.testutil.TestSymbolVisitorConfig import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel class EventStreamSymbolProviderTest { @@ -46,7 +47,11 @@ class EventStreamSymbolProviderTest { ) val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, TestSymbolVisitorConfig), model, CodegenTarget.CLIENT) + val provider = EventStreamSymbolProvider( + TestRuntimeConfig, + SymbolVisitor(testClientRustSettings(), model, service, TestClientRustSymbolProviderConfig), + CodegenTarget.CLIENT, + ) // Look up the synthetic input/output rather than the original input/output val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape @@ -82,7 +87,11 @@ class EventStreamSymbolProviderTest { ) val service = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val provider = EventStreamSymbolProvider(TestRuntimeConfig, SymbolVisitor(model, service, TestSymbolVisitorConfig), model, CodegenTarget.CLIENT) + val provider = EventStreamSymbolProvider( + TestRuntimeConfig, + SymbolVisitor(testClientRustSettings(), model, service, TestClientRustSymbolProviderConfig), + CodegenTarget.CLIENT, + ) // Look up the synthetic input/output rather than the original input/output val inputStream = model.expectShape(ShapeId.from("test.synthetic#TestOperationInput\$inputStream")) as MemberShape diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingShapeSymbolProviderTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingShapeSymbolProviderTest.kt index 6c0c3cdadf1..a2e233c7195 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingShapeSymbolProviderTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/StreamingShapeSymbolProviderTest.kt @@ -9,8 +9,10 @@ import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.smithy.Default import software.amazon.smithy.rust.codegen.core.smithy.defaultValue +import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup @@ -38,8 +40,18 @@ internal class StreamingShapeSymbolProviderTest { // "doing the right thing" val modelWithOperationTraits = OperationNormalizer.transform(model) val symbolProvider = testSymbolProvider(modelWithOperationTraits) - symbolProvider.toSymbol(modelWithOperationTraits.lookup("test.synthetic#GenerateSpeechOutput\$data")).name shouldBe ("ByteStream") - symbolProvider.toSymbol(modelWithOperationTraits.lookup("test.synthetic#GenerateSpeechInput\$data")).name shouldBe ("ByteStream") + modelWithOperationTraits.lookup("test.synthetic#GenerateSpeechOutput\$data").also { shape -> + symbolProvider.toSymbol(shape).also { symbol -> + symbol.name shouldBe "data" + symbol.rustType() shouldBe RustType.Opaque("ByteStream", "aws_smithy_http::byte_stream") + } + } + modelWithOperationTraits.lookup("test.synthetic#GenerateSpeechInput\$data").also { shape -> + symbolProvider.toSymbol(shape).also { symbol -> + symbol.name shouldBe "data" + symbol.rustType() shouldBe RustType.Opaque("ByteStream", "aws_smithy_http::byte_stream") + } + } } @Test diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt new file mode 100644 index 00000000000..14ed8ad37f7 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt @@ -0,0 +1,160 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.client.testutil.testClientCodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup + +class ClientEnumGeneratorTest { + @Test + fun `matching on enum should be forward-compatible`() { + fun expectMatchExpressionCompiles(model: Model, shapeId: String, enumToMatchOn: String) { + val shape = model.lookup(shapeId) + val context = testClientCodegenContext(model) + val project = TestWorkspace.testProject(context.symbolProvider) + project.moduleFor(shape) { + ClientEnumGenerator(context, shape).render(this) + unitTest( + "matching_on_enum_should_be_forward_compatible", + """ + match $enumToMatchOn { + SomeEnum::Variant1 => assert!(false, "expected `Variant3` but got `Variant1`"), + SomeEnum::Variant2 => assert!(false, "expected `Variant3` but got `Variant2`"), + other @ _ if other.as_str() == "Variant3" => assert!(true), + _ => assert!(false, "expected `Variant3` but got `_`"), + } + """.trimIndent(), + ) + } + project.compileAndTest() + } + + val modelV1 = """ + namespace test + + @enum([ + { name: "Variant1", value: "Variant1" }, + { name: "Variant2", value: "Variant2" }, + ]) + string SomeEnum + """.asSmithyModel() + val variant3AsUnknown = """SomeEnum::from("Variant3")""" + expectMatchExpressionCompiles(modelV1, "test#SomeEnum", variant3AsUnknown) + + val modelV2 = """ + namespace test + + @enum([ + { name: "Variant1", value: "Variant1" }, + { name: "Variant2", value: "Variant2" }, + { name: "Variant3", value: "Variant3" }, + ]) + string SomeEnum + """.asSmithyModel() + val variant3AsVariant3 = "SomeEnum::Variant3" + expectMatchExpressionCompiles(modelV2, "test#SomeEnum", variant3AsVariant3) + } + + @Test + fun `impl debug for non-sensitive enum should implement the derived debug trait`() { + val model = """ + namespace test + @enum([ + { name: "Foo", value: "Foo" }, + { name: "Bar", value: "Bar" }, + ]) + string SomeEnum + """.asSmithyModel() + + val shape = model.lookup("test#SomeEnum") + val context = testClientCodegenContext(model) + val project = TestWorkspace.testProject(context.symbolProvider) + project.moduleFor(shape) { + ClientEnumGenerator(context, shape).render(this) + unitTest( + "impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait", + """ + assert_eq!(format!("{:?}", SomeEnum::Foo), "Foo"); + assert_eq!(format!("{:?}", SomeEnum::Bar), "Bar"); + assert_eq!( + format!("{:?}", SomeEnum::from("Baz")), + "Unknown(UnknownVariantValue(\"Baz\"))" + ); + """, + ) + } + project.compileAndTest() + } + + @Test + fun `it escapes the Unknown variant if the enum has an unknown value in the model`() { + val model = """ + namespace test + @enum([ + { name: "Known", value: "Known" }, + { name: "Unknown", value: "Unknown" }, + { name: "UnknownValue", value: "UnknownValue" }, + ]) + string SomeEnum + """.asSmithyModel() + + val shape = model.lookup("test#SomeEnum") + val context = testClientCodegenContext(model) + val project = TestWorkspace.testProject(context.symbolProvider) + project.moduleFor(shape) { + ClientEnumGenerator(context, shape).render(this) + unitTest( + "it_escapes_the_unknown_variant_if_the_enum_has_an_unknown_value_in_the_model", + """ + assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue); + assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_); + assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned()))); + """.trimIndent(), + ) + } + project.compileAndTest() + } + + @Test + fun `generated named enums can roundtrip between string and enum value on the unknown variant`() { + val model = """ + namespace test + @enum([ + { value: "t2.nano", name: "T2_NANO" }, + { value: "t2.micro", name: "T2_MICRO" }, + ]) + string InstanceType + """.asSmithyModel() + + val shape = model.lookup("test#InstanceType") + val context = testClientCodegenContext(model) + val project = TestWorkspace.testProject(context.symbolProvider) + project.moduleFor(shape) { + rust("##![allow(deprecated)]") + ClientEnumGenerator(context, shape).render(this) + unitTest( + "generated_named_enums_roundtrip", + """ + let instance = InstanceType::T2Micro; + assert_eq!(instance.as_str(), "t2.micro"); + assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano); + // round trip unknown variants: + assert_eq!(InstanceType::from("other"), InstanceType::Unknown(crate::types::UnknownVariantValue("other".to_owned()))); + assert_eq!(InstanceType::from("other").as_str(), "other"); + """, + ) + } + project.compileAndTest() + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt index f507ba2d4c7..abd497b49b8 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt @@ -8,17 +8,14 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.client.testutil.testClientCodegenContext import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup internal class ClientInstantiatorTest { @@ -44,7 +41,7 @@ internal class ClientInstantiatorTest { string NamedEnum """.asSmithyModel() - private val codegenContext = testCodegenContext(model) + private val codegenContext = testClientCodegenContext(model) private val symbolProvider = codegenContext.symbolProvider @Test @@ -53,9 +50,9 @@ internal class ClientInstantiatorTest { val sut = clientInstantiator(codegenContext) val data = Node.parse("t2.nano".dq()) - val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { - EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + val project = TestWorkspace.testProject(symbolProvider) + project.moduleFor(shape) { + ClientEnumGenerator(codegenContext, shape).render(this) unitTest("generate_named_enums") { withBlock("let result = ", ";") { sut.render(this, shape, data) @@ -72,9 +69,9 @@ internal class ClientInstantiatorTest { val sut = clientInstantiator(codegenContext) val data = Node.parse("t2.nano".dq()) - val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { - EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + val project = TestWorkspace.testProject(symbolProvider) + project.moduleFor(shape) { + ClientEnumGenerator(codegenContext, shape).render(this) unitTest("generate_unnamed_enums") { withBlock("let result = ", ";") { sut.render(this, shape, data) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt index 2a9787f69d4..2e3f6d6a2d7 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/EndpointTraitBindingsTest.kt @@ -13,10 +13,10 @@ import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace @@ -50,10 +50,10 @@ internal class EndpointTraitBindingsTest { } """.asSmithyModel() val operationShape: OperationShape = model.lookup("test#GetStatus") - val sym = testSymbolProvider(model) + val symbolProvider = testSymbolProvider(model) val endpointBindingGenerator = EndpointTraitBindings( model, - sym, + symbolProvider, TestRuntimeConfig, operationShape, operationShape.expectTrait(EndpointTrait::class.java), @@ -67,7 +67,7 @@ internal class EndpointTraitBindingsTest { } """, ) - implBlock(model.lookup("test#GetStatusInput"), sym) { + implBlock(symbolProvider.toSymbol(model.lookup("test#GetStatusInput"))) { rustBlock( "fn endpoint_prefix(&self) -> std::result::Result<#T::endpoint::EndpointPrefix, #T>", RuntimeType.smithyHttp(TestRuntimeConfig), @@ -145,10 +145,10 @@ internal class EndpointTraitBindingsTest { """ async fn test_endpoint_prefix() { let conf = $moduleName::Config::builder().build(); - $moduleName::operation::SayHello::builder() + $moduleName::input::SayHelloInput::builder() .greeting("hey there!").build().expect("input is valid") .make_operation(&conf).await.expect_err("no spaces or exclamation points in ep prefixes"); - let op = $moduleName::operation::SayHello::builder() + let op = $moduleName::input::SayHelloInput::builder() .greeting("hello") .build().expect("valid operation") .make_operation(&conf).await.expect("hello is a valid prefix"); diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGeneratorTest.kt index 366ff370dc0..2aaa3e21dd8 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/ServiceConfigGeneratorTest.kt @@ -8,8 +8,8 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.config import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -94,7 +94,9 @@ internal class ServiceConfigGeneratorTest { } ServiceConfig.BuilderStruct -> writable { rust("config_field: Option") } ServiceConfig.BuilderImpl -> emptySection - ServiceConfig.BuilderBuild -> writable { rust("config_field: self.config_field.unwrap_or_default(),") } + ServiceConfig.BuilderBuild -> writable { + rust("config_field: self.config_field.unwrap_or_default(),") + } else -> emptySection } } @@ -102,7 +104,7 @@ internal class ServiceConfigGeneratorTest { val sut = ServiceConfigGenerator(listOf(ServiceCustomizer())) val symbolProvider = testSymbolProvider("namespace empty".asSmithyModel()) val project = TestWorkspace.testProject(symbolProvider) - project.withModule(RustModule.Config) { + project.withModule(ClientRustModule.Config) { sut.render(this) unitTest( "set_config_fields", diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGeneratorTest.kt new file mode 100644 index 00000000000..cca1dcb3d5d --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ErrorGeneratorTest.kt @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.error + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel + +class ErrorGeneratorTest { + val model = + """ + namespace com.test + use aws.protocols#awsJson1_1 + + @awsJson1_1 + service TestService { + operations: [TestOp] + } + + operation TestOp { + errors: [MyError] + } + + @error("server") + @retryable + structure MyError { + message: String + } + """.asSmithyModel() + + @Test + fun `generate error structure and builder`() { + clientIntegrationTest(model) { _, rustCrate -> + rustCrate.withFile("src/error.rs") { + rust( + """ + ##[test] + fn test_error_generator() { + use aws_smithy_types::error::metadata::{ErrorMetadata, ProvideErrorMetadata}; + use aws_smithy_types::retry::ErrorKind; + + let err = MyError::builder() + .meta(ErrorMetadata::builder().code("test").message("testmsg").build()) + .message("testmsg") + .build(); + assert_eq!(err.retryable_error_kind(), ErrorKind::ServerError); + assert_eq!("test", err.meta().code().unwrap()); + assert_eq!("testmsg", err.meta().message().unwrap()); + assert_eq!("testmsg", err.message().unwrap()); + } + """, + ) + } + } + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGeneratorTest.kt new file mode 100644 index 00000000000..b81de4010c0 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/OperationErrorGeneratorTest.kt @@ -0,0 +1,90 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.error + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup + +class OperationErrorGeneratorTest { + private val model = """ + namespace error + + @aws.protocols#awsJson1_0 + service TestService { + operations: [Greeting], + } + + operation Greeting { + errors: [InvalidGreeting, ComplexError, FooException, Deprecated] + } + + @error("client") + @retryable + structure InvalidGreeting { + message: String, + } + + @error("server") + structure FooException { } + + @error("server") + structure ComplexError { + abc: String, + other: Integer + } + + @error("server") + @deprecated + structure Deprecated { } + """.asSmithyModel() + + @Test + fun `generates combined error enums`() { + clientIntegrationTest(model) { _, rustCrate -> + rustCrate.moduleFor(model.lookup("error#FooException")) { + unitTest( + name = "generates_combined_error_enums", + test = """ + let error = GreetingError::InvalidGreeting( + InvalidGreeting::builder() + .message("an error") + .meta(aws_smithy_types::Error::builder().code("InvalidGreeting").message("an error").build()) + .build() + ); + assert_eq!(format!("{}", error), "InvalidGreeting: an error"); + assert_eq!(error.meta().message(), Some("an error")); + assert_eq!(error.meta().code(), Some("InvalidGreeting")); + use aws_smithy_types::retry::ProvideErrorKind; + assert_eq!(error.retryable_error_kind(), Some(aws_smithy_types::retry::ErrorKind::ClientError)); + + // Generate is_xyz methods for errors. + assert_eq!(error.is_invalid_greeting(), true); + assert_eq!(error.is_complex_error(), false); + + // Unhandled variants properly delegate message. + let error = GreetingError::generic(aws_smithy_types::Error::builder().message("hello").build()); + assert_eq!(error.meta().message(), Some("hello")); + + let error = GreetingError::unhandled("some other error"); + assert_eq!(error.meta().message(), None); + assert_eq!(error.meta().code(), None); + + // Indicate the original name in the display output. + let error = FooError::builder().build(); + assert_eq!(format!("{}", error), "FooError [FooException]"); + + let error = Deprecated::builder().build(); + assert_eq!(error.to_string(), "Deprecated"); + """, + ) + } + } + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGeneratorTest.kt new file mode 100644 index 00000000000..1cbb274cfb6 --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/error/ServiceErrorGeneratorTest.kt @@ -0,0 +1,63 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.error + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.integrationTest + +internal class ServiceErrorGeneratorTest { + @Test + fun `top level errors are send + sync`() { + val model = """ + namespace com.example + + use aws.protocols#restJson1 + + @restJson1 + service HelloService { + operations: [SayHello], + version: "1" + } + + @http(uri: "/", method: "POST") + operation SayHello { + input: EmptyStruct, + output: EmptyStruct, + errors: [SorryBusy, CanYouRepeatThat, MeDeprecated] + } + + structure EmptyStruct { } + + @error("server") + structure SorryBusy { } + + @error("client") + structure CanYouRepeatThat { } + + @error("client") + @deprecated + structure MeDeprecated { } + """.asSmithyModel() + + clientIntegrationTest(model) { codegenContext, rustCrate -> + rustCrate.integrationTest("validate_errors") { + rust( + """ + fn check_send_sync() {} + + ##[test] + fn service_errors_are_send_sync() { + check_send_sync::<${codegenContext.moduleUseName()}::Error>() + } + """, + ) + } + } + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGeneratorTest.kt index 6f214e35402..be42a9647c5 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/RequestBindingGeneratorTest.kt @@ -11,13 +11,13 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.HttpTrait -import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule +import software.amazon.smithy.rust.codegen.client.testutil.testClientCodegenContext import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer @@ -127,45 +127,47 @@ class RequestBindingGeneratorTest { private val operationShape = model.expectShape(ShapeId.from("smithy.example#PutObject"), OperationShape::class.java) private val inputShape = model.expectShape(operationShape.input.get(), StructureShape::class.java) - private fun renderOperation(writer: RustWriter) { - inputShape.renderWithModelBuilder(model, symbolProvider, writer) - val codegenContext = testCodegenContext(model) - val bindingGen = RequestBindingGenerator( - codegenContext, - // Any protocol is fine for this test. - RestJson(codegenContext), - operationShape, - ) - writer.rustBlock("impl PutObjectInput") { - // RequestBindingGenerator's functions expect to be rendered inside a function, - // but the unit test needs to call some of these functions individually. This generates - // some wrappers that can be called directly from the tests. The functions will get duplicated, - // but that's not a problem. - - rustBlock( - "pub fn test_uri_query(&self, mut output: &mut String) -> Result<(), #T>", - TestRuntimeConfig.operationBuildError(), - ) { - bindingGen.renderUpdateHttpBuilder(this) - rust("uri_query(self, output)") - } - - rustBlock( - "pub fn test_uri_base(&self, mut output: &mut String) -> Result<(), #T>", - TestRuntimeConfig.operationBuildError(), - ) { - bindingGen.renderUpdateHttpBuilder(this) - rust("uri_base(self, output)") - } - - rustBlock( - "pub fn test_request_builder_base(&self) -> Result<#T, #T>", - RuntimeType.HttpRequestBuilder, - TestRuntimeConfig.operationBuildError(), - ) { - bindingGen.renderUpdateHttpBuilder(this) - rust("let builder = #T::new();", RuntimeType.HttpRequestBuilder) - rust("update_http_builder(self, builder)") + private fun renderOperation(rustCrate: RustCrate) { + inputShape.renderWithModelBuilder(model, symbolProvider, rustCrate) + rustCrate.withModule(ClientRustModule.Input) { + val codegenContext = testClientCodegenContext(model) + val bindingGen = RequestBindingGenerator( + codegenContext, + // Any protocol is fine for this test. + RestJson(codegenContext), + operationShape, + ) + rustBlock("impl PutObjectInput") { + // RequestBindingGenerator's functions expect to be rendered inside a function, + // but the unit test needs to call some of these functions individually. This generates + // some wrappers that can be called directly from the tests. The functions will get duplicated, + // but that's not a problem. + + rustBlock( + "pub fn test_uri_query(&self, mut output: &mut String) -> Result<(), #T>", + TestRuntimeConfig.operationBuildError(), + ) { + bindingGen.renderUpdateHttpBuilder(this) + rust("uri_query(self, output)") + } + + rustBlock( + "pub fn test_uri_base(&self, mut output: &mut String) -> Result<(), #T>", + TestRuntimeConfig.operationBuildError(), + ) { + bindingGen.renderUpdateHttpBuilder(this) + rust("uri_base(self, output)") + } + + rustBlock( + "pub fn test_request_builder_base(&self) -> Result<#T, #T>", + RuntimeType.HttpRequestBuilder, + TestRuntimeConfig.operationBuildError(), + ) { + bindingGen.renderUpdateHttpBuilder(this) + rust("let builder = #T::new();", RuntimeType.HttpRequestBuilder) + rust("update_http_builder(self, builder)") + } } } } @@ -179,9 +181,8 @@ class RequestBindingGeneratorTest { @Test fun `generates valid request bindings`() { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(RustModule.public("input")) { // Currently rendering the operation renders the protocols—I want to separate that at some point. - renderOperation(this) - + renderOperation(project) + project.withModule(ClientRustModule.Input) { // Currently rendering the operation renders the protocols—I want to separate that at some point. unitTest( name = "generate_uris", test = """ diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGeneratorTest.kt index 19ab23102fd..e6c2d1ae2a4 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/http/ResponseBindingGeneratorTest.kt @@ -8,12 +8,11 @@ package software.amazon.smithy.rust.codegen.client.smithy.generators.http import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule +import software.amazon.smithy.rust.codegen.client.testutil.testClientCodegenContext import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolContentTypes @@ -67,23 +66,25 @@ class ResponseBindingGeneratorTest { """.asSmithyModel() private val model = OperationNormalizer.transform(baseModel) private val operationShape = model.expectShape(ShapeId.from("smithy.example#PutObject"), OperationShape::class.java) - private val codegenContext: CodegenContext = testCodegenContext(model) + private val codegenContext = testClientCodegenContext(model) private val symbolProvider = codegenContext.symbolProvider - private fun RustWriter.renderOperation() { + private fun RustCrate.renderOperation() { operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this) - rustBlock("impl PutObjectOutput") { - val bindings = HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("dont-care")) - .responseBindings(operationShape) - .filter { it.location == HttpLocation.HEADER } - bindings.forEach { binding -> - val runtimeType = ResponseBindingGenerator( - RestJson(codegenContext), - codegenContext, - operationShape, - ).generateDeserializeHeaderFn(binding) - // little hack to force these functions to be generated - rust("// use #T;", runtimeType) + withModule(ClientRustModule.Output) { + rustBlock("impl PutObjectOutput") { + val bindings = HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("dont-care")) + .responseBindings(operationShape) + .filter { it.location == HttpLocation.HEADER } + bindings.forEach { binding -> + val runtimeType = ResponseBindingGenerator( + RestJson(codegenContext), + codegenContext, + operationShape, + ).generateDeserializeHeaderFn(binding) + // little hack to force these functions to be generated + rust("// use #T;", runtimeType) + } } } } @@ -91,8 +92,8 @@ class ResponseBindingGeneratorTest { @Test fun deserializeHeadersIntoOutputShape() { val testProject = TestWorkspace.testProject(symbolProvider) - testProject.withModule(RustModule.public("output")) { - renderOperation() + testProject.renderOperation() + testProject.withModule(ClientRustModule.Output) { unitTest( "http_header_deser", """ diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt index 1da709422d0..c6f427d318e 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/protocol/ProtocolTestGeneratorTest.kt @@ -21,7 +21,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTraitImplGenerator @@ -62,10 +61,11 @@ private class TestProtocolTraitImplGenerator( fn parse(&self, _response: &#{Response}<#{Bytes}>) -> Self::Output { ${operationWriter.escape(correctResponse)} } - }""", + } + """, "parse_strict" to RuntimeType.parseStrictResponse(codegenContext.runtimeConfig), "Output" to symbolProvider.toSymbol(operationShape.outputShape(codegenContext.model)), - "Error" to operationShape.errorSymbol(symbolProvider), + "Error" to symbolProvider.symbolForOperationError(operationShape), "Response" to RuntimeType.HttpResponse, "Bytes" to RuntimeType.Bytes, ) @@ -92,7 +92,7 @@ private class TestProtocolMakeOperationGenerator( // A stubbed test protocol to do enable testing intentionally broken protocols private class TestProtocolGenerator( - codegenContext: CodegenContext, + codegenContext: ClientCodegenContext, protocol: Protocol, httpRequestBuilder: String, body: String, diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryCompatibleTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryCompatibleTest.kt new file mode 100644 index 00000000000..df33d83151d --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsQueryCompatibleTest.kt @@ -0,0 +1,152 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.protocols + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.integrationTest + +class AwsQueryCompatibleTest { + @Test + fun `aws-query-compatible json with aws query error should allow for retrieving error code and type from custom header`() { + val model = """ + namespace test + use aws.protocols#awsJson1_0 + use aws.protocols#awsQueryCompatible + use aws.protocols#awsQueryError + + @awsQueryCompatible + @awsJson1_0 + service TestService { + version: "2023-02-20", + operations: [SomeOperation] + } + + operation SomeOperation { + input: SomeOperationInputOutput, + output: SomeOperationInputOutput, + errors: [InvalidThingException], + } + + structure SomeOperationInputOutput { + a: String, + b: Integer + } + + @awsQueryError( + code: "InvalidThing", + httpResponseCode: 400, + ) + @error("client") + structure InvalidThingException { + message: String + } + """.asSmithyModel() + + clientIntegrationTest(model) { clientCodegenContext, rustCrate -> + val moduleName = clientCodegenContext.moduleUseName() + rustCrate.integrationTest("should_parse_code_and_type_fields") { + rust( + """ + ##[test] + fn should_parse_code_and_type_fields() { + use aws_smithy_http::response::ParseStrictResponse; + + let response = http::Response::builder() + .header( + "x-amzn-query-error", + http::HeaderValue::from_static("AWS.SimpleQueueService.NonExistentQueue;Sender"), + ) + .status(400) + .body( + r##"{ + "__type": "com.amazonaws.sqs##QueueDoesNotExist", + "message": "Some user-visible message" + }"##, + ) + .unwrap(); + let some_operation = $moduleName::operation::SomeOperation::new(); + let error = some_operation + .parse(&response.map(bytes::Bytes::from)) + .err() + .unwrap(); + assert_eq!( + Some("AWS.SimpleQueueService.NonExistentQueue"), + error.meta().code(), + ); + assert_eq!(Some("Sender"), error.meta().extra("type")); + } + """, + ) + } + } + } + + @Test + fun `aws-query-compatible json without aws query error should allow for retrieving error code from payload`() { + val model = """ + namespace test + use aws.protocols#awsJson1_0 + use aws.protocols#awsQueryCompatible + + @awsQueryCompatible + @awsJson1_0 + service TestService { + version: "2023-02-20", + operations: [SomeOperation] + } + + operation SomeOperation { + input: SomeOperationInputOutput, + output: SomeOperationInputOutput, + errors: [InvalidThingException], + } + + structure SomeOperationInputOutput { + a: String, + b: Integer + } + + @error("client") + structure InvalidThingException { + message: String + } + """.asSmithyModel() + + clientIntegrationTest(model) { clientCodegenContext, rustCrate -> + val moduleName = clientCodegenContext.moduleUseName() + rustCrate.integrationTest("should_parse_code_from_payload") { + rust( + """ + ##[test] + fn should_parse_code_from_payload() { + use aws_smithy_http::response::ParseStrictResponse; + + let response = http::Response::builder() + .status(400) + .body( + r##"{ + "__type": "com.amazonaws.sqs##QueueDoesNotExist", + "message": "Some user-visible message" + }"##, + ) + .unwrap(); + let some_operation = $moduleName::operation::SomeOperation::new(); + let error = some_operation + .parse(&response.map(bytes::Bytes::from)) + .err() + .unwrap(); + assert_eq!(Some("QueueDoesNotExist"), error.meta().code()); + assert_eq!(None, error.meta().extra("type")); + } + """, + ) + } + } + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt deleted file mode 100644 index 1717dab2d67..00000000000 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamBaseRequirements.kt +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream - -import org.junit.jupiter.api.extension.ExtensionContext -import org.junit.jupiter.params.provider.Arguments -import org.junit.jupiter.params.provider.ArgumentsProvider -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.client.smithy.customize.CombinedClientCodegenDecorator -import software.amazon.smithy.rust.codegen.client.testutil.clientTestRustSettings -import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.OperationErrorGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements -import java.util.stream.Stream - -class TestCasesProvider : ArgumentsProvider { - override fun provideArguments(context: ExtensionContext?): Stream = - EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream() -} - -abstract class ClientEventStreamBaseRequirements : EventStreamTestRequirements { - override fun createCodegenContext( - model: Model, - serviceShape: ServiceShape, - protocolShapeId: ShapeId, - codegenTarget: CodegenTarget, - ): ClientCodegenContext = ClientCodegenContext( - model, - testSymbolProvider(model), - serviceShape, - protocolShapeId, - clientTestRustSettings(), - CombinedClientCodegenDecorator(emptyList()), - ) - - override fun renderBuilderForShape( - writer: RustWriter, - codegenContext: ClientCodegenContext, - shape: StructureShape, - ) { - BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape).apply { - render(writer) - writer.implBlock(shape, codegenContext.symbolProvider) { - renderConvenienceMethod(writer) - } - } - } - - override fun renderOperationError( - writer: RustWriter, - model: Model, - symbolProvider: RustSymbolProvider, - operationSymbol: Symbol, - errors: List, - ) { - OperationErrorGenerator(model, symbolProvider, operationSymbol, errors).render(writer) - } -} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt index 936d3b63243..349f6a8cf31 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamMarshallerGeneratorTest.kt @@ -5,42 +5,30 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream +import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety -import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import java.util.stream.Stream class ClientEventStreamMarshallerGeneratorTest { @ParameterizedTest @ArgumentsSource(TestCasesProvider::class) fun test(testCase: EventStreamTestModels.TestCase) { - EventStreamTestTools.runTestCase( - testCase, - object : ClientEventStreamBaseRequirements() { - override fun renderGenerator( - codegenContext: ClientCodegenContext, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType = EventStreamMarshallerGenerator( - project.model, - CodegenTarget.CLIENT, - TestRuntimeConfig, - project.symbolProvider, - project.streamShape, - protocol.structuredDataSerializer(project.operationShape), - testCase.requestContentType, - ).render() - }, - CodegenTarget.CLIENT, - EventStreamTestVariety.Marshall, - ) + clientIntegrationTest(testCase.model) { _, rustCrate -> + rustCrate.testModule { + writeMarshallTestCases(testCase, optionalBuilderInputs = false) + } + } } } + +class TestCasesProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext?): Stream = + EventStreamTestModels.TEST_CASES.map { Arguments.of(it) }.stream() +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt index f9be7b3bf4c..db44d62a61c 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/eventstream/ClientEventStreamUnmarshallerGeneratorTest.kt @@ -7,43 +7,60 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols.eventstream import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator +import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest +import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety -import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import software.amazon.smithy.rust.codegen.core.testutil.unitTest class ClientEventStreamUnmarshallerGeneratorTest { @ParameterizedTest @ArgumentsSource(TestCasesProvider::class) fun test(testCase: EventStreamTestModels.TestCase) { - EventStreamTestTools.runTestCase( - testCase, - object : ClientEventStreamBaseRequirements() { - override fun renderGenerator( - codegenContext: ClientCodegenContext, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType { - fun builderSymbol(shape: StructureShape): Symbol = shape.builderSymbol(codegenContext.symbolProvider) - return EventStreamUnmarshallerGenerator( - protocol, - codegenContext, - project.operationShape, - project.streamShape, - ::builderSymbol, - ).render() - } - }, - CodegenTarget.CLIENT, - EventStreamTestVariety.Unmarshall, - ) + clientIntegrationTest( + testCase.model, + IntegrationTestParams(service = "test#TestService", addModuleToEventStreamAllowList = true), + ) { _, rustCrate -> + val generator = "crate::event_stream_serde::TestStreamUnmarshaller" + + rustCrate.testModule { + rust("##![allow(unused_imports, dead_code)]") + writeUnmarshallTestCases(testCase, optionalBuilderInputs = false) + + unitTest( + "unknown_message", + """ + let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!"); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert!(expect_event(result.unwrap()).is_unknown()); + """, + ) + + unitTest( + "generic_error", + """ + let message = msg( + "exception", + "UnmodeledError", + "${testCase.responseContentType}", + br#"${testCase.validUnmodeledError}"# + ); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + match expect_error(result.unwrap()) { + TestStreamError::Unhandled(err) => { + let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err)); + let expected = "message: \"unmodeled error\""; + assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'"); + } + kind => panic!("expected generic error, but got {:?}", kind), + } + """, + ) + } + } } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperationsTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperationsTest.kt index 7873a74c2e1..4cc9c594aaa 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperationsTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/transformers/RemoveEventStreamOperationsTest.kt @@ -11,7 +11,7 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenConfig -import software.amazon.smithy.rust.codegen.client.testutil.clientTestRustSettings +import software.amazon.smithy.rust.codegen.client.testutil.testClientRustSettings import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import java.util.Optional @@ -49,7 +49,7 @@ internal class RemoveEventStreamOperationsTest { fun `remove event stream ops from services that are not in the allow list`() { val transformed = RemoveEventStreamOperations.transform( model, - clientTestRustSettings( + testClientRustSettings( codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf("not-test-module")), ), ) @@ -61,7 +61,7 @@ internal class RemoveEventStreamOperationsTest { fun `keep event stream ops from services that are in the allow list`() { val transformed = RemoveEventStreamOperations.transform( model, - clientTestRustSettings( + testClientRustSettings( codegenConfig = ClientCodegenConfig(eventStreamAllowList = setOf("test-module")), ), ) diff --git a/codegen-core/common-test-models/naming-obstacle-course-casing.smithy b/codegen-core/common-test-models/naming-obstacle-course-casing.smithy new file mode 100644 index 00000000000..fb80a46d48b --- /dev/null +++ b/codegen-core/common-test-models/naming-obstacle-course-casing.smithy @@ -0,0 +1,63 @@ +$version: "1.0" +namespace casing + +use aws.protocols#awsJson1_1 + +// TODO(https://github.com/awslabs/smithy-rs/issues/2340): The commented part of the model breaks the generator in a +// miriad of ways. Any solution to the linked issue must address this. + +/// Confounds model generation machinery with lots of problematic casing +@awsJson1_1 +service ACRONYMInside_Service { + operations: [ + DoNothing, + // ACRONYMInside_Op + // ACRONYM_InsideOp + ] +} + +operation DoNothing {} + +// operation ACRONYMInside_Op { +// input: Input, +// output: Output, +// errors: [Error], +// } + +// operation ACRONYM_InsideOp { +// input: Input, +// output: Output, +// errors: [Error], +// } + +// structure Input { +// ACRONYMInside_Member: ACRONYMInside_Struct, +// ACRONYM_Inside_Member: ACRONYM_InsideStruct, +// ACRONYM_InsideMember: ACRONYMInsideStruct +// } + +// structure Output { +// ACRONYMInside_Member: ACRONYMInside_Struct, +// ACRONYM_Inside_Member: ACRONYM_InsideStruct, +// ACRONYM_InsideMember: ACRONYMInsideStruct +// } + +// @error("client") +// structure Error { +// ACRONYMInside_Member: ACRONYMInside_Struct, +// ACRONYM_Inside_Member: ACRONYM_InsideStruct, +// ACRONYM_InsideMember: ACRONYMInsideStruct +// } + +// structure ACRONYMInside_Struct { +// ACRONYMInside_Member: ACRONYM_InsideStruct, +// ACRONYM_Inside_Member: Integer, +// } + +// structure ACRONYM_InsideStruct { +// ACRONYMInside_Member: Integer, +// } + +// structure ACRONYMInsideStruct { +// ACRONYMInside_Member: Integer, +// } diff --git a/codegen-core/common-test-models/simple.smithy b/codegen-core/common-test-models/simple.smithy index 43c4bc6aca0..c7e58c8e4a8 100644 --- a/codegen-core/common-test-models/simple.smithy +++ b/codegen-core/common-test-models/simple.smithy @@ -1,136 +1,22 @@ -$version: "1.0" +$version: "2.0" namespace com.amazonaws.simple use aws.protocols#restJson1 -use smithy.test#httpRequestTests -use smithy.test#httpResponseTests -use smithy.framework#ValidationException @restJson1 -@title("SimpleService") -@documentation("A simple service example, with a Service resource that can be registered and a readonly healthcheck") service SimpleService { - version: "2022-01-01", - resources: [ - Service, - ], operations: [ - Healthcheck, - StoreServiceBlob, - ], + Operation + ] } -@documentation("Id of the service that will be registered") -string ServiceId - -@documentation("Name of the service that will be registered") -string ServiceName - -@error("client") -@documentation( - """ - Returned when a new resource cannot be created because one already exists. - """ -) -structure ResourceAlreadyExists { - @required - message: String -} - -@documentation("A resource that can register services") -resource Service { - identifiers: { id: ServiceId }, - put: RegisterService, -} - -@idempotent -@http(method: "PUT", uri: "/service/{id}") -@documentation("Service register operation") -@httpRequestTests([ - { - id: "RegisterServiceRequestTest", - protocol: "aws.protocols#restJson1", - uri: "/service/1", - headers: { - "Content-Type": "application/json", - }, - params: { id: "1", name: "TestService" }, - body: "{\"name\":\"TestService\"}", - method: "PUT", - } -]) -@httpResponseTests([ - { - id: "RegisterServiceResponseTest", - protocol: "aws.protocols#restJson1", - params: { id: "1", name: "TestService" }, - body: "{\"id\":\"1\",\"name\":\"TestService\"}", - code: 200, - headers: { - "Content-Length": "31" - } - } -]) -operation RegisterService { - input: RegisterServiceInputRequest, - output: RegisterServiceOutputResponse, - errors: [ResourceAlreadyExists, ValidationException] -} - -@documentation("Service register input structure") -structure RegisterServiceInputRequest { - @required - @httpLabel - id: ServiceId, - name: ServiceName, -} - -@documentation("Service register output structure") -structure RegisterServiceOutputResponse { - @required - id: ServiceId, - name: ServiceName, -} - -@readonly -@http(uri: "/healthcheck", method: "GET") -@documentation("Read-only healthcheck operation") -operation Healthcheck { - input: HealthcheckInputRequest, - output: HealthcheckOutputResponse +@http(uri: "/operation", method: "POST") +operation Operation { + input: OperationInputOutput + output: OperationInputOutput } -@documentation("Service healthcheck output structure") -structure HealthcheckInputRequest { - -} - -@documentation("Service healthcheck input structure") -structure HealthcheckOutputResponse { - -} - -@readonly -@http(method: "POST", uri: "/service/{id}/blob") -@documentation("Stores a blob for a service id") -operation StoreServiceBlob { - input: StoreServiceBlobInput, - output: StoreServiceBlobOutput, - errors: [ValidationException] -} - -@documentation("Store a blob for a service id input structure") -structure StoreServiceBlobInput { - @required - @httpLabel - id: ServiceId, - @required - @httpPayload - content: Blob, -} - -@documentation("Store a blob for a service id output structure") -structure StoreServiceBlobOutput { - +structure OperationInputOutput { + message: String } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt index 6d355da68a9..167fd2997fa 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/CargoDependency.kt @@ -95,6 +95,13 @@ class InlineDependency( CargoDependency.Http, ) + fun awsQueryCompatibleErrors(runtimeConfig: RuntimeConfig) = + forInlineableRustFile( + "aws_query_compatible_errors", + CargoDependency.smithyJson(runtimeConfig), + CargoDependency.Http, + ) + fun idempotencyToken() = forInlineableRustFile("idempotency_token", CargoDependency.FastRand) @@ -133,6 +140,8 @@ data class CargoDependency( return copy(features = features.toMutableSet().apply { add(feature) }) } + fun toDevDependency() = copy(scope = DependencyScope.Dev) + override fun version(): String = when (location) { is CratesIo -> location.version is Local -> "local" @@ -220,7 +229,12 @@ data class CargoDependency( val Smol: CargoDependency = CargoDependency("smol", CratesIo("1.2.0"), DependencyScope.Dev) val TempFile: CargoDependency = CargoDependency("tempfile", CratesIo("3.2.0"), DependencyScope.Dev) val Tokio: CargoDependency = - CargoDependency("tokio", CratesIo("1.8.4"), DependencyScope.Dev, features = setOf("macros", "test-util", "rt-multi-thread")) + CargoDependency( + "tokio", + CratesIo("1.8.4"), + DependencyScope.Dev, + features = setOf("macros", "test-util", "rt-multi-thread"), + ) val TracingAppender: CargoDependency = CargoDependency( "tracing-appender", CratesIo("0.2.2"), @@ -236,12 +250,17 @@ data class CargoDependency( fun smithyAsync(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-async") fun smithyChecksums(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-checksums") fun smithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-client") + fun smithyClientTestUtil(runtimeConfig: RuntimeConfig) = + smithyClient(runtimeConfig).toDevDependency().withFeature("test-util") + fun smithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-eventstream") fun smithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http") + fun smithyHttpAuth(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-auth") fun smithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-tower") fun smithyJson(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-json") fun smithyProtocolTestHelpers(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-protocol-test", scope = DependencyScope.Dev) + fun smithyQuery(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-query") fun smithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-types") fun smithyXml(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-xml") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt index 6745e3b2ee7..ef2786a396d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt @@ -32,7 +32,10 @@ sealed class RustModule { val documentation: String? = null, val parent: RustModule = LibRs, val inline: Boolean = false, + /* module is a cfg(test) module */ + val tests: Boolean = false, ) : RustModule() { + init { check(!name.contains("::")) { "Module names CANNOT contain `::`—modules must be nested with parent (name was: `$name`)" @@ -45,6 +48,12 @@ sealed class RustModule { "Module `$name` cannot be a module name—it is a reserved word." } } + + /** Convert a module into a module gated with `#[cfg(test)]` */ + fun cfgTest(): LeafModule = this.copy( + rustMetadata = rustMetadata.copy(additionalAttributes = rustMetadata.additionalAttributes + Attribute.CfgTest), + tests = true, + ) } companion object { @@ -78,13 +87,17 @@ sealed class RustModule { fun pubCrate(name: String, documentation: String? = null, parent: RustModule): LeafModule = new(name, visibility = Visibility.PUBCRATE, documentation = documentation, inline = false, parent = parent) - /* Common modules used across client, server and tests */ - val Config = public("config", documentation = "Configuration for the service.") - val Error = public("error", documentation = "All error types that operations can return. Documentation on these types is copied from the model.") - val Model = public("model", documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.") - val Input = public("input", documentation = "Input structures for operations. Documentation on these types is copied from the model.") - val Output = public("output", documentation = "Output structures for operations. Documentation on these types is copied from the model.") - val Types = public("types", documentation = "Data primitives referenced by other data types.") + fun inlineTests( + name: String = "test", + parent: RustModule = LibRs, + additionalAttributes: List = listOf(), + ) = new( + name, + Visibility.PRIVATE, + inline = true, + additionalAttributes = additionalAttributes, + parent = parent, + ).cfgTest() /** * Helper method to generate the `operation` Rust module. diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt index efe9ae7cc88..3879167383f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWords.kt @@ -8,30 +8,32 @@ package software.amazon.smithy.rust.codegen.core.rustlang import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider import software.amazon.smithy.codegen.core.ReservedWords import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.EnumShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.EnumDefinition -import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed +import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom +import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.letIf -import software.amazon.smithy.rust.codegen.core.util.orNull -import software.amazon.smithy.rust.codegen.core.util.toPascalCase -class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, private val model: Model) : - WrappingSymbolProvider(base) { +class RustReservedWordSymbolProvider(private val base: RustSymbolProvider) : WrappingSymbolProvider(base) { private val internal = - ReservedWordSymbolProvider.builder().symbolProvider(base).memberReservedWords(RustReservedWords).build() + ReservedWordSymbolProvider.builder().symbolProvider(base) + .nameReservedWords(RustReservedWords) + .memberReservedWords(RustReservedWords) + .build() override fun toMemberName(shape: MemberShape): String { - val baseName = internal.toMemberName(shape) - return when (val container = model.expectShape(shape.container)) { - is StructureShape -> when (baseName) { + val baseName = super.toMemberName(shape) + val reservedWordReplacedName = internal.toMemberName(shape) + val container = model.expectShape(shape.container) + return when { + container is StructureShape -> when (baseName) { "build" -> "build_value" "builder" -> "builder_value" "default" -> "default_value" @@ -40,20 +42,25 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, priva "make_operation" -> "make_operation_value" "presigned" -> "presigned_value" "customize" -> "customize_value" - else -> baseName + // To avoid conflicts with the error metadata `meta` field + "meta" -> "meta_value" + else -> reservedWordReplacedName } - is UnionShape -> when (baseName) { + container is UnionShape -> when (baseName) { // Unions contain an `Unknown` variant. This exists to support parsing data returned from the server // that represent union variants that have been added since this SDK was generated. UnionGenerator.UnknownVariantName -> "${UnionGenerator.UnknownVariantName}Value" "${UnionGenerator.UnknownVariantName}Value" -> "${UnionGenerator.UnknownVariantName}Value_" - // Self cannot be used as a raw identifier, so we can't use the normal escaping strategy - // https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4 - "Self" -> "SelfValue" + else -> reservedWordReplacedName + } + + container is EnumShape || container.hasTrait() -> when (baseName) { + // Unknown is used as the name of the variant containing unexpected values + "Unknown" -> "UnknownValue" // Real models won't end in `_` so it's safe to stop here - "SelfValue" -> "SelfValue_" - else -> baseName + "UnknownValue" -> "UnknownValue_" + else -> reservedWordReplacedName } else -> error("unexpected container: $container") @@ -67,46 +74,36 @@ class RustReservedWordSymbolProvider(private val base: RustSymbolProvider, priva * code generators to generate special docs. */ override fun toSymbol(shape: Shape): Symbol { + // Sanity check that the symbol provider stack is set up correctly + check(super.toSymbol(shape).renamedFrom() == null) { + "RustReservedWordSymbolProvider should only run once" + } + + var renamedSymbol = internal.toSymbol(shape) return when (shape) { is MemberShape -> { val container = model.expectShape(shape.container) - if (!(container is StructureShape || container is UnionShape)) { + val containerIsEnum = container is EnumShape || container.hasTrait() + if (container !is StructureShape && container !is UnionShape && !containerIsEnum) { return base.toSymbol(shape) } val previousName = base.toMemberName(shape) val escapedName = this.toMemberName(shape) - val baseSymbol = base.toSymbol(shape) // if the names don't match and it isn't a simple escaping with `r#`, record a rename - baseSymbol.letIf(escapedName != previousName && !escapedName.contains("r#")) { - it.toBuilder().renamedFrom(previousName).build() - } + renamedSymbol.toBuilder().name(escapedName) + .letIf(escapedName != previousName && !escapedName.contains("r#")) { + it.renamedFrom(previousName) + }.build() } - else -> base.toSymbol(shape) + else -> renamedSymbol } } +} - override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? { - val baseName = base.toEnumVariantName(definition) ?: return null - check(definition.name.orNull()?.toPascalCase() == baseName.name) { - "Enum variants must already be in pascal case ${baseName.name} differed from ${baseName.name.toPascalCase()}. Definition: ${definition.name}" - } - check(baseName.renamedFrom == null) { - "definitions should only pass through the renamer once" - } - return when (baseName.name) { - // Self cannot be used as a raw identifier, so we can't use the normal escaping strategy - // https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4 - "Self" -> MaybeRenamed("SelfValue", "Self") - // Real models won't end in `_` so it's safe to stop here - "SelfValue" -> MaybeRenamed("SelfValue_", "SelfValue") - // Unknown is used as the name of the variant containing unexpected values - "Unknown" -> MaybeRenamed("UnknownValue", "Unknown") - // Real models won't end in `_` so it's safe to stop here - "UnknownValue" -> MaybeRenamed("UnknownValue_", "UnknownValue") - else -> baseName - } - } +enum class EscapeFor { + TypeName, + ModuleName, } object RustReservedWords : ReservedWords { @@ -166,17 +163,33 @@ object RustReservedWords : ReservedWords { "try", ) - private val cantBeRaw = setOf("self", "crate", "super") + // Some things can't be used as a raw identifier, so we can't use the normal escaping strategy + // https://internals.rust-lang.org/t/raw-identifiers-dont-work-for-all-identifiers/9094/4 + private val keywordEscapingMap = mapOf( + "crate" to "crate_", + "super" to "super_", + "self" to "self_", + "Self" to "SelfValue", + // Real models won't end in `_` so it's safe to stop here + "SelfValue" to "SelfValue_", + ) - override fun escape(word: String): String = when { - cantBeRaw.contains(word) -> "${word}_" - else -> "r##$word" - } + override fun escape(word: String): String = doEscape(word, EscapeFor.TypeName) - fun escapeIfNeeded(word: String): String = when (isReserved(word)) { - true -> escape(word) - else -> word - } + private fun doEscape(word: String, escapeFor: EscapeFor = EscapeFor.TypeName): String = + when (val mapped = keywordEscapingMap[word]) { + null -> when (escapeFor) { + EscapeFor.TypeName -> "r##$word" + EscapeFor.ModuleName -> "${word}_" + } + else -> mapped + } + + fun escapeIfNeeded(word: String, escapeFor: EscapeFor = EscapeFor.TypeName): String = + when (isReserved(word)) { + true -> doEscape(word, escapeFor) + else -> word + } - override fun isReserved(word: String): Boolean = RustKeywords.contains(word) + override fun isReserved(word: String): Boolean = RustKeywords.contains(word) || keywordEscapingMap.contains(word) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt index f4fbfd5b70a..95036079802 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustType.kt @@ -222,7 +222,9 @@ fun RustType.asArgument(name: String) = Argument( fun RustType.render(fullyQualified: Boolean = true): String { val namespace = if (fullyQualified) { this.namespace?.let { "$it::" } ?: "" - } else "" + } else { + "" + } val base = when (this) { is RustType.Unit -> this.name is RustType.Bool -> this.name @@ -325,6 +327,16 @@ fun RustType.isCopy(): Boolean = when (this) { else -> false } +/** Returns true if the type implements Eq */ +fun RustType.isEq(): Boolean = when (this) { + is RustType.Integer -> true + is RustType.Bool -> true + is RustType.String -> true + is RustType.Unit -> true + is RustType.Container -> this.member.isEq() + else -> false +} + enum class Visibility { PRIVATE, PUBCRATE, PUBLIC; @@ -416,7 +428,7 @@ enum class AttributeKind { /** * Outer attributes, written without the bang after the hash, apply to the thing that follows the attribute. */ - Outer + Outer, } /** @@ -458,9 +470,11 @@ class Attribute(val inner: Writable) { val AllowDeprecated = Attribute(allow("deprecated")) val AllowIrrefutableLetPatterns = Attribute(allow("irrefutable_let_patterns")) val AllowUnreachableCode = Attribute(allow("unreachable_code")) + val AllowUnreachablePatterns = Attribute(allow("unreachable_patterns")) val AllowUnusedImports = Attribute(allow("unused_imports")) val AllowUnusedMut = Attribute(allow("unused_mut")) val AllowUnusedVariables = Attribute(allow("unused_variables")) + val AllowMissingDocs = Attribute(allow("missing_docs")) val CfgTest = Attribute(cfg("test")) val DenyMissingDocs = Attribute(deny("missing_docs")) val DocHidden = Attribute(doc("hidden")) @@ -546,3 +560,10 @@ class Attribute(val inner: Writable) { } } } + +/** Render all attributes in this list, one after another */ +fun Collection.render(writer: RustWriter) { + for (attr in this) { + attr.render(writer) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt index 2b4a17c193b..1628fe9cca3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriter.kt @@ -10,6 +10,7 @@ import org.jsoup.Jsoup import org.jsoup.nodes.Element import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolDependencyContainer import software.amazon.smithy.codegen.core.SymbolWriter import software.amazon.smithy.codegen.core.SymbolWriter.Factory import software.amazon.smithy.model.Model @@ -115,7 +116,7 @@ private fun , U> T.withTemplate( * This enables conditionally wrapping a block in a prefix/suffix, e.g. * * ``` - * writer.withBlock("Some(", ")", conditional = symbol.isOptional()) { + * writer.conditionalBlock("Some(", ")", conditional = symbol.isOptional()) { * write("symbolValue") * } * ``` @@ -166,9 +167,10 @@ private fun transformTemplate(template: String, scope: Array> T.docsOrFallback( note: String? = null, ): T { val htmlDocs: (T.() -> Unit)? = when (docString?.isNotBlank()) { - true -> { { docs(normalizeHtml(escape(docString))) } } + true -> { + { docs(normalizeHtml(escape(docString))) } + } + else -> null } return docsOrFallback(htmlDocs, autoSuppressMissingDocs, note) @@ -376,6 +381,13 @@ private fun Element.changeInto(tagName: String) { replaceWith(Element(tagName).also { elem -> elem.appendChildren(childNodesCopy()) }) } +/** Write an `impl` block for the given symbol */ +fun RustWriter.implBlock(symbol: Symbol, block: Writable) { + rustBlock("impl ${symbol.name}") { + block() + } +} + /** * Write _exactly_ the text as written into the code writer without newlines or formatting */ @@ -394,6 +406,8 @@ class RustWriter private constructor( private val printWarning: Boolean = true, /** Insert comments indicating where code was generated */ private val debugMode: Boolean = false, + /** When true, automatically change all dependencies to be in the test scope */ + val devDependenciesOnly: Boolean = false, ) : SymbolWriter(UseDeclarations(namespace)) { companion object { @@ -407,8 +421,16 @@ class RustWriter private constructor( fun factory(debugMode: Boolean): Factory = Factory { fileName: String, namespace: String -> when { fileName.endsWith(".toml") -> RustWriter(fileName, namespace, "#", debugMode = debugMode) + fileName.endsWith(".py") -> RustWriter(fileName, namespace, "#", debugMode = debugMode) fileName.endsWith(".md") -> rawWriter(fileName, debugMode = debugMode) fileName == "LICENSE" -> rawWriter(fileName, debugMode = debugMode) + fileName.startsWith("tests/") -> RustWriter( + fileName, + namespace, + debugMode = debugMode, + devDependenciesOnly = true, + ) + else -> RustWriter(fileName, namespace, debugMode = debugMode) } } @@ -454,7 +476,9 @@ class RustWriter private constructor( init { expressionStart = '#' if (filename.endsWith(".rs")) { - require(namespace.startsWith("crate") || filename.startsWith("tests/")) { "We can only write into files in the crate (got $namespace)" } + require(namespace.startsWith("crate") || filename.startsWith("tests/")) { + "We can only write into files in the crate (got $namespace)" + } } putFormatter('T', formatter) putFormatter('D', RustDocLinker()) @@ -463,7 +487,9 @@ class RustWriter private constructor( fun module(): String? = if (filename.startsWith("src") && filename.endsWith(".rs")) { filename.removeSuffix(".rs").substringAfterLast(File.separatorChar) - } else null + } else { + null + } fun safeName(prefix: String = "var"): String { n += 1 @@ -474,6 +500,22 @@ class RustWriter private constructor( preamble.add(preWriter) } + private fun addDependencyTestAware(dependencyContainer: SymbolDependencyContainer): RustWriter { + if (!devDependenciesOnly) { + super.addDependency(dependencyContainer) + } else { + dependencyContainer.dependencies.forEach { dependency -> + super.addDependency( + when (val dep = RustDependency.fromSymbolDependency(dependency)) { + is CargoDependency -> dep.toDevDependency() + else -> dependencyContainer + }, + ) + } + } + return this + } + /** * Create an inline module. Instead of being in a new file, inline modules are written as a `mod { ... }` block * directly into the parent. @@ -481,7 +523,7 @@ class RustWriter private constructor( * Callers must take care to use [this] when writing to ensure code is written to the right place: * ```kotlin * val writer = RustWriter.forModule("model") - * writer.withModule(RustModule.public("nested")) { + * writer.withInlineModule(RustModule.public("nested")) { * Generator(...).render(this) // GOOD * Generator(...).render(writer) // WRONG! * } @@ -499,14 +541,19 @@ class RustWriter private constructor( // In Rust, modules must specify their own imports—they don't have access to the parent scope. // To easily handle this, create a new inner writer to collect imports, then dump it // into an inline module. - val innerWriter = RustWriter(this.filename, "${this.namespace}::${module.name}", printWarning = false) + val innerWriter = RustWriter( + this.filename, + "${this.namespace}::${module.name}", + printWarning = false, + devDependenciesOnly = devDependenciesOnly || module.tests, + ) moduleWriter(innerWriter) module.documentation?.let { docs -> docs(docs) } module.rustMetadata.render(this) rustBlock("mod ${module.name}") { writeWithNoFormatting(innerWriter.toString()) } - innerWriter.dependencies.forEach { addDependency(it) } + innerWriter.dependencies.forEach { addDependencyTestAware(it) } return this } @@ -605,15 +652,19 @@ class RustWriter private constructor( override fun toString(): String { val contents = super.toString() val preheader = if (preamble.isNotEmpty()) { - val prewriter = RustWriter(filename, namespace, printWarning = false) + val prewriter = RustWriter(filename, namespace, printWarning = false, devDependenciesOnly = devDependenciesOnly) preamble.forEach { it(prewriter) } prewriter.toString() - } else null + } else { + null + } // Hack to support TOML: the [commentCharacter] is overridden to support writing TOML. val header = if (printWarning) { "$commentCharacter Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT." - } else null + } else { + null + } val useDecls = importContainer.toString().ifEmpty { null } @@ -623,7 +674,7 @@ class RustWriter private constructor( fun format(r: Any) = formatter.apply(r, "") fun addDepsRecursively(symbol: Symbol) { - addDependency(symbol) + addDependencyTestAware(symbol) symbol.references.forEach { addDepsRecursively(it.symbol) } } @@ -647,9 +698,9 @@ class RustWriter private constructor( @Suppress("UNCHECKED_CAST") val func = t as? Writable ?: throw CodegenException("RustWriteableInjector.apply choked on non-function t ($t)") - val innerWriter = RustWriter(filename, namespace, printWarning = false) + val innerWriter = RustWriter(filename, namespace, printWarning = false, devDependenciesOnly = devDependenciesOnly) func(innerWriter) - innerWriter.dependencies.forEach { addDependency(it) } + innerWriter.dependencies.forEach { addDependencyTestAware(it) } return innerWriter.toString().trimEnd() } } @@ -658,11 +709,15 @@ class RustWriter private constructor( override fun apply(t: Any, u: String): String { return when (t) { is RuntimeType -> { - t.dependency?.also { addDependency(it) } + t.dependency?.also { addDependencyTestAware(it) } // for now, use the fully qualified type name t.fullyQualifiedName() } + is RustModule -> { + t.fullyQualifiedPath() + } + is Symbol -> { addDepsRecursively(t) t.rustType().render(fullyQualified = true) @@ -676,9 +731,9 @@ class RustWriter private constructor( @Suppress("UNCHECKED_CAST") val func = t as? Writable ?: throw CodegenException("Invalid function type (expected writable) ($t)") - val innerWriter = RustWriter(filename, namespace, printWarning = false) + val innerWriter = RustWriter(filename, namespace, printWarning = false, devDependenciesOnly = devDependenciesOnly) func(innerWriter) - innerWriter.dependencies.forEach { addDependency(it) } + innerWriter.dependencies.forEach { addDependencyTestAware(it) } return innerWriter.toString().trimEnd() } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt index 768a24073b8..cbfb2cc39a6 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegator.kt @@ -6,17 +6,20 @@ package software.amazon.smithy.rust.codegen.core.smithy import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.codegen.core.WriterDelegator import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.smithy.generators.CargoTomlGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsGenerator @@ -146,21 +149,34 @@ open class RustCrate( is RustModule.LibRs -> lib { moduleWriter(this) } is RustModule.LeafModule -> { checkDups(module) - // Create a dependency which adds the mod statement for this module. This will be added to the writer - // so that _usage_ of this module will generate _exactly one_ `mod ` with the correct modifiers. - val modStatement = RuntimeType.forInlineFun("mod_" + module.fullyQualifiedPath(), module.parent) { - module.renderModStatement(this) - } - val path = module.fullyQualifiedPath().split("::").drop(1).joinToString("/") - inner.useFileWriter("src/$path.rs", module.fullyQualifiedPath()) { writer -> - moduleWriter(writer) - writer.addDependency(modStatement.dependency) + + if (module.isInline()) { + withModule(module.parent) { + withInlineModule(module, moduleWriter) + } + } else { + // Create a dependency which adds the mod statement for this module. This will be added to the writer + // so that _usage_ of this module will generate _exactly one_ `mod ` with the correct modifiers. + val modStatement = RuntimeType.forInlineFun("mod_" + module.fullyQualifiedPath(), module.parent) { + module.renderModStatement(this) + } + val path = module.fullyQualifiedPath().split("::").drop(1).joinToString("/") + inner.useFileWriter("src/$path.rs", module.fullyQualifiedPath()) { writer -> + moduleWriter(writer) + writer.addDependency(modStatement.dependency) + } } } } return this } + /** + * Returns the module for a given Shape. + */ + fun moduleFor(shape: Shape, moduleWriter: Writable): RustCrate = + withModule((symbolProvider as RustSymbolProvider).moduleForShape(shape), moduleWriter) + /** * Create a new file directly */ @@ -169,14 +185,25 @@ open class RustCrate( fileWriter(it) } } -} -val ErrorsModule = RustModule.public("error", documentation = "All error types that operations can return. Documentation on these types is copied from the model.") -val OperationsModule = RustModule.public("operation", documentation = "All operations that this crate can perform.") -val ModelsModule = RustModule.public("model", documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.") -val InputsModule = RustModule.public("input", documentation = "Input structures for operations. Documentation on these types is copied from the model.") -val OutputsModule = RustModule.public("output", documentation = "Output structures for operations. Documentation on these types is copied from the model.") + /** + * Render something in a private module and re-export it into the given symbol. + * + * @param privateModule: Private module to render into + * @param symbol: The symbol of the thing being rendered, which will be re-exported. This symbol + * should be the public-facing symbol rather than the private symbol. + */ + fun inPrivateModuleWithReexport(privateModule: RustModule.LeafModule, symbol: Symbol, writer: Writable) { + withModule(privateModule, writer) + privateModule.toType().resolve(symbol.name).toSymbol().also { privateSymbol -> + withModule(symbol.module()) { + rust("pub use #T;", privateSymbol) + } + } + } +} +// TODO(https://github.com/awslabs/smithy-rs/issues/2341): Remove unconstrained/constrained from codegen-core val UnconstrainedModule = RustModule.private("unconstrained", "Unconstrained types for constrained shapes.") val ConstrainedModule = @@ -198,10 +225,12 @@ fun WriterDelegator.finalize( this.useFileWriter("src/lib.rs", "crate::lib") { LibRsGenerator(settings, model, libRsCustomizations, requireDocs).render(it) } - val cargoDependencies = mergeDependencyFeatures( + val cargoDependencies = + this.dependencies.map { RustDependency.fromSymbolDependency(it) } - .filterIsInstance().distinct(), - ) + .filterIsInstance().distinct() + .mergeDependencyFeatures() + .mergeIdenticalTestDependencies() this.useFileWriter("Cargo.toml") { val cargoToml = CargoTomlGenerator( settings, @@ -223,9 +252,21 @@ private fun CargoDependency.mergeWith(other: CargoDependency): CargoDependency { ) } -fun mergeDependencyFeatures(cargoDependencies: List): List = - cargoDependencies.groupBy { it.key } +internal fun List.mergeDependencyFeatures(): List = + this.groupBy { it.key } .mapValues { group -> group.value.reduce { acc, next -> acc.mergeWith(next) } } .values .toList() .sortedBy { it.name } + +/** + * If the same dependency exists both in prod and test scope, remove it from the test scope. + */ +internal fun List.mergeIdenticalTestDependencies(): List { + val compileDeps = + this.filter { it.scope == DependencyScope.Compile }.toSet() + + return this.filterNot { + it.scope == DependencyScope.Dev && compileDeps.contains(it.copy(scope = DependencyScope.Compile)) + } +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/DirectedWalker.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/DirectedWalker.kt index 51c0b4ecf80..f48b996045e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/DirectedWalker.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/DirectedWalker.kt @@ -19,11 +19,8 @@ import java.util.function.Predicate class DirectedWalker(model: Model) { private val inner = Walker(model) - fun walkShapes(shape: Shape): Set { - return walkShapes(shape) { _ -> true } - } + fun walkShapes(shape: Shape): Set = walkShapes(shape) { true } - fun walkShapes(shape: Shape, predicate: Predicate): Set { - return inner.walkShapes(shape) { rel -> predicate.test(rel) && rel.direction == RelationshipDirection.DIRECTED } - } + fun walkShapes(shape: Shape, predicate: Predicate): Set = + inner.walkShapes(shape) { rel -> predicate.test(rel) && rel.direction == RelationshipDirection.DIRECTED } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt index 1aff86f7d49..6eeab26d65b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/EventStreamSymbolProvider.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.core.smithy import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape @@ -14,7 +13,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors @@ -29,7 +27,6 @@ import software.amazon.smithy.rust.codegen.core.util.isOutputEventStream class EventStreamSymbolProvider( private val runtimeConfig: RuntimeConfig, base: RustSymbolProvider, - private val model: Model, private val target: CodegenTarget, ) : WrappingSymbolProvider(base) { override fun toSymbol(shape: Shape): Symbol { @@ -49,7 +46,7 @@ class EventStreamSymbolProvider( val error = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() } else { - unionShape.eventStreamErrorSymbol(this).toSymbol() + symbolForEventStreamError(unionShape) } val errorFmt = error.rustType().render(fullyQualified = true) val innerFmt = initial.rustType().stripOuter().render(fullyQualified = true) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index 1c489f188f7..219cb33c038 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -80,10 +80,11 @@ data class RuntimeConfig( */ fun fromNode(maybeNode: Optional): RuntimeConfig { val node = maybeNode.orElse(Node.objectNode()) - val crateVersionMap = node.getObjectMember("versions").orElse(Node.objectNode()).members.entries.let { members -> - val map = members.associate { it.key.toString() to it.value.expectStringNode().value } - CrateVersionMap(map) - } + val crateVersionMap = + node.getObjectMember("versions").orElse(Node.objectNode()).members.entries.let { members -> + val map = members.associate { it.key.toString() to it.value.expectStringNode().value } + CrateVersionMap(map) + } val path = node.getStringMember("relativePath").orNull()?.value val runtimeCrateLocation = RuntimeCrateLocation(path = path, versions = crateVersionMap) return RuntimeConfig( @@ -95,7 +96,11 @@ data class RuntimeConfig( val crateSrcPrefix: String = cratePrefix.replace("-", "_") - fun smithyRuntimeCrate(runtimeCrateName: String, optional: Boolean = false, scope: DependencyScope = DependencyScope.Compile): CargoDependency { + fun smithyRuntimeCrate( + runtimeCrateName: String, + optional: Boolean = false, + scope: DependencyScope = DependencyScope.Compile, + ): CargoDependency { val crateName = "$cratePrefix-$runtimeCrateName" return CargoDependency( crateName, @@ -236,7 +241,6 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val Tracing = CargoDependency.Tracing.toType() // codegen types - val Config = RuntimeType("crate::config") val ConstrainedTrait = RuntimeType("crate::constrained::Constrained", InlineDependency.constrained()) val MaybeConstrained = RuntimeType("crate::constrained::MaybeConstrained", InlineDependency.constrained()) @@ -246,35 +250,56 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) fun smithyClient(runtimeConfig: RuntimeConfig) = CargoDependency.smithyClient(runtimeConfig).toType() fun smithyEventStream(runtimeConfig: RuntimeConfig) = CargoDependency.smithyEventStream(runtimeConfig).toType() fun smithyHttp(runtimeConfig: RuntimeConfig) = CargoDependency.smithyHttp(runtimeConfig).toType() + fun smithyHttpAuth(runtimeConfig: RuntimeConfig) = CargoDependency.smithyHttpAuth(runtimeConfig).toType() + fun smithyHttpTower(runtimeConfig: RuntimeConfig) = CargoDependency.smithyHttpTower(runtimeConfig).toType() fun smithyJson(runtimeConfig: RuntimeConfig) = CargoDependency.smithyJson(runtimeConfig).toType() fun smithyQuery(runtimeConfig: RuntimeConfig) = CargoDependency.smithyQuery(runtimeConfig).toType() fun smithyTypes(runtimeConfig: RuntimeConfig) = CargoDependency.smithyTypes(runtimeConfig).toType() fun smithyXml(runtimeConfig: RuntimeConfig) = CargoDependency.smithyXml(runtimeConfig).toType() - private fun smithyProtocolTest(runtimeConfig: RuntimeConfig) = CargoDependency.smithyProtocolTestHelpers(runtimeConfig).toType() + private fun smithyProtocolTest(runtimeConfig: RuntimeConfig) = + CargoDependency.smithyProtocolTestHelpers(runtimeConfig).toType() // smithy runtime type members - fun base64Decode(runtimeConfig: RuntimeConfig): RuntimeType = smithyTypes(runtimeConfig).resolve("base64::decode") - fun base64Encode(runtimeConfig: RuntimeConfig): RuntimeType = smithyTypes(runtimeConfig).resolve("base64::encode") + fun base64Decode(runtimeConfig: RuntimeConfig): RuntimeType = + smithyTypes(runtimeConfig).resolve("base64::decode") + + fun base64Encode(runtimeConfig: RuntimeConfig): RuntimeType = + smithyTypes(runtimeConfig).resolve("base64::encode") + fun blob(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("Blob") fun byteStream(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("byte_stream::ByteStream") fun classifyRetry(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("retry::ClassifyRetry") fun dateTime(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("DateTime") fun document(runtimeConfig: RuntimeConfig): RuntimeType = smithyTypes(runtimeConfig).resolve("Document") - fun errorKind(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("retry::ErrorKind") + fun retryErrorKind(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("retry::ErrorKind") fun eventStreamReceiver(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("event_stream::Receiver") - fun genericError(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("Error") + fun errorMetadata(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::ErrorMetadata") + fun errorMetadataBuilder(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::metadata::Builder") + fun provideErrorMetadataTrait(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::metadata::ProvideErrorMetadata") + fun unhandledError(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::Unhandled") fun jsonErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.jsonErrors(runtimeConfig)) + fun awsQueryCompatibleErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.awsQueryCompatibleErrors(runtimeConfig)) fun labelFormat(runtimeConfig: RuntimeConfig, func: String) = smithyHttp(runtimeConfig).resolve("label::$func") fun operation(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("operation::Operation") fun operationModule(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("operation") - fun parseHttpResponse(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("response::ParseHttpResponse") - fun parseStrictResponse(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("response::ParseStrictResponse") - fun protocolTest(runtimeConfig: RuntimeConfig, func: String): RuntimeType = smithyProtocolTest(runtimeConfig).resolve(func) - fun provideErrorKind(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("retry::ProvideErrorKind") + fun parseHttpResponse(runtimeConfig: RuntimeConfig) = + smithyHttp(runtimeConfig).resolve("response::ParseHttpResponse") + + fun parseStrictResponse(runtimeConfig: RuntimeConfig) = + smithyHttp(runtimeConfig).resolve("response::ParseStrictResponse") + + fun protocolTest(runtimeConfig: RuntimeConfig, func: String): RuntimeType = + smithyProtocolTest(runtimeConfig).resolve(func) + + fun provideErrorKind(runtimeConfig: RuntimeConfig) = + smithyTypes(runtimeConfig).resolve("retry::ProvideErrorKind") + fun queryFormat(runtimeConfig: RuntimeConfig, func: String) = smithyHttp(runtimeConfig).resolve("query::$func") fun sdkBody(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("body::SdkBody") fun sdkError(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("result::SdkError") - fun sdkSuccess(runtimeConfig: RuntimeConfig): RuntimeType = smithyHttp(runtimeConfig).resolve("result::SdkSuccess") + fun sdkSuccess(runtimeConfig: RuntimeConfig): RuntimeType = + smithyHttp(runtimeConfig).resolve("result::SdkSuccess") + fun timestampFormat(runtimeConfig: RuntimeConfig, format: TimestampFormatTrait.Format): RuntimeType { val timestampFormat = when (format) { TimestampFormatTrait.Format.EPOCH_SECONDS -> "EpochSeconds" @@ -286,7 +311,11 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) return smithyTypes(runtimeConfig).resolve("date_time::Format::$timestampFormat") } - fun forInlineDependency(inlineDependency: InlineDependency) = RuntimeType("crate::${inlineDependency.name}", inlineDependency) + fun captureRequest(runtimeConfig: RuntimeConfig) = + CargoDependency.smithyClientTestUtil(runtimeConfig).toType().resolve("test_connection::capture_request") + + fun forInlineDependency(inlineDependency: InlineDependency) = + RuntimeType("crate::${inlineDependency.name}", inlineDependency) fun forInlineFun(name: String, module: RustModule, func: Writable) = RuntimeType( "${module.fullyQualifiedPath()}::$name", @@ -296,10 +325,13 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) // inlinable types fun ec2QueryErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.ec2QueryErrors(runtimeConfig)) + fun wrappedXmlErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.wrappedXmlErrors(runtimeConfig)) + fun unwrappedXmlErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.unwrappedXmlErrors(runtimeConfig)) + val IdempotencyToken by lazy { forInlineDependency(InlineDependency.idempotencyToken()) } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RustSymbolProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RustSymbolProvider.kt new file mode 100644 index 00000000000..2314007aa93 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RustSymbolProvider.kt @@ -0,0 +1,101 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule + +/** + * SymbolProvider interface that carries additional configuration and module/symbol resolution. + */ +interface RustSymbolProvider : SymbolProvider { + val model: Model + val moduleProviderContext: ModuleProviderContext + val config: RustSymbolProviderConfig + + fun moduleForShape(shape: Shape): RustModule.LeafModule = + config.moduleProvider.moduleForShape(moduleProviderContext, shape) + fun moduleForOperationError(operation: OperationShape): RustModule.LeafModule = + config.moduleProvider.moduleForOperationError(moduleProviderContext, operation) + fun moduleForEventStreamError(eventStream: UnionShape): RustModule.LeafModule = + config.moduleProvider.moduleForEventStreamError(moduleProviderContext, eventStream) + fun moduleForBuilder(shape: Shape): RustModule.LeafModule = + config.moduleProvider.moduleForBuilder(moduleProviderContext, shape, toSymbol(shape)) + + /** Returns the symbol for an operation error */ + fun symbolForOperationError(operation: OperationShape): Symbol + + /** Returns the symbol for an event stream error */ + fun symbolForEventStreamError(eventStream: UnionShape): Symbol + + /** Returns the symbol for a builder */ + fun symbolForBuilder(shape: Shape): Symbol +} + +/** + * Module providers can't use the full CodegenContext since they're invoked from + * inside the SymbolVisitor, which is created before CodegenContext is created. + */ +data class ModuleProviderContext( + val settings: CoreRustSettings, + val model: Model, + val serviceShape: ServiceShape?, +) + +fun CodegenContext.toModuleProviderContext(): ModuleProviderContext = + ModuleProviderContext(settings, model, serviceShape) + +/** + * Provider for RustModules so that the symbol provider knows where to organize things. + */ +interface ModuleProvider { + /** Returns the module for a shape */ + fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule + + /** Returns the module for an operation error */ + fun moduleForOperationError(context: ModuleProviderContext, operation: OperationShape): RustModule.LeafModule + + /** Returns the module for an event stream error */ + fun moduleForEventStreamError(context: ModuleProviderContext, eventStream: UnionShape): RustModule.LeafModule + + /** Returns the module for a builder */ + fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule +} + +/** + * Configuration for symbol providers. + */ +data class RustSymbolProviderConfig( + val runtimeConfig: RuntimeConfig, + val renameExceptions: Boolean, + val nullabilityCheckMode: NullableIndex.CheckMode, + val moduleProvider: ModuleProvider, + val nameBuilderFor: (Symbol) -> String = { _ -> "Builder" }, +) + +/** + * Default delegator to enable easily decorating another symbol provider. + */ +open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSymbolProvider { + override val model: Model get() = base.model + override val moduleProviderContext: ModuleProviderContext get() = base.moduleProviderContext + override val config: RustSymbolProviderConfig get() = base.config + + override fun toSymbol(shape: Shape): Symbol = base.toSymbol(shape) + override fun toMemberName(shape: MemberShape): String = base.toMemberName(shape) + override fun symbolForOperationError(operation: OperationShape): Symbol = base.symbolForOperationError(operation) + override fun symbolForEventStreamError(eventStream: UnionShape): Symbol = + base.symbolForEventStreamError(eventStream) + override fun symbolForBuilder(shape: Shape): Symbol = base.symbolForBuilder(shape) +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt index 3e1d082627d..051f3c3d1ee 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/StreamingTraitSymbolProvider.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.core.smithy import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape @@ -26,8 +25,7 @@ import software.amazon.smithy.rust.codegen.core.util.isStreaming /** * Wrapping symbol provider to change `Blob` to `ByteStream` when it targets a streaming member */ -class StreamingShapeSymbolProvider(private val base: RustSymbolProvider, private val model: Model) : - WrappingSymbolProvider(base) { +class StreamingShapeSymbolProvider(private val base: RustSymbolProvider) : WrappingSymbolProvider(base) { override fun toSymbol(shape: Shape): Symbol { val initial = base.toSymbol(shape) // We are only targeting member shapes @@ -44,7 +42,7 @@ class StreamingShapeSymbolProvider(private val base: RustSymbolProvider, private // We are only targeting streaming blobs return if (target is BlobShape && shape.isStreaming(model)) { - RuntimeType.byteStream(config().runtimeConfig).toSymbol().toBuilder().setDefault(Default.RustDefault).build() + RuntimeType.byteStream(config.runtimeConfig).toSymbol().toBuilder().setDefault(Default.RustDefault).build() } else { base.toSymbol(shape) } @@ -59,22 +57,23 @@ class StreamingShapeSymbolProvider(private val base: RustSymbolProvider, private * * Note that since streaming members can only be used on the root shape, this can only impact input and output shapes. */ -class StreamingShapeMetadataProvider( - private val base: RustSymbolProvider, - private val model: Model, -) : SymbolMetadataProvider(base) { +class StreamingShapeMetadataProvider(private val base: RustSymbolProvider) : SymbolMetadataProvider(base) { override fun structureMeta(structureShape: StructureShape): RustMetadata { val baseMetadata = base.toSymbol(structureShape).expectRustMetadata() return if (structureShape.hasStreamingMember(model)) { baseMetadata.withoutDerives(RuntimeType.Clone, RuntimeType.PartialEq) - } else baseMetadata + } else { + baseMetadata + } } override fun unionMeta(unionShape: UnionShape): RustMetadata { val baseMetadata = base.toSymbol(unionShape).expectRustMetadata() return if (unionShape.hasStreamingMember(model)) { baseMetadata.withoutDerives(RuntimeType.Clone, RuntimeType.PartialEq) - } else baseMetadata + } else { + baseMetadata + } } override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt new file mode 100644 index 00000000000..3b9307ab7ee --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolExt.kt @@ -0,0 +1,138 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustType +import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter +import software.amazon.smithy.rust.codegen.core.util.orNull + +/** Set the symbolLocation for this symbol builder */ +fun Symbol.Builder.locatedIn(rustModule: RustModule.LeafModule): Symbol.Builder { + val currentRustType = this.build().rustType() + check(currentRustType is RustType.Opaque) { + "Only `RustType.Opaque` can have its namespace updated. Received $currentRustType." + } + val newRustType = currentRustType.copy(namespace = rustModule.fullyQualifiedPath()) + return this.definitionFile(rustModule.definitionFile()) + .namespace(rustModule.fullyQualifiedPath(), "::") + .rustType(newRustType) + .module(rustModule) +} + +/** + * Make the Rust type of a symbol optional (hold `Option`) + * + * This is idempotent and will have no change if the type is already optional. + */ +fun Symbol.makeOptional(): Symbol = + if (isOptional()) { + this + } else { + val rustType = RustType.Option(this.rustType()) + Symbol.builder() + .rustType(rustType) + .addReference(this) + .name(rustType.name) + .build() + } + +/** + * Make the Rust type of a symbol boxed (hold `Box`). + * + * This is idempotent and will have no change if the type is already boxed. + */ +fun Symbol.makeRustBoxed(): Symbol = + if (isRustBoxed()) { + this + } else { + val rustType = RustType.Box(this.rustType()) + Symbol.builder() + .rustType(rustType) + .addReference(this) + .name(rustType.name) + .build() + } + +/** + * Make the Rust type of a symbol wrapped in `MaybeConstrained`. (hold `MaybeConstrained`). + * + * This is idempotent and will have no change if the type is already `MaybeConstrained`. + */ +fun Symbol.makeMaybeConstrained(): Symbol = + if (this.rustType() is RustType.MaybeConstrained) { + this + } else { + val rustType = RustType.MaybeConstrained(this.rustType()) + Symbol.builder() + .rustType(rustType) + .addReference(this) + .name(rustType.name) + .build() + } + +/** + * Map the [RustType] of a symbol with [f]. + * + * WARNING: This function does not update any symbol references (e.g., `symbol.addReference()`) on the + * returned symbol. You will have to add those yourself if your logic relies on them. + **/ +fun Symbol.mapRustType(f: (RustType) -> RustType): Symbol { + val newType = f(this.rustType()) + return Symbol.builder().rustType(newType) + .name(newType.name) + .build() +} + +/** + * Type representing the default value for a given type (e.g. for Strings, this is `""`). + */ +sealed class Default { + /** + * This symbol has no default value. If the symbol is not optional, this will error during builder construction + */ + object NoDefault : Default() + + /** + * This symbol should use the Rust `std::default::Default` when unset + */ + object RustDefault : Default() +} + +/** + * Returns true when it's valid to use the default/0 value for [this] symbol during construction. + */ +fun Symbol.canUseDefault(): Boolean = this.defaultValue() != Default.NoDefault + +/** + * True when [this] is will be represented by Option in Rust + */ +fun Symbol.isOptional(): Boolean = when (this.rustType()) { + is RustType.Option -> true + else -> false +} + +fun Symbol.isRustBoxed(): Boolean = rustType().stripOuter() is RustType.Box + +private const val RUST_TYPE_KEY = "rusttype" +private const val SHAPE_KEY = "shape" +private const val RUST_MODULE_KEY = "rustmodule" +private const val RENAMED_FROM_KEY = "renamedfrom" +private const val SYMBOL_DEFAULT = "symboldefault" + +// Symbols should _always_ be created with a Rust type & shape attached +fun Symbol.rustType(): RustType = this.expectProperty(RUST_TYPE_KEY, RustType::class.java) +fun Symbol.Builder.rustType(rustType: RustType): Symbol.Builder = this.putProperty(RUST_TYPE_KEY, rustType) +fun Symbol.shape(): Shape = this.expectProperty(SHAPE_KEY, Shape::class.java) +fun Symbol.Builder.shape(shape: Shape?): Symbol.Builder = this.putProperty(SHAPE_KEY, shape) +fun Symbol.module(): RustModule.LeafModule = this.expectProperty(RUST_MODULE_KEY, RustModule.LeafModule::class.java) +fun Symbol.Builder.module(module: RustModule.LeafModule): Symbol.Builder = this.putProperty(RUST_MODULE_KEY, module) +fun Symbol.renamedFrom(): String? = this.getProperty(RENAMED_FROM_KEY, String::class.java).orNull() +fun Symbol.Builder.renamedFrom(name: String): Symbol.Builder = this.putProperty(RENAMED_FROM_KEY, name) +fun Symbol.defaultValue(): Default = this.getProperty(SYMBOL_DEFAULT, Default::class.java).orElse(Default.NoDefault) +fun Symbol.Builder.setDefault(default: Default): Symbol.Builder = this.putProperty(SYMBOL_DEFAULT, default) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt index a7017b504ef..880ac0510ad 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt @@ -18,7 +18,6 @@ import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.SensitiveTrait import software.amazon.smithy.model.traits.StreamingTrait @@ -27,27 +26,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.util.hasTrait -/** - * Default delegator to enable easily decorating another symbol provider. - */ -open class WrappingSymbolProvider(private val base: RustSymbolProvider) : RustSymbolProvider { - override fun config(): SymbolVisitorConfig { - return base.config() - } - - override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? { - return base.toEnumVariantName(definition) - } - - override fun toSymbol(shape: Shape): Symbol { - return base.toSymbol(shape) - } - - override fun toMemberName(shape: MemberShape): String { - return base.toMemberName(shape) - } -} - /** * Attach `meta` to symbols. `meta` is used by the generators (e.g. StructureGenerator) to configure the generated models. * @@ -92,7 +70,7 @@ fun containerDefaultMetadata( model: Model, additionalAttributes: List = emptyList(), ): RustMetadata { - val defaultDerives = setOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone) + val derives = mutableSetOf(RuntimeType.Debug, RuntimeType.PartialEq, RuntimeType.Clone) val isSensitive = shape.hasTrait() || // Checking the shape's direct members for the sensitive trait should suffice. @@ -101,27 +79,21 @@ fun containerDefaultMetadata( // shape; any sensitive descendant should still be printed as redacted. shape.members().any { it.getMemberTrait(model, SensitiveTrait::class.java).isPresent } - val setOfDerives = if (isSensitive) { - defaultDerives - RuntimeType.Debug - } else { - defaultDerives + if (isSensitive) { + derives.remove(RuntimeType.Debug) } - return RustMetadata( - setOfDerives, - additionalAttributes, - Visibility.PUBLIC, - ) + + return RustMetadata(derives, additionalAttributes, Visibility.PUBLIC) } /** * The base metadata supports a set of attributes that are used by generators to decorate code. * - * By default we apply `#[non_exhaustive]` in [additionalAttributes] only to client structures since breaking model + * By default, we apply `#[non_exhaustive]` in [additionalAttributes] only to client structures since breaking model * changes are fine when generating server code. */ class BaseSymbolMetadataProvider( base: RustSymbolProvider, - private val model: Model, private val additionalAttributes: List, ) : SymbolMetadataProvider(base) { @@ -142,6 +114,11 @@ class BaseSymbolMetadataProvider( } is UnionShape, is CollectionShape, is MapShape -> RustMetadata(visibility = Visibility.PUBLIC) + + // This covers strings with the enum trait for now, and can be removed once we're fully on EnumShape + // TODO(https://github.com/awslabs/smithy-rs/issues/1700): Remove this `is StringShape` match arm + is StringShape -> RustMetadata(visibility = Visibility.PUBLIC) + else -> TODO("Unrecognized container type: $container") } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt index c6a1b1dac6b..fbab905dad1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitor.kt @@ -17,6 +17,7 @@ import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ByteShape import software.amazon.smithy.model.shapes.DocumentShape import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.EnumShape import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.ListShape @@ -35,7 +36,6 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute @@ -43,14 +43,10 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.Visibility -import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait -import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.letIf -import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import kotlin.reflect.KClass @@ -67,89 +63,6 @@ val SimpleShapes: Map, RustType> = mapOf( StringShape::class to RustType.String, ) -data class SymbolVisitorConfig( - val runtimeConfig: RuntimeConfig, - val renameExceptions: Boolean, - val nullabilityCheckMode: CheckMode, -) - -/** - * Make the Rust type of a symbol optional (hold `Option`) - * - * This is idempotent and will have no change if the type is already optional. - */ -fun Symbol.makeOptional(): Symbol = - if (isOptional()) { - this - } else { - val rustType = RustType.Option(this.rustType()) - Symbol.builder() - .rustType(rustType) - .addReference(this) - .name(rustType.name) - .build() - } - -/** - * Make the Rust type of a symbol boxed (hold `Box`). - * - * This is idempotent and will have no change if the type is already boxed. - */ -fun Symbol.makeRustBoxed(): Symbol = - if (isRustBoxed()) { - this - } else { - val rustType = RustType.Box(this.rustType()) - Symbol.builder() - .rustType(rustType) - .addReference(this) - .name(rustType.name) - .build() - } - -/** - * Make the Rust type of a symbol wrapped in `MaybeConstrained`. (hold `MaybeConstrained`). - * - * This is idempotent and will have no change if the type is already `MaybeConstrained`. - */ -fun Symbol.makeMaybeConstrained(): Symbol = - if (this.rustType() is RustType.MaybeConstrained) { - this - } else { - val rustType = RustType.MaybeConstrained(this.rustType()) - Symbol.builder() - .rustType(rustType) - .addReference(this) - .name(rustType.name) - .build() - } - -/** - * Map the [RustType] of a symbol with [f]. - * - * WARNING: This function does not set any `SymbolReference`s on the returned symbol. You will have to add those - * yourself if your logic relies on them. - **/ -fun Symbol.mapRustType(f: (RustType) -> RustType): Symbol { - val newType = f(this.rustType()) - return Symbol.builder().rustType(newType) - .name(newType.name) - .build() -} - -/** Set the symbolLocation for this symbol builder */ -fun Symbol.Builder.locatedIn(rustModule: RustModule.LeafModule): Symbol.Builder { - val currentRustType = this.build().rustType() - check(currentRustType is RustType.Opaque) { - "Only `Opaque` can have their namespace updated" - } - val newRustType = currentRustType.copy(namespace = rustModule.fullyQualifiedPath()) - return this.definitionFile(rustModule.definitionFile()) - .namespace(rustModule.fullyQualifiedPath(), "::") - .rustType(newRustType) - .module(rustModule) -} - /** * Track both the past and current name of a symbol * @@ -161,23 +74,19 @@ fun Symbol.Builder.locatedIn(rustModule: RustModule.LeafModule): Symbol.Builder */ data class MaybeRenamed(val name: String, val renamedFrom: String?) -/** - * SymbolProvider interface that carries both the inner configuration and a function to produce an enum variant name. - */ -interface RustSymbolProvider : SymbolProvider { - fun config(): SymbolVisitorConfig - fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? -} - /** * Make the return [value] optional if the [member] symbol is as well optional. */ -fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = value.letIf(toSymbol(member).isOptional()) { "Some($value)" } +fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = value.letIf(toSymbol(member).isOptional()) { + "Some($value)" +} /** * Make the return [value] optional if the [member] symbol is not optional. */ -fun SymbolProvider.toOptional(member: MemberShape, value: String): String = value.letIf(!toSymbol(member).isOptional()) { "Some($value)" } +fun SymbolProvider.toOptional(member: MemberShape, value: String): String = value.letIf(!toSymbol(member).isOptional()) { + "Some($value)" +} /** * Services can rename their contained shapes. See https://awslabs.github.io/smithy/1.0/spec/core/model.html#service @@ -199,33 +108,42 @@ fun Shape.contextName(serviceShape: ServiceShape?): String { * derives for a given shape. */ open class SymbolVisitor( - private val model: Model, + settings: CoreRustSettings, + override val model: Model, private val serviceShape: ServiceShape?, - private val config: SymbolVisitorConfig, -) : RustSymbolProvider, - ShapeVisitor { + override val config: RustSymbolProviderConfig, +) : RustSymbolProvider, ShapeVisitor { + override val moduleProviderContext = ModuleProviderContext(settings, model, serviceShape) private val nullableIndex = NullableIndex.of(model) - override fun config(): SymbolVisitorConfig = config override fun toSymbol(shape: Shape): Symbol { return shape.accept(this) } - /** - * Return the name of a given `enum` variant. Note that this refers to `enum` in the Smithy context - * where enum is a trait that can be applied to [StringShape] and not in the Rust context of an algebraic data type. - * - * Because enum variants are not member shape, a separate handler is required. - */ - override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? { - val baseName = definition.name.orNull()?.toPascalCase() ?: return null - return MaybeRenamed(baseName, null) + override fun symbolForOperationError(operation: OperationShape): Symbol = + toSymbol(operation).let { symbol -> + val module = moduleForOperationError(operation) + module.toType().resolve("${symbol.name}Error").toSymbol().toBuilder().locatedIn(module).build() + } + + override fun symbolForEventStreamError(eventStream: UnionShape): Symbol = + toSymbol(eventStream).let { symbol -> + val module = moduleForEventStreamError(eventStream) + module.toType().resolve("${symbol.name}Error").toSymbol().toBuilder().locatedIn(module).build() + } + + override fun symbolForBuilder(shape: Shape): Symbol = toSymbol(shape).let { symbol -> + val module = moduleForBuilder(shape) + module.toType().resolve(config.nameBuilderFor(symbol)).toSymbol().toBuilder().locatedIn(module).build() } - override fun toMemberName(shape: MemberShape): String = when (val container = model.expectShape(shape.container)) { - is StructureShape -> shape.memberName.toSnakeCase() - is UnionShape -> shape.memberName.toPascalCase() - else -> error("unexpected container shape: $container") + override fun toMemberName(shape: MemberShape): String { + val container = model.expectShape(shape.container) + return when { + container is StructureShape -> shape.memberName.toSnakeCase() + container is UnionShape || container is EnumShape || container.hasTrait() -> shape.memberName.toPascalCase() + else -> error("unexpected container shape: $container") + } } override fun blobShape(shape: BlobShape?): Symbol { @@ -244,7 +162,9 @@ open class SymbolVisitor( name(rustType.name) build() } - } else symbol + } else { + symbol + } } private fun simpleShape(shape: SimpleShape): Symbol { @@ -261,7 +181,7 @@ open class SymbolVisitor( override fun stringShape(shape: StringShape): Symbol { return if (shape.hasTrait()) { val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) - symbolBuilder(shape, rustType).locatedIn(ModelsModule).build() + symbolBuilder(shape, rustType).locatedIn(moduleForShape(shape)).build() } else { simpleShape(shape) } @@ -312,7 +232,7 @@ open class SymbolVisitor( .replaceFirstChar { it.uppercase() }, ), ) - .locatedIn(OperationsModule) + .locatedIn(moduleForShape(shape)) .build() } @@ -326,25 +246,15 @@ open class SymbolVisitor( override fun structureShape(shape: StructureShape): Symbol { val isError = shape.hasTrait() - val isInput = shape.hasTrait() - val isOutput = shape.hasTrait() val name = shape.contextName(serviceShape).toPascalCase().letIf(isError && config.renameExceptions) { it.replace("Exception", "Error") } - val builder = symbolBuilder(shape, RustType.Opaque(name)) - return when { - isError -> builder.locatedIn(ErrorsModule) - isInput -> builder.locatedIn(InputsModule) - isOutput -> builder.locatedIn(OutputsModule) - else -> builder.locatedIn(ModelsModule) - }.build() + return symbolBuilder(shape, RustType.Opaque(name)).locatedIn(moduleForShape(shape)).build() } override fun unionShape(shape: UnionShape): Symbol { val name = shape.contextName(serviceShape).toPascalCase() - val builder = symbolBuilder(shape, RustType.Opaque(name)).locatedIn(ModelsModule) - - return builder.build() + return symbolBuilder(shape, RustType.Opaque(name)).locatedIn(moduleForShape(shape)).build() } override fun memberShape(shape: MemberShape): Symbol { @@ -372,30 +282,20 @@ open class SymbolVisitor( fun handleRustBoxing(symbol: Symbol, shape: MemberShape): Symbol = if (shape.hasTrait()) { symbol.makeRustBoxed() - } else symbol + } else { + symbol + } -fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder { - val builder = Symbol.builder().putProperty(SHAPE_KEY, shape) - return builder.rustType(rustType) +fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder = + Symbol.builder().shape(shape).rustType(rustType) .name(rustType.name) // Every symbol that actually gets defined somewhere should set a definition file // If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation .definitionFile("thisisabug.rs") -} fun handleOptionality(symbol: Symbol, member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode): Symbol = symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() } -private const val RUST_TYPE_KEY = "rusttype" -private const val RUST_MODULE_KEY = "rustmodule" -private const val SHAPE_KEY = "shape" -private const val SYMBOL_DEFAULT = "symboldefault" -private const val RENAMED_FROM_KEY = "renamedfrom" - -fun Symbol.Builder.rustType(rustType: RustType): Symbol.Builder = this.putProperty(RUST_TYPE_KEY, rustType) -fun Symbol.Builder.module(module: RustModule.LeafModule): Symbol.Builder = this.putProperty(RUST_MODULE_KEY, module) -fun Symbol.module(): RustModule.LeafModule = this.expectProperty(RUST_MODULE_KEY, RustModule.LeafModule::class.java) - /** * Creates a test module for this symbol. * For example if the symbol represents the name for the struct `struct MyStruct { ... }`, @@ -418,49 +318,6 @@ fun SymbolProvider.testModuleForShape(shape: Shape): RustModule.LeafModule { ) } -fun Symbol.Builder.renamedFrom(name: String): Symbol.Builder { - return this.putProperty(RENAMED_FROM_KEY, name) -} - -fun Symbol.renamedFrom(): String? = this.getProperty(RENAMED_FROM_KEY, String::class.java).orNull() - -fun Symbol.defaultValue(): Default = this.getProperty(SYMBOL_DEFAULT, Default::class.java).orElse(Default.NoDefault) -fun Symbol.Builder.setDefault(default: Default): Symbol.Builder = this.putProperty(SYMBOL_DEFAULT, default) - -/** - * Type representing the default value for a given type. (eg. for Strings, this is `""`) - */ -sealed class Default { - /** - * This symbol has no default value. If the symbol is not optional, this will be an error during builder construction - */ - object NoDefault : Default() - - /** - * This symbol should use the Rust `std::default::Default` when unset - */ - object RustDefault : Default() -} - -/** - * True when it is valid to use the default/0 value for [this] symbol during construction. - */ -fun Symbol.canUseDefault(): Boolean = this.defaultValue() != Default.NoDefault - -/** - * True when [this] is will be represented by Option in Rust - */ -fun Symbol.isOptional(): Boolean = when (this.rustType()) { - is RustType.Option -> true - else -> false -} - -fun Symbol.isRustBoxed(): Boolean = rustType().stripOuter() is RustType.Box - -// Symbols should _always_ be created with a Rust type & shape attached -fun Symbol.rustType(): RustType = this.expectProperty(RUST_TYPE_KEY, RustType::class.java) -fun Symbol.shape(): Shape = this.expectProperty(SHAPE_KEY, Shape::class.java) - /** * You should rarely need this function, rust names in general should be symbol-aware, * this is "automatic" if you use things like [software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate]. diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt index a06ff50a281..27e4598246b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/AllowLintsCustomization.kt @@ -34,7 +34,7 @@ private val allowedClippyLints = listOf( "should_implement_trait", // Protocol tests use silly names like `baz`, don't flag that. - // TODO(msrv_upgrade): switch + // TODO(msrv_upgrade): switch upon MSRV upgrade to Rust 1.65 "blacklisted_name", // "disallowed_names", @@ -48,11 +48,10 @@ private val allowedClippyLints = listOf( "needless_return", // For backwards compatibility, we often don't derive Eq - // TODO(msrv_upgrade): enable - // "derive_partial_eq_without_eq", + "derive_partial_eq_without_eq", // Keeping errors small in a backwards compatible way is challenging - // TODO(msrv_upgrade): enable + // TODO(msrv_upgrade): uncomment upon MSRV upgrade to Rust 1.65 // "result_large_err", ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt index eca55030506..93db223c20e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/CrateVersionCustomization.kt @@ -5,24 +5,24 @@ package software.amazon.smithy.rust.codegen.core.smithy.customizations +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate /** * Add `PGK_VERSION` const in lib.rs to enable knowing the version of the current module */ -class CrateVersionCustomization : LibRsCustomization() { - override fun section(section: LibRsSection) = - writable { - if (section is LibRsSection.Body) { - rust( - """ - /// Crate version number. - pub static PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); - """, - ) - } +object CrateVersionCustomization { + fun pkgVersion(module: RustModule): RuntimeType = RuntimeType(module.fullyQualifiedPath() + "::PKG_VERSION") + + fun extras(rustCrate: RustCrate, module: RustModule) = + rustCrate.withModule(module) { + rust( + """ + /// Crate version number. + pub static PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); + """, + ) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt index 02a8843c19c..48b68bfa6d3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtra.kt @@ -8,13 +8,15 @@ package software.amazon.smithy.rust.codegen.core.smithy.customizations import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.core.util.letIf private data class PubUseType( val type: RuntimeType, @@ -46,16 +48,14 @@ private fun hasBlobs(model: Model): Boolean = structUnionMembersMatchPredicate(m private fun hasDateTimes(model: Model): Boolean = structUnionMembersMatchPredicate(model, Shape::isTimestampShape) /** Returns a list of types that should be re-exported for the given model */ -internal fun pubUseTypes(runtimeConfig: RuntimeConfig, model: Model): List { +internal fun pubUseTypes(codegenContext: CodegenContext, model: Model): List { + val runtimeConfig = codegenContext.runtimeConfig return ( listOf( PubUseType(RuntimeType.blob(runtimeConfig), ::hasBlobs), PubUseType(RuntimeType.dateTime(runtimeConfig), ::hasDateTimes), - ) + RuntimeType.smithyTypes(runtimeConfig).let { types -> - listOf(PubUseType(types.resolve("error::display::DisplayErrorContext")) { true }) - } + RuntimeType.smithyHttp(runtimeConfig).let { http -> + ) + RuntimeType.smithyHttp(runtimeConfig).let { http -> listOf( - PubUseType(http.resolve("result::SdkError")) { true }, PubUseType(http.resolve("byte_stream::ByteStream"), ::hasStreamingOperations), PubUseType(http.resolve("byte_stream::AggregatedBytes"), ::hasStreamingOperations), ) @@ -63,12 +63,32 @@ internal fun pubUseTypes(runtimeConfig: RuntimeConfig, model: Model): List pubUseType.shouldExport(model) }.map { it.type } } -/** Adds re-export statements in a separate file for the types module */ -fun pubUseSmithyTypes(runtimeConfig: RuntimeConfig, model: Model, rustCrate: RustCrate) { - rustCrate.withModule(RustModule.Types) { - val types = pubUseTypes(runtimeConfig, model) - if (types.isNotEmpty()) { - types.forEach { type -> rust("pub use #T;", type) } - } +/** Adds re-export statements for Smithy primitives */ +fun pubUseSmithyPrimitives(codegenContext: CodegenContext, model: Model): Writable = writable { + val types = pubUseTypes(codegenContext, model) + if (types.isNotEmpty()) { + types.forEach { type -> rust("pub use #T;", type) } + } +} + +/** Adds re-export statements for error types */ +fun pubUseSmithyErrorTypes(codegenContext: CodegenContext): Writable = writable { + val runtimeConfig = codegenContext.runtimeConfig + val reexports = listOf( + listOf( + RuntimeType.smithyHttp(runtimeConfig).let { http -> + PubUseType(http.resolve("result::SdkError")) { true } + }, + ), + RuntimeType.smithyTypes(runtimeConfig).let { types -> + listOf(PubUseType(types.resolve("error::display::DisplayErrorContext")) { true }) + // Only re-export `ProvideErrorMetadata` for clients + .letIf(codegenContext.target == CodegenTarget.CLIENT) { list -> + list + listOf(PubUseType(types.resolve("error::metadata::ProvideErrorMetadata")) { true }) + } + }, + ).flatten() + reexports.forEach { reexport -> + rust("pub use #T;", reexport.type) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt index c886c4ed97c..756f0d132c3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/CoreCodegenDecorator.kt @@ -9,8 +9,11 @@ import software.amazon.smithy.build.PluginContext import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.ManifestCustomizations +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization import software.amazon.smithy.rust.codegen.core.util.deepMergeWith import java.util.ServiceLoader import java.util.logging.Logger @@ -62,6 +65,31 @@ interface CoreCodegenDecorator { baseCustomizations: List, ): List = baseCustomizations + /** + * Hook to customize structures generated by `StructureGenerator`. + */ + fun structureCustomizations( + codegenContext: CodegenContext, + baseCustomizations: List, + ): List = baseCustomizations + + // TODO(https://github.com/awslabs/smithy-rs/issues/1401): Move builder customizations into `ClientCodegenDecorator` + /** + * Hook to customize generated builders. + */ + fun builderCustomizations( + codegenContext: CodegenContext, + baseCustomizations: List, + ): List = baseCustomizations + + /** + * Hook to customize error struct `impl` blocks. + */ + fun errorImplCustomizations( + codegenContext: CodegenContext, + baseCustomizations: List, + ): List = baseCustomizations + /** * Extra sections allow one decorator to influence another. This is intended to be used by querying the `rootDecorator` */ @@ -76,14 +104,6 @@ abstract class CombinedCoreCodegenDecorator { private val orderedDecorators = decorators.sortedBy { it.order } - final override fun libRsCustomizations( - codegenContext: CodegenContext, - baseCustomizations: List, - ): List = - combineCustomizations(baseCustomizations) { decorator, customizations -> - decorator.libRsCustomizations(codegenContext, customizations) - } - final override fun crateManifestCustomizations(codegenContext: CodegenContext): ManifestCustomizations = combineCustomizations(emptyMap()) { decorator, customizations -> customizations.deepMergeWith(decorator.crateManifestCustomizations(codegenContext)) @@ -98,6 +118,35 @@ abstract class CombinedCoreCodegenDecorator, + ): List = + combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.libRsCustomizations(codegenContext, customizations) + } + + override fun structureCustomizations( + codegenContext: CodegenContext, + baseCustomizations: List, + ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.structureCustomizations(codegenContext, customizations) + } + + override fun builderCustomizations( + codegenContext: CodegenContext, + baseCustomizations: List, + ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.builderCustomizations(codegenContext, customizations) + } + + override fun errorImplCustomizations( + codegenContext: CodegenContext, + baseCustomizations: List, + ): List = combineCustomizations(baseCustomizations) { decorator, customizations -> + decorator.errorImplCustomizations(codegenContext, customizations) + } + final override fun extraSections(codegenContext: CodegenContext): List = addCustomizations { decorator -> decorator.extraSections(codegenContext) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/OperationCustomization.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/OperationCustomization.kt index d11f98a5c7a..7ab4586c311 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/OperationCustomization.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customize/OperationCustomization.kt @@ -53,6 +53,26 @@ sealed class OperationSection(name: String) : Section(name) { override val customizations: List, val operationShape: OperationShape, ) : OperationSection("MutateOutput") + + /** + * Allows for adding additional properties to the `extras` field on the + * `aws_smithy_types::error::ErrorMetadata`. + */ + data class PopulateErrorMetadataExtras( + override val customizations: List, + /** Name of the generic error builder (for referring to it in Rust code) */ + val builderName: String, + /** Name of the response (for referring to it in Rust code) */ + val responseName: String, + ) : OperationSection("PopulateErrorMetadataExtras") + + /** + * Hook to add custom code right before the response is parsed. + */ + data class BeforeParseResponse( + override val customizations: List, + val responseName: String, + ) : OperationSection("BeforeParseResponse") } abstract class OperationCustomization : NamedCustomization() { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt index 69aca456308..16cb28b8032 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGenerator.kt @@ -12,11 +12,8 @@ import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asArgument import software.amazon.smithy.rust.codegen.core.rustlang.asOptional @@ -36,12 +33,13 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations import software.amazon.smithy.rust.codegen.core.smithy.defaultValue import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.isOptional -import software.amazon.smithy.rust.codegen.core.smithy.locatedIn import software.amazon.smithy.rust.codegen.core.smithy.makeOptional -import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.util.dq @@ -52,22 +50,27 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase // TODO(https://github.com/awslabs/smithy-rs/issues/1401) This builder generator is only used by the client. // Move this entire file, and its tests, to `codegen-client`. -fun builderSymbolFn(symbolProvider: RustSymbolProvider): (StructureShape) -> Symbol = { structureShape -> - structureShape.builderSymbol(symbolProvider) -} +/** BuilderGenerator customization sections */ +sealed class BuilderSection(name: String) : Section(name) { + abstract val shape: StructureShape + + /** Hook to add additional fields to the builder */ + data class AdditionalFields(override val shape: StructureShape) : BuilderSection("AdditionalFields") + + /** Hook to add additional methods to the builder */ + data class AdditionalMethods(override val shape: StructureShape) : BuilderSection("AdditionalMethods") + + /** Hook to add additional fields to the `build()` method */ + data class AdditionalFieldsInBuild(override val shape: StructureShape) : BuilderSection("AdditionalFieldsInBuild") -fun StructureShape.builderSymbol(symbolProvider: RustSymbolProvider): Symbol { - val structureSymbol = symbolProvider.toSymbol(this) - val builderNamespace = RustReservedWords.escapeIfNeeded(structureSymbol.name.toSnakeCase()) - val module = RustModule.new(builderNamespace, visibility = Visibility.PUBLIC, parent = structureSymbol.module(), inline = true) - val rustType = RustType.Opaque("Builder", module.fullyQualifiedPath()) - return Symbol.builder() - .rustType(rustType) - .name(rustType.name) - .locatedIn(module) - .build() + /** Hook to add additional fields to the `Debug` impl */ + data class AdditionalDebugFields(override val shape: StructureShape, val formatterName: String) : + BuilderSection("AdditionalDebugFields") } +/** Customizations for BuilderGenerator */ +abstract class BuilderCustomization : NamedCustomization() + fun RuntimeConfig.operationBuildError() = RuntimeType.operationModule(this).resolve("error::BuildError") fun RuntimeConfig.serializationError() = RuntimeType.operationModule(this).resolve("error::SerializationError") @@ -92,6 +95,7 @@ class BuilderGenerator( private val model: Model, private val symbolProvider: RustSymbolProvider, private val shape: StructureShape, + private val customizations: List, ) { companion object { /** @@ -108,27 +112,34 @@ class BuilderGenerator( // generate a fallible builder. !it.isOptional() && !it.canUseDefault() } + + fun renderConvenienceMethod(implBlock: RustWriter, symbolProvider: RustSymbolProvider, shape: StructureShape) { + implBlock.docs("Creates a new builder-style object to manufacture #D.", symbolProvider.toSymbol(shape)) + symbolProvider.symbolForBuilder(shape).also { builderSymbol -> + implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) { + write("#T::default()", builderSymbol) + } + } + } } - private val runtimeConfig = symbolProvider.config().runtimeConfig + private val runtimeConfig = symbolProvider.config.runtimeConfig private val members: List = shape.allMembers.values.toList() private val structureSymbol = symbolProvider.toSymbol(shape) - private val builderSymbol = shape.builderSymbol(symbolProvider) - private val baseDerives = structureSymbol.expectRustMetadata().derives + private val builderSymbol = symbolProvider.symbolForBuilder(shape) + private val metadata = structureSymbol.expectRustMetadata() // Filter out any derive that isn't Debug, PartialEq, or Clone. Then add a Default derive - private val builderDerives = baseDerives.filter { it == RuntimeType.Debug || it == RuntimeType.PartialEq || it == RuntimeType.Clone } + RuntimeType.Default - private val builderName = "Builder" + private val builderDerives = metadata.derives.filter { + it == RuntimeType.Debug || it == RuntimeType.PartialEq || it == RuntimeType.Clone + } + RuntimeType.Default + private val builderName = symbolProvider.symbolForBuilder(shape).name fun render(writer: RustWriter) { - val symbol = symbolProvider.toSymbol(shape) - writer.docs("See #D.", symbol) - writer.withInlineModule(shape.builderSymbol(symbolProvider).module()) { - // Matching derives to the main structure + `Default` since we are a builder and everything is optional. - renderBuilder(this) - if (!structureSymbol.expectRustMetadata().hasDebugDerive()) { - renderDebugImpl(this) - } + // Matching derives to the main structure + `Default` since we are a builder and everything is optional. + renderBuilder(writer) + if (!structureSymbol.expectRustMetadata().hasDebugDerive()) { + renderDebugImpl(writer) } } @@ -153,13 +164,6 @@ class BuilderGenerator( OperationBuildError(runtimeConfig).missingField(field, detailedMessage)(this) } - fun renderConvenienceMethod(implBlock: RustWriter) { - implBlock.docs("Creates a new builder-style object to manufacture #D.", structureSymbol) - implBlock.rustBlock("pub fn builder() -> #T", builderSymbol) { - write("#T::default()", builderSymbol) - } - } - // TODO(EventStream): [DX] Consider updating builders to take EventInputStream as Into private fun renderBuilderMember(writer: RustWriter, memberName: String, memberSymbol: Symbol) { // Builder members are crate-public to enable using them directly in serializers/deserializers. @@ -207,6 +211,7 @@ class BuilderGenerator( private fun renderBuilder(writer: RustWriter) { writer.docs("A builder for #D.", structureSymbol) + metadata.additionalAttributes.render(writer) Attribute(derive(builderDerives)).render(writer) writer.rustBlock("pub struct $builderName") { for (member in members) { @@ -215,6 +220,7 @@ class BuilderGenerator( val memberSymbol = symbolProvider.toSymbol(member).makeOptional() renderBuilderMember(this, memberName, memberSymbol) } + writeCustomizations(customizations, BuilderSection.AdditionalFields(shape)) } writer.rustBlock("impl $builderName") { @@ -234,6 +240,7 @@ class BuilderGenerator( renderBuilderMemberSetterFn(this, outerType, member, memberName) } + writeCustomizations(customizations, BuilderSection.AdditionalMethods(shape)) renderBuildFn(this) } } @@ -250,6 +257,7 @@ class BuilderGenerator( "formatter.field(${memberName.dq()}, &$fieldValue);", ) } + writeCustomizations(customizations, BuilderSection.AdditionalDebugFields(shape, "formatter")) rust("formatter.finish()") } } @@ -328,6 +336,7 @@ class BuilderGenerator( } } } + writeCustomizations(customizations, BuilderSection.AdditionalFieldsInBuild(shape)) } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index ec8f3505e97..379a6982da5 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -7,13 +7,16 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.DocumentationTrait import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.documentShape @@ -21,28 +24,92 @@ import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata +import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom import software.amazon.smithy.rust.codegen.core.util.REDACTION -import software.amazon.smithy.rust.codegen.core.util.doubleQuote import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.shouldRedact +import software.amazon.smithy.rust.codegen.core.util.toPascalCase + +data class EnumGeneratorContext( + val enumName: String, + val enumMeta: RustMetadata, + val enumTrait: EnumTrait, + val sortedMembers: List, +) + +/** + * Type of enum to generate + * + * In codegen-core, there are only `Infallible` enums. Server adds additional enum types, which + * is why this class is abstract rather than sealed. + */ +abstract class EnumType { + /** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the enum */ + abstract fun implFromForStr(context: EnumGeneratorContext): Writable + + /** Returns a writable that implements `FromStr` for the enum */ + abstract fun implFromStr(context: EnumGeneratorContext): Writable + + /** Optionally adds additional documentation to the `enum` docs */ + open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {} + + /** Optionally adds additional enum members */ + open fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable {} + + /** Optionally adds match arms to the `as_str` match implementation for named enums */ + open fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = writable {} + + /** Optionally add more attributes to the enum */ + open fun additionalEnumAttributes(context: EnumGeneratorContext): List = emptyList() + + /** Optionally add more impls to the enum */ + open fun additionalEnumImpls(context: EnumGeneratorContext): Writable = writable {} +} /** Model that wraps [EnumDefinition] to calculate and cache values required to generate the Rust enum source. */ -class EnumMemberModel(private val definition: EnumDefinition, private val symbolProvider: RustSymbolProvider) { +class EnumMemberModel( + private val parentShape: Shape, + private val definition: EnumDefinition, + private val symbolProvider: RustSymbolProvider, +) { + companion object { + /** + * Return the name of a given `enum` variant. Note that this refers to `enum` in the Smithy context + * where enum is a trait that can be applied to [StringShape] and not in the Rust context of an algebraic data type. + * + * Ordinarily, the symbol provider would determine this name, but the enum trait doesn't allow for this. + * + * TODO(https://github.com/awslabs/smithy-rs/issues/1700): Remove this function when refactoring to EnumShape. + */ + @Deprecated("This function will go away when we handle EnumShape instead of EnumTrait") + fun toEnumVariantName( + symbolProvider: RustSymbolProvider, + parentShape: Shape, + definition: EnumDefinition, + ): MaybeRenamed? { + val name = definition.name.orNull()?.toPascalCase() ?: return null + // Create a fake member shape for symbol look up until we refactor to use EnumShape + val fakeMemberShape = + MemberShape.builder().id(parentShape.id.withMember(name)).target("smithy.api#String").build() + val symbol = symbolProvider.toSymbol(fakeMemberShape) + return MaybeRenamed(symbol.name, symbol.renamedFrom()) + } + } // Because enum variants always start with an upper case letter, they will never // conflict with reserved words (which are always lower case), therefore, we never need // to fall back to raw identifiers val value: String get() = definition.value - fun name(): MaybeRenamed? = symbolProvider.toEnumVariantName(definition) + fun name(): MaybeRenamed? = toEnumVariantName(symbolProvider, parentShape, definition) private fun renderDocumentation(writer: RustWriter) { val name = @@ -61,7 +128,7 @@ class EnumMemberModel(private val definition: EnumDefinition, private val symbol } } - fun derivedName() = checkNotNull(symbolProvider.toEnumVariantName(definition)).name + fun derivedName() = checkNotNull(toEnumVariantName(symbolProvider, parentShape, definition)).name fun render(writer: RustWriter) { renderDocumentation(writer) @@ -88,152 +155,134 @@ private fun RustWriter.docWithNote(doc: String?, note: String?) { open class EnumGenerator( private val model: Model, private val symbolProvider: RustSymbolProvider, - private val writer: RustWriter, - protected val shape: StringShape, - protected val enumTrait: EnumTrait, + private val shape: StringShape, + private val enumType: EnumType, ) { - protected val symbol: Symbol = symbolProvider.toSymbol(shape) - protected val enumName: String = symbol.name - protected val meta = symbol.expectRustMetadata() - protected val sortedMembers: List = - enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(it, symbolProvider) } - protected open var target: CodegenTarget = CodegenTarget.CLIENT - companion object { - /** Name of the generated unknown enum member name for enums with named members. */ - const val UnknownVariant = "Unknown" - - /** Name of the opaque struct that is inner data for the generated [UnknownVariant]. */ - const val UnknownVariantValue = "UnknownVariantValue" - /** Name of the function on the enum impl to get a vec of value names */ const val Values = "values" } - open fun render() { + private val enumTrait: EnumTrait = shape.expectTrait() + private val symbol: Symbol = symbolProvider.toSymbol(shape) + private val context = EnumGeneratorContext( + enumName = symbol.name, + enumMeta = symbol.expectRustMetadata(), + enumTrait = enumTrait, + sortedMembers = enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(shape, it, symbolProvider) }, + ) + + fun render(writer: RustWriter) { + enumType.additionalEnumAttributes(context).forEach { attribute -> + attribute.render(writer) + } if (enumTrait.hasNames()) { - // pub enum Blah { V1, V2, .. } - renderEnum() - writer.insertTrailingNewline() - // impl From for Blah { ... } - renderFromForStr() - // impl FromStr for Blah { ... } - renderFromStr() - writer.insertTrailingNewline() - // impl Blah { pub fn as_str(&self) -> &str - implBlock() - writer.rustBlock("impl AsRef for $enumName") { - rustBlock("fn as_ref(&self) -> &str") { - rust("self.as_str()") - } - } + writer.renderNamedEnum() } else { - renderUnnamedEnum() + writer.renderUnnamedEnum() } + enumType.additionalEnumImpls(context)(writer) if (shape.shouldRedact(model)) { - renderDebugImplForSensitiveEnum() + writer.renderDebugImplForSensitiveEnum() } } - private fun renderUnnamedEnum() { - writer.documentShape(shape, model) - writer.deprecatedShape(shape) - meta.render(writer) - writer.write("struct $enumName(String);") - writer.rustBlock("impl $enumName") { - docs("Returns the `&str` value of the enum member.") - rustBlock("pub fn as_str(&self) -> &str") { - rust("&self.0") - } - - docs("Returns all the `&str` representations of the enum members.") - rustBlock("pub const fn $Values() -> &'static [&'static str]") { - withBlock("&[", "]") { - val memberList = sortedMembers.joinToString(", ") { it.value.dq() } - rust(memberList) + private fun RustWriter.renderNamedEnum() { + // pub enum Blah { V1, V2, .. } + renderEnum() + insertTrailingNewline() + // impl From for Blah { ... } + enumType.implFromForStr(context)(this) + // impl FromStr for Blah { ... } + enumType.implFromStr(context)(this) + insertTrailingNewline() + // impl Blah { pub fn as_str(&self) -> &str + implBlock( + asStrImpl = writable { + rustBlock("match self") { + context.sortedMembers.forEach { member -> + rust("""${context.enumName}::${member.derivedName()} => ${member.value.dq()},""") + } + enumType.additionalAsStrMatchArms(context)(this) + } + }, + ) + rust( + """ + impl AsRef for ${context.enumName} { + fn as_ref(&self) -> &str { + self.as_str() } } - } + """, + ) + } - writer.rustBlock("impl #T for $enumName where T: #T", RuntimeType.From, RuntimeType.AsRef) { - rustBlock("fn from(s: T) -> Self") { - rust("$enumName(s.as_ref().to_owned())") + private fun RustWriter.renderUnnamedEnum() { + documentShape(shape, model) + deprecatedShape(shape) + context.enumMeta.render(this) + rust("struct ${context.enumName}(String);") + implBlock( + asStrImpl = writable { + rust("&self.0") + }, + ) + + rustTemplate( + """ + impl #{From} for ${context.enumName} where T: #{AsRef} { + fn from(s: T) -> Self { + ${context.enumName}(s.as_ref().to_owned()) + } } - } + """, + "From" to RuntimeType.From, + "AsRef" to RuntimeType.AsRef, + ) } - private fun renderEnum() { - target.ifClient { - writer.renderForwardCompatibilityNote(enumName, sortedMembers, UnknownVariant, UnknownVariantValue) - } + private fun RustWriter.renderEnum() { + enumType.additionalDocs(context)(this) val renamedWarning = - sortedMembers.mapNotNull { it.name() }.filter { it.renamedFrom != null }.joinToString("\n") { + context.sortedMembers.mapNotNull { it.name() }.filter { it.renamedFrom != null }.joinToString("\n") { val previousName = it.renamedFrom!! - "`$enumName::$previousName` has been renamed to `::${it.name}`." + "`${context.enumName}::$previousName` has been renamed to `::${it.name}`." } - writer.docWithNote( + docWithNote( shape.getTrait()?.value, renamedWarning.ifBlank { null }, ) - writer.deprecatedShape(shape) + deprecatedShape(shape) - meta.render(writer) - writer.rustBlock("enum $enumName") { - sortedMembers.forEach { member -> member.render(writer) } - target.ifClient { - docs("`$UnknownVariant` contains new variants that have been added since this code was generated.") - rust("$UnknownVariant(#T)", unknownVariantValue()) - } + context.enumMeta.render(this) + rustBlock("enum ${context.enumName}") { + context.sortedMembers.forEach { member -> member.render(this) } + enumType.additionalEnumMembers(context)(this) } } - private fun implBlock() { - writer.rustBlock("impl $enumName") { - rust("/// Returns the `&str` value of the enum member.") - rustBlock("pub fn as_str(&self) -> &str") { - rustBlock("match self") { - sortedMembers.forEach { member -> - rust("""$enumName::${member.derivedName()} => ${member.value.dq()},""") - } - - target.ifClient { - rust("$enumName::$UnknownVariant(value) => value.as_str()") - } - } - } - - rust("/// Returns all the `&str` values of the enum members.") - rustBlock("pub const fn $Values() -> &'static [&'static str]") { - withBlock("&[", "]") { - val memberList = sortedMembers.joinToString(", ") { it.value.doubleQuote() } - write(memberList) + private fun RustWriter.implBlock(asStrImpl: Writable) { + rustTemplate( + """ + impl ${context.enumName} { + /// Returns the `&str` value of the enum member. + pub fn as_str(&self) -> &str { + #{asStrImpl:W} } - } - } - } - - private fun unknownVariantValue(): RuntimeType { - return RuntimeType.forInlineFun(UnknownVariantValue, RustModule.Types) { - docs( - """ - Opaque struct used as inner data for the `Unknown` variant defined in enums in - the crate - - While this is not intended to be used directly, it is marked as `pub` because it is - part of the enums that are public interface. - """.trimIndent(), - ) - meta.render(this) - rust("struct $UnknownVariantValue(pub(crate) String);") - rustBlock("impl $UnknownVariantValue") { - // The generated as_str is not pub as we need to prevent users from calling it on this opaque struct. - rustBlock("pub(crate) fn as_str(&self) -> &str") { - rust("&self.0") + /// Returns all the `&str` representations of the enum members. + pub const fn $Values() -> &'static [&'static str] { + &[#{Values:W}] } } - } + """, + "asStrImpl" to asStrImpl, + "Values" to writable { + rust(context.sortedMembers.joinToString(", ") { it.value.dq() }) + }, + ) } /** @@ -241,10 +290,10 @@ open class EnumGenerator( * * It prints the redacted text regardless of the variant it is asked to print. */ - private fun renderDebugImplForSensitiveEnum() { - writer.rustTemplate( + private fun RustWriter.renderDebugImplForSensitiveEnum() { + rustTemplate( """ - impl #{Debug} for $enumName { + impl #{Debug} for ${context.enumName} { fn fmt(&self, f: &mut #{StdFmt}::Formatter<'_>) -> #{StdFmt}::Result { write!(f, $REDACTION) } @@ -254,89 +303,4 @@ open class EnumGenerator( "StdFmt" to RuntimeType.stdFmt, ) } - - protected open fun renderFromForStr() { - writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.From) { - rustBlock("fn from(s: &str) -> Self") { - rustBlock("match s") { - sortedMembers.forEach { member -> - rust("""${member.value.dq()} => $enumName::${member.derivedName()},""") - } - rust("other => $enumName::$UnknownVariant(#T(other.to_owned()))", unknownVariantValue()) - } - } - } - } - - open fun renderFromStr() { - writer.rust( - """ - impl std::str::FromStr for $enumName { - type Err = std::convert::Infallible; - - fn from_str(s: &str) -> std::result::Result { - Ok($enumName::from(s)) - } - } - """, - ) - } -} - -/** - * Generate the rustdoc describing how to write a match expression against a generated enum in a - * forward-compatible way. - */ -private fun RustWriter.renderForwardCompatibilityNote( - enumName: String, sortedMembers: List, - unknownVariant: String, unknownVariantValue: String, -) { - docs( - """ - When writing a match expression against `$enumName`, it is important to ensure - your code is forward-compatible. That is, if a match arm handles a case for a - feature that is supported by the service but has not been represented as an enum - variant in a current version of SDK, your code should continue to work when you - upgrade SDK to a future version in which the enum does include a variant for that - feature. - """.trimIndent(), - ) - docs("") - docs("Here is an example of how you can make a match expression forward-compatible:") - docs("") - docs("```text") - rust("/// ## let ${enumName.lowercase()} = unimplemented!();") - rust("/// match ${enumName.lowercase()} {") - sortedMembers.mapNotNull { it.name() }.forEach { member -> - rust("/// $enumName::${member.name} => { /* ... */ },") - } - rust("""/// other @ _ if other.as_str() == "NewFeature" => { /* handles a case for `NewFeature` */ },""") - rust("/// _ => { /* ... */ },") - rust("/// }") - docs("```") - docs( - """ - The above code demonstrates that when `${enumName.lowercase()}` represents - `NewFeature`, the execution path will lead to the second last match arm, - even though the enum does not contain a variant `$enumName::NewFeature` - in the current version of SDK. The reason is that the variable `other`, - created by the `@` operator, is bound to - `$enumName::$unknownVariant($unknownVariantValue("NewFeature".to_owned()))` - and calling `as_str` on it yields `"NewFeature"`. - This match expression is forward-compatible when executed with a newer - version of SDK where the variant `$enumName::NewFeature` is defined. - Specifically, when `${enumName.lowercase()}` represents `NewFeature`, - the execution path will hit the second last match arm as before by virtue of - calling `as_str` on `$enumName::NewFeature` also yielding `"NewFeature"`. - """.trimIndent(), - ) - docs("") - docs( - """ - Explicitly matching on the `$unknownVariant` variant should - be avoided for two reasons: - - The inner data `$unknownVariantValue` is opaque, and no further information can be extracted. - - It might inadvertently shadow other intended match arms. - """.trimIndent(), - ) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt index 359386e34da..c50ee557077 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/LibRsGenerator.kt @@ -44,7 +44,9 @@ class LibRsGenerator( val libraryDocs = settings.getService(model).getTrait()?.value ?: settings.moduleName containerDocs(escape(libraryDocs)) - val crateLayout = customizations.map { it.section(LibRsSection.ModuleDocumentation(LibRsSection.CrateOrganization)) }.filter { !it.isEmpty() } + val crateLayout = customizations.map { + it.section(LibRsSection.ModuleDocumentation(LibRsSection.CrateOrganization)) + }.filter { !it.isEmpty() } if (crateLayout.isNotEmpty()) { containerDocs("\n## Crate Organization") crateLayout.forEach { it(this) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt index 7e88d3d1fad..e91bf3b8718 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGenerator.kt @@ -6,16 +6,13 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.SensitiveTrait import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.asDeref import software.amazon.smithy.rust.codegen.core.rustlang.asRef import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape @@ -25,43 +22,55 @@ import software.amazon.smithy.rust.codegen.core.rustlang.isDeref import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorGenerator import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary -fun RustWriter.implBlock(structureShape: Shape, symbolProvider: SymbolProvider, block: Writable) { - rustBlock("impl ${symbolProvider.toSymbol(structureShape).name}") { - block() - } +/** StructureGenerator customization sections */ +sealed class StructureSection(name: String) : Section(name) { + abstract val shape: StructureShape + + /** Hook to add additional fields to the structure */ + data class AdditionalFields(override val shape: StructureShape) : StructureSection("AdditionalFields") + + /** Hook to add additional fields to the `Debug` impl */ + data class AdditionalDebugFields(override val shape: StructureShape, val formatterName: String) : + StructureSection("AdditionalDebugFields") + + /** Hook to add additional trait impls to the structure */ + data class AdditionalTraitImpls(override val shape: StructureShape, val structName: String) : + StructureSection("AdditionalTraitImpls") } +/** Customizations for StructureGenerator */ +abstract class StructureCustomization : NamedCustomization() + open class StructureGenerator( val model: Model, private val symbolProvider: RustSymbolProvider, private val writer: RustWriter, private val shape: StructureShape, + private val customizations: List, ) { private val errorTrait = shape.getTrait() protected val members: List = shape.allMembers.values.toList() - protected val accessorMembers: List = when (errorTrait) { + private val accessorMembers: List = when (errorTrait) { null -> members // Let the ErrorGenerator render the error message accessor if this is an error struct else -> members.filter { "message" != symbolProvider.toMemberName(it) } } - protected val name = symbolProvider.toSymbol(shape).name + protected val name: String = symbolProvider.toSymbol(shape).name - fun render(forWhom: CodegenTarget = CodegenTarget.CLIENT) { + fun render() { renderStructure() - errorTrait?.also { errorTrait -> - ErrorGenerator(model, symbolProvider, writer, shape, errorTrait).render(forWhom) - } } /** @@ -79,7 +88,9 @@ open class StructureGenerator( }.toSet().sorted() return if (lifetimes.isNotEmpty()) { "<${lifetimes.joinToString { "'$it" }}>" - } else "" + } else { + "" + } } /** @@ -98,6 +109,7 @@ open class StructureGenerator( "formatter.field(${memberName.dq()}, &$fieldValue);", ) } + writeCustomizations(customizations, StructureSection.AdditionalDebugFields(shape, "formatter")) rust("formatter.finish()") } } @@ -150,12 +162,15 @@ open class StructureGenerator( writer.forEachMember(members) { member, memberName, memberSymbol -> renderStructureMember(writer, member, memberName, memberSymbol) } + writeCustomizations(customizations, StructureSection.AdditionalFields(shape)) } renderStructureImpl() if (!containerMeta.hasDebugDerive()) { renderDebugImpl() } + + writer.writeCustomizations(customizations, StructureSection.AdditionalTraitImpls(shape, name)) } protected fun RustWriter.forEachMember( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGenerator.kt index 69cd25f2ea6..dbec900302e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGenerator.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.documentShape +import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate @@ -23,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom +import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.REDACTION import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -112,7 +114,7 @@ class UnionGenerator( if (sortedMembers.size == 1) { Attribute.AllowIrrefutableLetPatterns.render(this) } - writer.renderAsVariant(member, variantName, funcNamePart, unionSymbol, memberSymbol) + writer.renderAsVariant(model, symbolProvider, member, variantName, funcNamePart, unionSymbol) rust("/// Returns true if this is a [`$variantName`](#T::$variantName).", unionSymbol) rustBlock("pub fn is_$funcNamePart(&self) -> bool") { rust("self.as_$funcNamePart().is_ok()") @@ -183,11 +185,12 @@ private fun RustWriter.renderVariant(symbolProvider: SymbolProvider, member: Mem } private fun RustWriter.renderAsVariant( + model: Model, + symbolProvider: SymbolProvider, member: MemberShape, variantName: String, funcNamePart: String, unionSymbol: Symbol, - memberSymbol: Symbol, ) { if (member.isTargetUnit()) { rust( @@ -198,13 +201,15 @@ private fun RustWriter.renderAsVariant( rust("if let ${unionSymbol.name}::$variantName = &self { Ok(()) } else { Err(self) }") } } else { + val memberSymbol = symbolProvider.toSymbol(member) + val targetSymbol = symbolProvider.toSymbol(model.expectShape(member.target)) rust( "/// Tries to convert the enum instance into [`$variantName`](#T::$variantName), extracting the inner #D.", unionSymbol, - memberSymbol, + targetSymbol, ) rust("/// Returns `Err(&Self)` if it can't be converted.") - rustBlock("pub fn as_$funcNamePart(&self) -> std::result::Result<&#T, &Self>", memberSymbol) { + rustBlock("pub fn as_$funcNamePart(&self) -> std::result::Result<&${memberSymbol.rustType().render()}, &Self>") { rust("if let ${unionSymbol.name}::$variantName(val) = &self { Ok(val) } else { Err(self) }") } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGenerator.kt similarity index 83% rename from codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt rename to codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGenerator.kt index 071a5bd89a3..692bf32aab3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators.error +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait @@ -23,6 +24,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.StdError import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization +import software.amazon.smithy.rust.codegen.core.smithy.customize.Section +import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression @@ -34,22 +38,31 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.shouldRedact +/** Error customization sections */ +sealed class ErrorImplSection(name: String) : Section(name) { + /** Use this section to add additional trait implementations to the generated error structures */ + class ErrorAdditionalTraitImpls(val errorType: Symbol) : ErrorImplSection("ErrorAdditionalTraitImpls") +} + +/** Customizations for generated errors */ +abstract class ErrorImplCustomization : NamedCustomization() + sealed class ErrorKind { abstract fun writable(runtimeConfig: RuntimeConfig): Writable object Throttling : ErrorKind() { override fun writable(runtimeConfig: RuntimeConfig) = - writable { rust("#T::ThrottlingError", RuntimeType.errorKind(runtimeConfig)) } + writable { rust("#T::ThrottlingError", RuntimeType.retryErrorKind(runtimeConfig)) } } object Client : ErrorKind() { override fun writable(runtimeConfig: RuntimeConfig) = - writable { rust("#T::ClientError", RuntimeType.errorKind(runtimeConfig)) } + writable { rust("#T::ClientError", RuntimeType.retryErrorKind(runtimeConfig)) } } object Server : ErrorKind() { override fun writable(runtimeConfig: RuntimeConfig) = - writable { rust("#T::ServerError", RuntimeType.errorKind(runtimeConfig)) } + writable { rust("#T::ServerError", RuntimeType.retryErrorKind(runtimeConfig)) } } } @@ -69,19 +82,22 @@ fun StructureShape.modeledRetryKind(errorTrait: ErrorTrait): ErrorKind? { } } -class ErrorGenerator( +class ErrorImplGenerator( private val model: Model, private val symbolProvider: RustSymbolProvider, private val writer: RustWriter, private val shape: StructureShape, private val error: ErrorTrait, + private val customizations: List, ) { + private val runtimeConfig = symbolProvider.config.runtimeConfig + fun render(forWhom: CodegenTarget = CodegenTarget.CLIENT) { val symbol = symbolProvider.toSymbol(shape) val messageShape = shape.errorMessageMember() - val errorKindT = RuntimeType.errorKind(symbolProvider.config().runtimeConfig) + val errorKindT = RuntimeType.retryErrorKind(runtimeConfig) writer.rustBlock("impl ${symbol.name}") { - val retryKindWriteable = shape.modeledRetryKind(error)?.writable(symbolProvider.config().runtimeConfig) + val retryKindWriteable = shape.modeledRetryKind(error)?.writable(runtimeConfig) if (retryKindWriteable != null) { rust("/// Returns `Some(${errorKindT.name})` if the error is retryable. Otherwise, returns `None`.") rustBlock("pub fn retryable_error_kind(&self) -> #T", errorKindT) { @@ -153,6 +169,9 @@ class ErrorGenerator( write("Ok(())") } } + writer.write("impl #T for ${symbol.name} {}", StdError) + + writer.writeCustomizations(customizations, ErrorImplSection.ErrorAdditionalTraitImpls(symbol)) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/OperationErrorGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/OperationErrorGenerator.kt deleted file mode 100644 index 69f1ad8639a..00000000000 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/OperationErrorGenerator.kt +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.core.smithy.generators.error - -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.RetryableTrait -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.rustlang.Visibility -import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape -import software.amazon.smithy.rust.codegen.core.rustlang.docs -import software.amazon.smithy.rust.codegen.core.rustlang.documentShape -import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.customize.Section -import software.amazon.smithy.rust.codegen.core.util.hasTrait -import software.amazon.smithy.rust.codegen.core.util.toSnakeCase - -/** - * For a given Operation ([this]), return the symbol referring to the operation error. This can be used - * if you, e.g. want to return an operation error from a function: - * - * ```kotlin - * rustWriter.rustBlock("fn get_error() -> #T", operation.errorSymbol(symbolProvider)) { - * write("todo!() // function body") - * } - * ``` - */ -fun OperationShape.errorSymbol(symbolProvider: RustSymbolProvider): RuntimeType { - val operationSymbol = symbolProvider.toSymbol(this) - return RustModule.Error.toType().resolve("${operationSymbol.name}Error") -} - -fun UnionShape.eventStreamErrorSymbol(symbolProvider: RustSymbolProvider): RuntimeType { - val unionSymbol = symbolProvider.toSymbol(this) - return RustModule.Error.toType().resolve("${unionSymbol.name}Error") -} - -/** - * Generates a unified error enum for [operation]. [ErrorGenerator] handles generating the individual variants, - * but we must still combine those variants into an enum covering all possible errors for a given operation. - */ -class OperationErrorGenerator( - private val model: Model, - private val symbolProvider: RustSymbolProvider, - private val operationSymbol: Symbol, - private val errors: List, -) { - private val runtimeConfig = symbolProvider.config().runtimeConfig - private val genericError = RuntimeType.genericError(symbolProvider.config().runtimeConfig) - private val createUnhandledError = - RuntimeType.smithyHttp(runtimeConfig).resolve("result::CreateUnhandledError") - - fun render(writer: RustWriter) { - val errorSymbol = RuntimeType("crate::error::${operationSymbol.name}Error") - renderErrors(writer, errorSymbol, operationSymbol) - } - - fun renderErrors( - writer: RustWriter, - errorSymbol: RuntimeType, - operationSymbol: Symbol, - ) { - val meta = RustMetadata( - derives = setOf(RuntimeType.Debug), - additionalAttributes = listOf(Attribute.NonExhaustive), - visibility = Visibility.PUBLIC, - ) - - writer.rust("/// Error type for the `${operationSymbol.name}` operation.") - meta.render(writer) - writer.rustBlock("struct ${errorSymbol.name}") { - rust( - """ - /// Kind of error that occurred. - pub kind: ${errorSymbol.name}Kind, - /// Additional metadata about the error, including error code, message, and request ID. - pub (crate) meta: #T - """, - RuntimeType.genericError(runtimeConfig), - ) - } - writer.rustBlock("impl #T for ${errorSymbol.name}", createUnhandledError) { - rustBlock("fn create_unhandled_error(source: Box) -> Self") { - rustBlock("Self") { - rust("kind: ${errorSymbol.name}Kind::Unhandled(#T::new(source)),", unhandledError()) - rust("meta: Default::default()") - } - } - } - - writer.rust("/// Types of errors that can occur for the `${operationSymbol.name}` operation.") - meta.render(writer) - writer.rustBlock("enum ${errorSymbol.name}Kind") { - errors.forEach { errorVariant -> - documentShape(errorVariant, model) - deprecatedShape(errorVariant) - val errorVariantSymbol = symbolProvider.toSymbol(errorVariant) - write("${errorVariantSymbol.name}(#T),", errorVariantSymbol) - } - docs(UNHANDLED_ERROR_DOCS) - rust( - """ - Unhandled(#T), - """, - unhandledError(), - ) - } - writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.Display) { - rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") { - delegateToVariants(errors, errorSymbol) { - writable { rust("_inner.fmt(f)") } - } - } - } - - val errorKindT = RuntimeType.errorKind(symbolProvider.config().runtimeConfig) - writer.rustBlock( - "impl #T for ${errorSymbol.name}", - RuntimeType.provideErrorKind(symbolProvider.config().runtimeConfig), - ) { - rustBlock("fn code(&self) -> Option<&str>") { - rust("${errorSymbol.name}::code(self)") - } - - rustBlock("fn retryable_error_kind(&self) -> Option<#T>", errorKindT) { - val retryableVariants = errors.filter { it.hasTrait() } - if (retryableVariants.isEmpty()) { - rust("None") - } else { - rustBlock("match &self.kind") { - retryableVariants.forEach { - val errorVariantSymbol = symbolProvider.toSymbol(it) - rust("${errorSymbol.name}Kind::${errorVariantSymbol.name}(inner) => Some(inner.retryable_error_kind()),") - } - rust("_ => None") - } - } - } - } - - writer.rustBlock("impl ${errorSymbol.name}") { - writer.rustTemplate( - """ - /// Creates a new `${errorSymbol.name}`. - pub fn new(kind: ${errorSymbol.name}Kind, meta: #{generic_error}) -> Self { - Self { kind, meta } - } - - /// Creates the `${errorSymbol.name}::Unhandled` variant from any error type. - pub fn unhandled(err: impl Into>) -> Self { - Self { - kind: ${errorSymbol.name}Kind::Unhandled(#{Unhandled}::new(err.into())), - meta: Default::default() - } - } - - /// Creates the `${errorSymbol.name}::Unhandled` variant from a `#{generic_error}`. - pub fn generic(err: #{generic_error}) -> Self { - Self { - meta: err.clone(), - kind: ${errorSymbol.name}Kind::Unhandled(#{Unhandled}::new(err.into())), - } - } - - /// Returns the error message if one is available. - pub fn message(&self) -> Option<&str> { - self.meta.message() - } - - /// Returns error metadata, which includes the error code, message, - /// request ID, and potentially additional information. - pub fn meta(&self) -> &#{generic_error} { - &self.meta - } - - /// Returns the request ID if it's available. - pub fn request_id(&self) -> Option<&str> { - self.meta.request_id() - } - - /// Returns the error code if it's available. - pub fn code(&self) -> Option<&str> { - self.meta.code() - } - """, - "generic_error" to genericError, - "std_error" to RuntimeType.StdError, - "Unhandled" to unhandledError(), - ) - errors.forEach { error -> - val errorVariantSymbol = symbolProvider.toSymbol(error) - val fnName = errorVariantSymbol.name.toSnakeCase() - writer.rust("/// Returns `true` if the error kind is `${errorSymbol.name}Kind::${errorVariantSymbol.name}`.") - writer.rustBlock("pub fn is_$fnName(&self) -> bool") { - rust("matches!(&self.kind, ${errorSymbol.name}Kind::${errorVariantSymbol.name}(_))") - } - } - } - - writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.StdError) { - rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) { - delegateToVariants(errors, errorSymbol) { - writable { - rust("Some(_inner)") - } - } - } - } - } - - sealed class VariantMatch(name: String) : Section(name) { - object Unhandled : VariantMatch("Unhandled") - data class Modeled(val symbol: Symbol, val shape: Shape) : VariantMatch("Modeled") - } - - /** - * Generates code to delegate behavior to the variants, for example: - * - * ```rust - * match &self.kind { - * GreetingWithErrorsError::InvalidGreeting(_inner) => inner.fmt(f), - * GreetingWithErrorsError::ComplexError(_inner) => inner.fmt(f), - * GreetingWithErrorsError::FooError(_inner) => inner.fmt(f), - * GreetingWithErrorsError::Unhandled(_inner) => _inner.fmt(f), - * } - * ``` - * - * [handler] is passed an instance of [VariantMatch]—a [writable] should be returned containing the content to be - * written for this variant. - * - * The field will always be bound as `_inner`. - */ - private fun RustWriter.delegateToVariants( - errors: List, - symbol: RuntimeType, - handler: (VariantMatch) -> Writable, - ) { - rustBlock("match &self.kind") { - errors.forEach { - val errorSymbol = symbolProvider.toSymbol(it) - rust("""${symbol.name}Kind::${errorSymbol.name}(_inner) => """) - handler(VariantMatch.Modeled(errorSymbol, it))(this) - write(",") - } - val unhandledHandler = handler(VariantMatch.Unhandled) - rustBlock("${symbol.name}Kind::Unhandled(_inner) =>") { - unhandledHandler(this) - } - } - } -} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/UnhandledErrorGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/UnhandledErrorGenerator.kt deleted file mode 100644 index bd28c1a8524..00000000000 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/UnhandledErrorGenerator.kt +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.core.smithy.generators.error - -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.docs -import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType - -internal const val UNHANDLED_ERROR_DOCS = - """ - An unexpected error occurred (e.g., invalid JSON returned by the service or an unknown error code). - - When logging an error from the SDK, it is recommended that you either wrap the error in - [`DisplayErrorContext`](crate::types::DisplayErrorContext), use another - error reporter library that visits the error's cause/source chain, or call - [`Error::source`](std::error::Error::source) for more details about the underlying cause. - """ - -internal fun unhandledError(): RuntimeType = RuntimeType.forInlineFun("Unhandled", RustModule.Error) { - docs(UNHANDLED_ERROR_DOCS) - rustTemplate( - """ - ##[derive(Debug)] - pub struct Unhandled { - source: Box, - } - impl Unhandled { - ##[allow(unused)] - pub(crate) fn new(source: Box) -> Self { - Self { source } - } - } - impl std::fmt::Display for Unhandled { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - write!(f, "unhandled error") - } - } - impl #{StdError} for Unhandled { - fn source(&self) -> Option<&(dyn #{StdError} + 'static)> { - Some(self.source.as_ref() as _) - } - } - """, - "StdError" to RuntimeType.StdError, - ) -} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt index b2ca5e066fe..c784e8d1e66 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/http/HttpBindingGenerator.kt @@ -115,8 +115,6 @@ class HttpBindingGenerator( private val codegenContext: CodegenContext, private val symbolProvider: SymbolProvider, private val operationShape: OperationShape, - /** Function that maps a StructureShape into its builder symbol */ - private val builderSymbol: (StructureShape) -> Symbol, private val customizations: List = listOf(), ) { private val runtimeConfig = codegenContext.runtimeConfig @@ -210,7 +208,7 @@ class HttpBindingGenerator( */ fun generateDeserializePayloadFn( binding: HttpBindingDescriptor, - errorT: RuntimeType, + errorSymbol: Symbol, // Deserialize a single structure, union or document member marked as a payload payloadParser: RustWriter.(String) -> Unit, httpMessageType: HttpMessageType = HttpMessageType.RESPONSE, @@ -224,7 +222,7 @@ class HttpBindingGenerator( "pub fn $fnName(body: &mut #T) -> std::result::Result<#T, #T>", RuntimeType.sdkBody(runtimeConfig), outputT, - errorT, + errorSymbol, ) { // Streaming unions are Event Streams and should be handled separately val target = model.expectShape(binding.member.target) @@ -238,10 +236,10 @@ class HttpBindingGenerator( // The output needs to be Optional when deserializing the payload body or the caller signature // will not match. val outputT = symbolProvider.toSymbol(binding.member).makeOptional() - rustBlock("pub fn $fnName(body: &[u8]) -> std::result::Result<#T, #T>", outputT, errorT) { + rustBlock("pub fn $fnName(body: &[u8]) -> std::result::Result<#T, #T>", outputT, errorSymbol) { deserializePayloadBody( binding, - errorT, + errorSymbol, structuredHandler = payloadParser, httpMessageType, ) @@ -256,7 +254,6 @@ class HttpBindingGenerator( codegenContext, operationShape, targetShape, - builderSymbol, ).render() rustTemplate( """ @@ -286,7 +283,7 @@ class HttpBindingGenerator( private fun RustWriter.deserializePayloadBody( binding: HttpBindingDescriptor, - errorSymbol: RuntimeType, + errorSymbol: Symbol, structuredHandler: RustWriter.(String) -> Unit, httpMessageType: HttpMessageType = HttpMessageType.RESPONSE, ) { @@ -460,7 +457,7 @@ class HttpBindingGenerator( // Rename here technically not required, operations and members cannot be renamed. private fun fnName(operationShape: OperationShape, binding: HttpBindingDescriptor) = "${ - operationShape.id.getName(service).toSnakeCase() + operationShape.id.getName(service).toSnakeCase() }_${binding.member.container.name.toSnakeCase()}_${binding.memberName.toSnakeCase()}" /** @@ -650,12 +647,13 @@ class HttpBindingGenerator( let $safeName = $formatted; if !$safeName.is_empty() { let header_value = $safeName; - let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { + let header_value: #{HeaderValue} = header_value.parse().map_err(|err| { #{invalid_field_error:W} })?; builder = builder.header("$headerName", header_value); } """, + "HeaderValue" to RuntimeType.Http.resolve("HeaderValue"), "invalid_field_error" to renderErrorMessage("header_value"), ) } @@ -690,21 +688,22 @@ class HttpBindingGenerator( #{invalid_header_name:W} })?; let header_value = ${ - headerFmtFun( - this, - valueTargetShape, - timestampFormat, - "v", - isMultiValuedHeader = false, - ) + headerFmtFun( + this, + valueTargetShape, + timestampFormat, + "v", + isMultiValuedHeader = false, + ) }; - let header_value = http::header::HeaderValue::try_from(&*header_value).map_err(|err| { + let header_value: #{HeaderValue} = header_value.parse().map_err(|err| { #{invalid_header_value:W} })?; builder = builder.header(header_name, header_value); } """, + "HeaderValue" to RuntimeType.Http.resolve("HeaderValue"), "invalid_header_name" to OperationBuildError(runtimeConfig).invalidField(memberName) { rust("""format!("`{k}` cannot be used as a header name: {err}")""") }, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index ff60da3ebe6..9bc32bfa5cc 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -19,7 +19,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator @@ -128,7 +127,7 @@ open class AwsJson( private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, - "Error" to RuntimeType.genericError(runtimeConfig), + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), "HeaderMap" to RuntimeType.Http.resolve("HeaderMap"), "JsonError" to CargoDependency.smithyJson(runtimeConfig).toType() .resolve("deserialize::error::DeserializeError"), @@ -152,32 +151,31 @@ open class AwsJson( codegenContext, httpBindingResolver, ::awsJsonFieldName, - builderSymbolFn(codegenContext.symbolProvider), ) } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = AwsJsonSerializerGenerator(codegenContext, httpBindingResolver) - override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_http_generic_error", jsonDeserModule) { + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_http_error_metadata", jsonDeserModule) { rustTemplate( """ - pub fn parse_http_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{JsonError}> { - #{json_errors}::parse_generic_error(response.body(), response.headers()) + pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { + #{json_errors}::parse_error_metadata(response.body(), response.headers()) } """, *errorScope, ) } - override fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_event_stream_generic_error", jsonDeserModule) { + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_event_stream_error_metadata", jsonDeserModule) { rustTemplate( """ - pub fn parse_event_stream_generic_error(payload: &#{Bytes}) -> Result<#{Error}, #{JsonError}> { + pub fn parse_event_stream_error_metadata(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { // Note: HeaderMap::new() doesn't allocate - #{json_errors}::parse_generic_error(payload, &#{HeaderMap}::new()) + #{json_errors}::parse_error_metadata(payload, &#{HeaderMap}::new()) } """, *errorScope, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt index 6bb7a62a58e..1914076175a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQuery.kt @@ -6,11 +6,9 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.aws.traits.protocols.AwsQueryErrorTrait -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait @@ -19,7 +17,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.AwsQueryParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.AwsQuerySerializerGenerator @@ -45,7 +42,7 @@ class AwsQueryProtocol(private val codegenContext: CodegenContext) : Protocol { private val awsQueryErrors: RuntimeType = RuntimeType.wrappedXmlErrors(runtimeConfig) private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, - "Error" to RuntimeType.genericError(runtimeConfig), + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), "HeaderMap" to RuntimeType.HttpHeaderMap, "Response" to RuntimeType.HttpResponse, "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), @@ -56,32 +53,29 @@ class AwsQueryProtocol(private val codegenContext: CodegenContext) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME - override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - fun builderSymbol(shape: StructureShape): Symbol = - shape.builderSymbol(codegenContext.symbolProvider) - return AwsQueryParserGenerator(codegenContext, awsQueryErrors, ::builderSymbol) - } + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = + AwsQueryParserGenerator(codegenContext, awsQueryErrors) override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = AwsQuerySerializerGenerator(codegenContext) - override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_http_generic_error", xmlDeserModule) { + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_http_error_metadata", xmlDeserModule) { rustBlockTemplate( - "pub fn parse_http_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{XmlDecodeError}>", + "pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{XmlDecodeError}>", *errorScope, ) { - rust("#T::parse_generic_error(response.body().as_ref())", awsQueryErrors) + rust("#T::parse_error_metadata(response.body().as_ref())", awsQueryErrors) } } - override fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_event_stream_generic_error", xmlDeserModule) { + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_event_stream_error_metadata", xmlDeserModule) { rustBlockTemplate( - "pub fn parse_event_stream_generic_error(payload: &#{Bytes}) -> Result<#{Error}, #{XmlDecodeError}>", + "pub fn parse_event_stream_error_metadata(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{XmlDecodeError}>", *errorScope, ) { - rust("#T::parse_generic_error(payload.as_ref())", awsQueryErrors) + rust("#T::parse_error_metadata(payload.as_ref())", awsQueryErrors) } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt new file mode 100644 index 00000000000..32f9fdbb600 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt @@ -0,0 +1,97 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.protocols + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ToShapeId +import software.amazon.smithy.model.traits.HttpTrait +import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator +import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator + +class AwsQueryCompatibleHttpBindingResolver( + private val awsQueryBindingResolver: AwsQueryBindingResolver, + private val awsJsonHttpBindingResolver: AwsJsonHttpBindingResolver, +) : HttpBindingResolver { + override fun httpTrait(operationShape: OperationShape): HttpTrait = + awsJsonHttpBindingResolver.httpTrait(operationShape) + + override fun requestBindings(operationShape: OperationShape): List = + awsJsonHttpBindingResolver.requestBindings(operationShape) + + override fun responseBindings(operationShape: OperationShape): List = + awsJsonHttpBindingResolver.responseBindings(operationShape) + + override fun errorResponseBindings(errorShape: ToShapeId): List = + awsJsonHttpBindingResolver.errorResponseBindings(errorShape) + + override fun errorCode(errorShape: ToShapeId): String = + awsQueryBindingResolver.errorCode(errorShape) + + override fun requestContentType(operationShape: OperationShape): String = + awsJsonHttpBindingResolver.requestContentType(operationShape) + + override fun responseContentType(operationShape: OperationShape): String = + awsJsonHttpBindingResolver.requestContentType(operationShape) +} + +class AwsQueryCompatible( + val codegenContext: CodegenContext, + private val awsJson: AwsJson, +) : Protocol { + private val runtimeConfig = codegenContext.runtimeConfig + private val errorScope = arrayOf( + "Bytes" to RuntimeType.Bytes, + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), + "JsonError" to CargoDependency.smithyJson(runtimeConfig).toType() + .resolve("deserialize::error::DeserializeError"), + "Response" to RuntimeType.Http.resolve("Response"), + "json_errors" to RuntimeType.jsonErrors(runtimeConfig), + "aws_query_compatible_errors" to RuntimeType.awsQueryCompatibleErrors(runtimeConfig), + ) + private val jsonDeserModule = RustModule.private("json_deser") + + override val httpBindingResolver: HttpBindingResolver = + AwsQueryCompatibleHttpBindingResolver( + AwsQueryBindingResolver(codegenContext.model), + AwsJsonHttpBindingResolver(codegenContext.model, awsJson.version), + ) + + override val defaultTimestampFormat = awsJson.defaultTimestampFormat + + override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator = + awsJson.structuredDataParser(operationShape) + + override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + awsJson.structuredDataSerializer(operationShape) + + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_http_error_metadata", jsonDeserModule) { + rustTemplate( + """ + pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { + let mut builder = + #{json_errors}::parse_error_metadata(response.body(), response.headers())?; + if let Some((error_code, error_type)) = + #{aws_query_compatible_errors}::parse_aws_query_compatible_error(response.headers()) + { + builder = builder.code(error_code); + builder = builder.custom("type", error_type); + } + Ok(builder) + } + """, + *errorScope, + ) + } + + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + awsJson.parseEventStreamErrorMetadata(operationShape) +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt index c388b8e85b4..8e6c6c063c2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Ec2Query.kt @@ -5,10 +5,8 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.pattern.UriPattern import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.core.rustlang.RustModule @@ -16,7 +14,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.Ec2QueryParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.Ec2QuerySerializerGenerator @@ -27,7 +24,7 @@ class Ec2QueryProtocol(private val codegenContext: CodegenContext) : Protocol { private val ec2QueryErrors: RuntimeType = RuntimeType.ec2QueryErrors(runtimeConfig) private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, - "Error" to RuntimeType.genericError(runtimeConfig), + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), "HeaderMap" to RuntimeType.HttpHeaderMap, "Response" to RuntimeType.HttpResponse, "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), @@ -48,31 +45,29 @@ class Ec2QueryProtocol(private val codegenContext: CodegenContext) : Protocol { override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - fun builderSymbol(shape: StructureShape): Symbol = - shape.builderSymbol(codegenContext.symbolProvider) - return Ec2QueryParserGenerator(codegenContext, ec2QueryErrors, ::builderSymbol) + return Ec2QueryParserGenerator(codegenContext, ec2QueryErrors) } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = Ec2QuerySerializerGenerator(codegenContext) - override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_http_generic_error", xmlDeserModule) { + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_http_error_metadata", xmlDeserModule) { rustBlockTemplate( - "pub fn parse_http_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{XmlDecodeError}>", + "pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{XmlDecodeError}>", *errorScope, ) { - rust("#T::parse_generic_error(response.body().as_ref())", ec2QueryErrors) + rust("#T::parse_error_metadata(response.body().as_ref())", ec2QueryErrors) } } - override fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_event_stream_generic_error", xmlDeserModule) { + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_event_stream_error_metadata", xmlDeserModule) { rustBlockTemplate( - "pub fn parse_event_stream_generic_error(payload: &#{Bytes}) -> Result<#{Error}, #{XmlDecodeError}>", + "pub fn parse_event_stream_error_metadata(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{XmlDecodeError}>", *errorScope, ) { - rust("#T::parse_generic_error(payload.as_ref())", ec2QueryErrors) + rust("#T::parse_error_metadata(payload.as_ref())", ec2QueryErrors) } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt index c5d93ae3b00..1b3e4b4a9ae 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/Protocol.kt @@ -46,21 +46,21 @@ interface Protocol { /** * Generates a function signature like the following: * ```rust - * fn parse_http_generic_error(response: &Response) -> aws_smithy_types::error::Error + * fn parse_http_error_metadata(response: &Response) -> aws_smithy_types::error::Builder * ``` */ - fun parseHttpGenericError(operationShape: OperationShape): RuntimeType + fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType /** * Generates a function signature like the following: * ```rust - * fn parse_event_stream_generic_error(payload: &Bytes) -> aws_smithy_types::error::Error + * fn parse_event_stream_error_metadata(payload: &Bytes) -> aws_smithy_types::error::Error * ``` * * Event Stream generic errors are almost identical to HTTP generic errors, except that * there are no response headers or statuses available to further inform the error parsing. */ - fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType + fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType } typealias ProtocolMap = Map> diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index cbcd2d511b0..0bbb1c5f5c0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape @@ -20,7 +19,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator @@ -66,7 +64,7 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, - "Error" to RuntimeType.genericError(runtimeConfig), + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), "HeaderMap" to RuntimeType.Http.resolve("HeaderMap"), "JsonError" to CargoDependency.smithyJson(runtimeConfig).toType() .resolve("deserialize::error::DeserializeError"), @@ -94,33 +92,31 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { listOf("x-amzn-errortype" to errorShape.id.toString()) override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - fun builderSymbol(shape: StructureShape): Symbol = - shape.builderSymbol(codegenContext.symbolProvider) - return JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName, ::builderSymbol) + return JsonParserGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = JsonSerializerGenerator(codegenContext, httpBindingResolver, ::restJsonFieldName) - override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_http_generic_error", jsonDeserModule) { + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_http_error_metadata", jsonDeserModule) { rustTemplate( """ - pub fn parse_http_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{JsonError}> { - #{json_errors}::parse_generic_error(response.body(), response.headers()) + pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { + #{json_errors}::parse_error_metadata(response.body(), response.headers()) } """, *errorScope, ) } - override fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_event_stream_generic_error", jsonDeserModule) { + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_event_stream_error_metadata", jsonDeserModule) { rustTemplate( """ - pub fn parse_event_stream_generic_error(payload: &#{Bytes}) -> Result<#{Error}, #{JsonError}> { + pub fn parse_event_stream_error_metadata(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> { // Note: HeaderMap::new() doesn't allocate - #{json_errors}::parse_generic_error(payload, &#{HeaderMap}::new()) + #{json_errors}::parse_error_metadata(payload, &#{HeaderMap}::new()) } """, *errorScope, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt index 44a9631ef74..eba23589995 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt @@ -6,16 +6,13 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.aws.traits.protocols.RestXmlTrait -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.RestXmlParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator @@ -27,7 +24,7 @@ open class RestXml(val codegenContext: CodegenContext) : Protocol { private val runtimeConfig = codegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, - "Error" to RuntimeType.genericError(runtimeConfig), + "ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig), "HeaderMap" to RuntimeType.HttpHeaderMap, "Response" to RuntimeType.HttpResponse, "XmlDecodeError" to RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"), @@ -46,32 +43,30 @@ open class RestXml(val codegenContext: CodegenContext) : Protocol { TimestampFormatTrait.Format.DATE_TIME override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - fun builderSymbol(shape: StructureShape): Symbol = - shape.builderSymbol(codegenContext.symbolProvider) - return RestXmlParserGenerator(codegenContext, restXmlErrors, ::builderSymbol) + return RestXmlParserGenerator(codegenContext, restXmlErrors) } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator { return XmlBindingTraitSerializerGenerator(codegenContext, httpBindingResolver) } - override fun parseHttpGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_http_generic_error", xmlDeserModule) { + override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_http_error_metadata", xmlDeserModule) { rustBlockTemplate( - "pub fn parse_http_generic_error(response: &#{Response}<#{Bytes}>) -> Result<#{Error}, #{XmlDecodeError}>", + "pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{XmlDecodeError}>", *errorScope, ) { - rust("#T::parse_generic_error(response.body().as_ref())", restXmlErrors) + rust("#T::parse_error_metadata(response.body().as_ref())", restXmlErrors) } } - override fun parseEventStreamGenericError(operationShape: OperationShape): RuntimeType = - RuntimeType.forInlineFun("parse_event_stream_generic_error", xmlDeserModule) { + override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType = + RuntimeType.forInlineFun("parse_event_stream_error_metadata", xmlDeserModule) { rustBlockTemplate( - "pub fn parse_event_stream_generic_error(payload: &#{Bytes}) -> Result<#{Error}, #{XmlDecodeError}>", + "pub fn parse_event_stream_error_metadata(payload: &#{Bytes}) -> Result<#{ErrorMetadataBuilder}, #{XmlDecodeError}>", *errorScope, ) { - rust("#T::parse_generic_error(payload.as_ref())", restXmlErrors) + rust("#T::parse_error_metadata(payload.as_ref())", restXmlErrors) } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt index 2dbe6d72fbe..2fc1f5c6497 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGenerator.kt @@ -5,8 +5,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -29,12 +27,10 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType class AwsQueryParserGenerator( codegenContext: CodegenContext, xmlErrors: RuntimeType, - builderSymbol: (shape: StructureShape) -> Symbol, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( codegenContext, xmlErrors, - builderSymbol, ) { context, inner -> val operationName = codegenContext.symbolProvider.toSymbol(context.shape).name val responseWrapperName = operationName + "Response" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt index e7be46f3bdc..00cdc784d17 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGenerator.kt @@ -5,8 +5,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -27,12 +25,10 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType class Ec2QueryParserGenerator( codegenContext: CodegenContext, xmlErrors: RuntimeType, - builderSymbol: (shape: StructureShape) -> Symbol, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( codegenContext, xmlErrors, - builderSymbol, ) { context, inner -> val operationName = codegenContext.symbolProvider.toSymbol(context.shape).name val responseWrapperName = operationName + "Response" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index e6bf4812f34..4fa98fa6ee0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -33,7 +33,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol @@ -44,13 +43,14 @@ import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.toPascalCase +fun RustModule.Companion.eventStreamSerdeModule(): RustModule.LeafModule = + private("event_stream_serde") + class EventStreamUnmarshallerGenerator( private val protocol: Protocol, codegenContext: CodegenContext, private val operationShape: OperationShape, private val unionShape: UnionShape, - /** Function that maps a StructureShape into its builder symbol */ - private val builderSymbol: (StructureShape) -> Symbol, ) { private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider @@ -60,10 +60,10 @@ class EventStreamUnmarshallerGenerator( private val errorSymbol = if (codegenTarget == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() } else { - unionShape.eventStreamErrorSymbol(symbolProvider).toSymbol() + symbolProvider.symbolForEventStreamError(unionShape) } private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) - private val eventStreamSerdeModule = RustModule.private("event_stream_serde") + private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() private val codegenScope = arrayOf( "Blob" to RuntimeType.blob(runtimeConfig), "expect_fns" to smithyEventStream.resolve("smithy"), @@ -87,15 +87,16 @@ class EventStreamUnmarshallerGenerator( } private fun RustWriter.renderUnmarshaller(unmarshallerType: RuntimeType, unionSymbol: Symbol) { + val unmarshallerTypeName = unmarshallerType.name rust( """ ##[non_exhaustive] ##[derive(Debug)] - pub struct ${unmarshallerType.name}; + pub struct $unmarshallerTypeName; - impl ${unmarshallerType.name} { + impl $unmarshallerTypeName { pub fn new() -> Self { - ${unmarshallerType.name} + $unmarshallerTypeName } } """, @@ -157,6 +158,7 @@ class EventStreamUnmarshallerGenerator( "Output" to unionSymbol, *codegenScope, ) + false -> rustTemplate( "return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", *codegenScope, @@ -182,6 +184,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) } + payloadOnly -> { withBlock("let parsed = ", ";") { renderParseProtocolPayload(unionMember) @@ -192,8 +195,9 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) } + else -> { - rust("let mut builder = #T::default();", builderSymbol(unionStruct)) + rust("let mut builder = #T::default();", symbolProvider.symbolForBuilder(unionStruct)) val payloadMember = unionStruct.members().firstOrNull { it.hasTrait() } if (payloadMember != null) { renderUnmarshallEventPayload(payloadMember) @@ -268,6 +272,7 @@ class EventStreamUnmarshallerGenerator( is BlobShape -> { rustTemplate("#{Blob}::new(message.payload().as_ref())", *codegenScope) } + is StringShape -> { rustTemplate( """ @@ -278,6 +283,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) } + is UnionShape, is StructureShape -> { renderParseProtocolPayload(member) } @@ -306,15 +312,16 @@ class EventStreamUnmarshallerGenerator( CodegenTarget.CLIENT -> { rustTemplate( """ - let generic = match #{parse_generic_error}(message.payload()) { - Ok(generic) => generic, + let generic = match #{parse_error_metadata}(message.payload()) { + Ok(builder) => builder.build(), Err(err) => return Ok(#{UnmarshalledMessage}::Error(#{OpError}::unhandled(err))), }; """, - "parse_generic_error" to protocol.parseEventStreamGenericError(operationShape), + "parse_error_metadata" to protocol.parseEventStreamErrorMetadata(operationShape), *codegenScope, ) } + CodegenTarget.SERVER -> {} } @@ -336,18 +343,16 @@ class EventStreamUnmarshallerGenerator( val target = model.expectShape(member.target, StructureShape::class.java) val parser = protocol.structuredDataParser(operationShape).errorParser(target) if (parser != null) { - rust("let mut builder = #T::default();", builderSymbol(target)) + rust("let mut builder = #T::default();", symbolProvider.symbolForBuilder(target)) rustTemplate( """ builder = #{parser}(&message.payload()[..], builder) .map_err(|err| { #{Error}::unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) })?; + builder.set_meta(Some(generic)); return Ok(#{UnmarshalledMessage}::Error( - #{OpError}::new( - #{OpError}Kind::${member.target.name}(builder.build()), - generic, - ) + #{OpError}::${member.target.name}(builder.build()) )) """, "parser" to parser, @@ -355,11 +360,12 @@ class EventStreamUnmarshallerGenerator( ) } } + CodegenTarget.SERVER -> { val target = model.expectShape(member.target, StructureShape::class.java) val parser = protocol.structuredDataParser(operationShape).errorParser(target) val mut = if (parser != null) { " mut" } else { "" } - rust("let$mut builder = #T::default();", builderSymbol(target)) + rust("let$mut builder = #T::default();", symbolProvider.symbolForBuilder(target)) if (parser != null) { rustTemplate( """ @@ -396,6 +402,7 @@ class EventStreamUnmarshallerGenerator( CodegenTarget.CLIENT -> { rustTemplate("Ok(#{UnmarshalledMessage}::Error(#{OpError}::generic(generic)))", *codegenScope) } + CodegenTarget.SERVER -> { rustTemplate( """ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt index 1e8f8151f7a..f7a38f0fb28 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGenerator.kt @@ -28,7 +28,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.escape -import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate @@ -77,8 +76,6 @@ class JsonParserGenerator( private val httpBindingResolver: HttpBindingResolver, /** Function that maps a MemberShape into a JSON field name */ private val jsonName: (MemberShape) -> String, - /** Function that maps a StructureShape into its builder symbol */ - private val builderSymbol: (StructureShape) -> Symbol, /** * Whether we should parse a value for a shape into its associated unconstrained type. For example, when the shape * is a `StructureShape`, we should construct and return a builder instead of building into the final `struct` the @@ -153,7 +150,9 @@ class JsonParserGenerator( override fun payloadParser(member: MemberShape): RuntimeType { val shape = model.expectShape(member.target) - check(shape is UnionShape || shape is StructureShape || shape is DocumentShape) { "payload parser should only be used on structures & unions" } + check(shape is UnionShape || shape is StructureShape || shape is DocumentShape) { + "payload parser should only be used on structures & unions" + } val fnName = symbolProvider.deserializeFunctionName(shape) + "_payload" return RuntimeType.forInlineFun(fnName, jsonDeserModule) { rustBlockTemplate( @@ -191,7 +190,7 @@ class JsonParserGenerator( } val outputShape = operationShape.outputShape(model) val fnName = symbolProvider.deserializeFunctionName(operationShape) - return structureParser(fnName, builderSymbol(outputShape), httpDocumentMembers) + return structureParser(fnName, symbolProvider.symbolForBuilder(outputShape), httpDocumentMembers) } override fun errorParser(errorShape: StructureShape): RuntimeType? { @@ -199,7 +198,7 @@ class JsonParserGenerator( return null } val fnName = symbolProvider.deserializeFunctionName(errorShape) + "_json_err" - return structureParser(fnName, builderSymbol(errorShape), errorShape.members().toList()) + return structureParser(fnName, symbolProvider.symbolForBuilder(errorShape), errorShape.members().toList()) } private fun orEmptyJson(): RuntimeType = RuntimeType.forInlineFun("or_empty_doc", jsonDeserModule) { @@ -223,7 +222,7 @@ class JsonParserGenerator( } val inputShape = operationShape.inputShape(model) val fnName = symbolProvider.deserializeFunctionName(operationShape) - return structureParser(fnName, builderSymbol(inputShape), includedMembers) + return structureParser(fnName, symbolProvider.symbolForBuilder(inputShape), includedMembers) } private fun RustWriter.expectEndOfTokenStream() { @@ -493,7 +492,11 @@ class JsonParserGenerator( ) { startObjectOrNull { Attribute.AllowUnusedMut.render(this) - rustTemplate("let mut builder = #{Builder}::default();", *codegenScope, "Builder" to builderSymbol(shape)) + rustTemplate( + "let mut builder = #{Builder}::default();", + *codegenScope, + "Builder" to symbolProvider.symbolForBuilder(shape), + ) deserializeStructInner(shape.members()) // Only call `build()` if the builder is not fallible. Otherwise, return the builder. if (returnSymbolToParse.isUnconstrained) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt index f09598e6f17..d37413f29f2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/RestXmlParserGenerator.kt @@ -5,8 +5,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -19,12 +17,10 @@ import software.amazon.smithy.rust.codegen.core.util.orNull class RestXmlParserGenerator( codegenContext: CodegenContext, xmlErrors: RuntimeType, - builderSymbol: (shape: StructureShape) -> Symbol, private val xmlBindingTraitParserGenerator: XmlBindingTraitParserGenerator = XmlBindingTraitParserGenerator( codegenContext, xmlErrors, - builderSymbol, ) { context, inner -> val shapeName = context.outputShapeName // Get the non-synthetic version of the outputShape and check to see if it has the `AllowInvalidXmlRoot` trait diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index 74b71bb7a90..1a05bfbf33c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -7,7 +7,6 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import software.amazon.smithy.aws.traits.customizations.S3UnwrappedXmlOutputTrait import software.amazon.smithy.codegen.core.CodegenException -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex @@ -71,7 +70,6 @@ data class OperationWrapperContext( class XmlBindingTraitParserGenerator( codegenContext: CodegenContext, private val xmlErrors: RuntimeType, - private val builderSymbol: (shape: StructureShape) -> Symbol, private val writeOperationWrapper: RustWriter.(OperationWrapperContext, OperationInnerWriteable) -> Unit, ) : StructuredDataParserGenerator { @@ -131,7 +129,9 @@ class XmlBindingTraitParserGenerator( */ override fun payloadParser(member: MemberShape): RuntimeType { val shape = model.expectShape(member.target) - check(shape is UnionShape || shape is StructureShape) { "payload parser should only be used on structures & unions" } + check(shape is UnionShape || shape is StructureShape) { + "payload parser should only be used on structures & unions" + } val fnName = symbolProvider.deserializeFunctionName(member) return RuntimeType.forInlineFun(fnName, xmlDeserModule) { rustBlock( @@ -188,7 +188,7 @@ class XmlBindingTraitParserGenerator( Attribute.AllowUnusedMut.render(this) rustBlock( "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", - builderSymbol(outputShape), + symbolProvider.symbolForBuilder(outputShape), xmlDecodeError, ) { rustTemplate( @@ -221,7 +221,7 @@ class XmlBindingTraitParserGenerator( Attribute.AllowUnusedMut.render(this) rustBlock( "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", - builderSymbol(errorShape), + symbolProvider.symbolForBuilder(errorShape), xmlDecodeError, ) { val members = errorShape.errorXmlMembers() @@ -255,7 +255,7 @@ class XmlBindingTraitParserGenerator( Attribute.AllowUnusedMut.render(this) rustBlock( "pub fn $fnName(inp: &[u8], mut builder: #1T) -> Result<#1T, #2T>", - builderSymbol(inputShape), + symbolProvider.symbolForBuilder(inputShape), xmlDecodeError, ) { rustTemplate( @@ -456,7 +456,7 @@ class XmlBindingTraitParserGenerator( private fun RustWriter.case(member: MemberShape, inner: Writable) { rustBlock( "s if ${ - member.xmlName().matchExpression("s") + member.xmlName().matchExpression("s") } /* ${member.memberName} ${escape(member.id.toString())} */ => ", ) { inner() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt index e4ec2cec3ac..3a0c5c1b30b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt @@ -23,9 +23,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.unknownVariantError +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStreamSerdeModule import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticEventStreamUnionTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors @@ -48,9 +48,9 @@ class EventStreamErrorMarshallerGenerator( private val operationErrorSymbol = if (target == CodegenTarget.SERVER && unionShape.eventStreamErrors().isEmpty()) { RuntimeType.smithyHttp(runtimeConfig).resolve("event_stream::MessageStreamError").toSymbol() } else { - unionShape.eventStreamErrorSymbol(symbolProvider).toSymbol() + symbolProvider.symbolForEventStreamError(unionShape) } - private val eventStreamSerdeModule = RustModule.private("event_stream_serde") + private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() private val errorsShape = unionShape.expectTrait() private val codegenScope = arrayOf( "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), @@ -96,25 +96,15 @@ class EventStreamErrorMarshallerGenerator( ) { rust("let mut headers = Vec::new();") addStringHeader(":message-type", """"exception".into()""") - val kind = when (target) { - CodegenTarget.CLIENT -> ".kind" - CodegenTarget.SERVER -> "" - } if (errorsShape.errorMembers.isEmpty()) { rust("let payload = Vec::new();") } else { - rustBlock("let payload = match _input$kind") { - val symbol = operationErrorSymbol - val errorName = when (target) { - CodegenTarget.CLIENT -> "${symbol}Kind" - CodegenTarget.SERVER -> "$symbol" - } - + rustBlock("let payload = match _input") { errorsShape.errorMembers.forEach { error -> - val errorSymbol = symbolProvider.toSymbol(error) val errorString = error.memberName val target = model.expectShape(error.target, StructureShape::class.java) - rustBlock("$errorName::${errorSymbol.name}(inner) => ") { + val targetSymbol = symbolProvider.toSymbol(target) + rustBlock("#T::${targetSymbol.name}(inner) => ", operationErrorSymbol) { addStringHeader(":exception-type", "${errorString.dq()}.into()") renderMarshallEvent(error, target) } @@ -122,11 +112,12 @@ class EventStreamErrorMarshallerGenerator( if (target.renderUnknownVariant()) { rustTemplate( """ - $errorName::Unhandled(_inner) => return Err( + #{OperationError}::Unhandled(_inner) => return Err( #{Error}::marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) ), """, *codegenScope, + "OperationError" to operationErrorSymbol, ) } } @@ -136,7 +127,7 @@ class EventStreamErrorMarshallerGenerator( } } - fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) { + private fun RustWriter.renderMarshallEvent(unionMember: MemberShape, eventStruct: StructureShape) { val headerMembers = eventStruct.members().filter { it.hasTrait() } val payloadMember = eventStruct.members().firstOrNull { it.hasTrait() } for (member in headerMembers) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt index cb6833aaf75..201cd82ed5d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt @@ -38,6 +38,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant import software.amazon.smithy.rust.codegen.core.smithy.generators.unknownVariantError import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStreamSerdeModule import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.hasTrait @@ -53,7 +54,7 @@ open class EventStreamMarshallerGenerator( private val payloadContentType: String, ) { private val smithyEventStream = RuntimeType.smithyEventStream(runtimeConfig) - private val eventStreamSerdeModule = RustModule.private("event_stream_serde") + private val eventStreamSerdeModule = RustModule.eventStreamSerdeModule() private val codegenScope = arrayOf( "MarshallMessage" to smithyEventStream.resolve("frame::MarshallMessage"), "Message" to smithyEventStream.resolve("frame::Message"), diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt index 3d69218e253..81430c51fb2 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt @@ -157,7 +157,7 @@ class XmlBindingTraitSerializerGenerator( let mut writer = #{XmlWriter}::new(&mut out); ##[allow(unused_mut)] let mut root = writer.start_el(${xmlIndex.payloadShapeName(member).dq()})${ - target.xmlNamespace(root = true).apply() + target.xmlNamespace(root = true).apply() }; """, *codegenScope, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt index 2d5a3134012..41144d5945c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/traits/RustBoxTrait.kt @@ -12,7 +12,8 @@ import software.amazon.smithy.model.traits.Trait /** * Trait indicating that this shape should be represented with `Box` when converted into Rust * - * This is used to handle recursive shapes. See RecursiveShapeBoxer. + * This is used to handle recursive shapes. + * See [software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer]. * * This trait is synthetic, applied during code generation, and never used in actual models. */ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt index 1a00d431aea..78a6852fd6b 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/OperationNormalizer.kt @@ -59,7 +59,9 @@ object OperationNormalizer { val shapeConflict = newShapes.firstOrNull { shape -> model.getShape(shape.id).isPresent } check( shapeConflict == null, - ) { "shape $shapeConflict conflicted with an existing shape in the model (${model.getShape(shapeConflict!!.id)}. This is a bug." } + ) { + "shape $shapeConflict conflicted with an existing shape in the model (${model.getShape(shapeConflict!!.id)}. This is a bug." + } val modelWithOperationInputs = model.toBuilder().addShapes(newShapes).build() return transformer.mapShapes(modelWithOperationInputs) { // Update all operations to point to their new input/output shapes diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt index 4b47801a8d5..d53751829fb 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxer.kt @@ -7,25 +7,50 @@ package software.amazon.smithy.rust.codegen.core.smithy.transformers import software.amazon.smithy.codegen.core.TopologicalIndex import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.SetShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait -object RecursiveShapeBoxer { +class RecursiveShapeBoxer( /** - * Transform a model which may contain recursive shapes into a model annotated with [RustBoxTrait] + * A predicate that determines when a cycle in the shape graph contains "indirection". If a cycle contains + * indirection, no shape needs to be tagged. What constitutes indirection is up to the caller to decide. + */ + private val containsIndirectionPredicate: (Collection) -> Boolean = ::containsIndirection, + /** + * A closure that gets called on one member shape of a cycle that does not contain indirection for "fixing". For + * example, the [RustBoxTrait] trait can be used to tag the member shape. + */ + private val boxShapeFn: (MemberShape) -> MemberShape = ::addRustBoxTrait, +) { + /** + * Transform a model which may contain recursive shapes. * - * When recursive shapes do NOT go through a List, Map, or Set, they must be boxed in Rust. This function will - * iteratively find loops & add the `RustBox` trait in a deterministic way until it reaches a fixed point. + * For example, when recursive shapes do NOT go through a `CollectionShape` or a `MapShape` shape, they must be + * boxed in Rust. This function will iteratively find cycles and call [boxShapeFn] on a member shape in the + * cycle to act on it. This is done in a deterministic way until it reaches a fixed point. * - * This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. Even so + * This function MUST be deterministic (always choose the same shapes to fix). If it is not, that is a bug. Even so * this function may cause backward compatibility issues in certain pathological cases where a changes to recursive * structures cause different members to be boxed. We may need to address these via customizations. + * + * For example, given the following model, + * + * ```smithy + * namespace com.example + * + * structure Recursive { + * recursiveStruct: Recursive + * anotherField: Boolean + * } + * ``` + * + * The `com.example#Recursive$recursiveStruct` member shape is part of a cycle, but the + * `com.example#Recursive$anotherField` member shape is not. */ fun transform(model: Model): Model { val next = transformInner(model) @@ -37,16 +62,17 @@ object RecursiveShapeBoxer { } /** - * If [model] contains a recursive loop that must be boxed, apply one instance of [RustBoxTrait] return the new model. - * If [model] contains no loops, return null. + * If [model] contains a recursive loop that must be boxed, return the transformed model resulting form a call to + * [boxShapeFn]. + * If [model] contains no loops, return `null`. */ private fun transformInner(model: Model): Model? { - // Execute 1-step of the boxing algorithm in the path to reaching a fixed point - // 1. Find all the shapes that are part of a cycle - // 2. Find all the loops that those shapes are part of - // 3. Filter out the loops that go through a layer of indirection - // 3. Pick _just one_ of the remaining loops to fix - // 4. Select the member shape in that loop with the earliest shape id + // Execute 1 step of the boxing algorithm in the path to reaching a fixed point: + // 1. Find all the shapes that are part of a cycle. + // 2. Find all the loops that those shapes are part of. + // 3. Filter out the loops that go through a layer of indirection. + // 3. Pick _just one_ of the remaining loops to fix. + // 4. Select the member shape in that loop with the earliest shape id. // 5. Box it. // (External to this function) Go back to 1. val index = TopologicalIndex.of(model) @@ -58,34 +84,38 @@ object RecursiveShapeBoxer { // Flatten the connections into shapes. loops.map { it.shapes } } - val loopToFix = loops.firstOrNull { !containsIndirection(it) } + val loopToFix = loops.firstOrNull { !containsIndirectionPredicate(it) } return loopToFix?.let { loop: List -> check(loop.isNotEmpty()) - // pick the shape to box in a deterministic way + // Pick the shape to box in a deterministic way. val shapeToBox = loop.filterIsInstance().minByOrNull { it.id }!! ModelTransformer.create().mapShapes(model) { shape -> if (shape == shapeToBox) { - shape.asMemberShape().get().toBuilder().addTrait(RustBoxTrait()).build() + boxShapeFn(shape.asMemberShape().get()) } else { shape } } } } +} - /** - * Check if a List contains a shape which will use a pointer when represented in Rust, avoiding the - * need to add more Boxes - */ - private fun containsIndirection(loop: List): Boolean { - return loop.find { - when (it) { - is ListShape, - is MapShape, - is SetShape, -> true - else -> it.hasTrait() - } - } != null +/** + * Check if a `List` contains a shape which will use a pointer when represented in Rust, avoiding the + * need to add more `Box`es. + * + * Why `CollectionShape`s and `MapShape`s? Note that `CollectionShape`s get rendered in Rust as `Vec`, and + * `MapShape`s as `HashMap`; they're the only Smithy shapes that "organically" introduce indirection + * (via a pointer to the heap) in the recursive path. For other recursive paths, we thus have to introduce the + * indirection artificially ourselves using `Box`. + * + */ +private fun containsIndirection(loop: Collection): Boolean = loop.find { + when (it) { + is CollectionShape, is MapShape -> true + else -> it.hasTrait() } -} +} != null + +private fun addRustBoxTrait(shape: MemberShape): MemberShape = shape.toBuilder().addTrait(RustBoxTrait()).build() diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt new file mode 100644 index 00000000000..d1f208c3931 --- /dev/null +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/CodegenIntegrationTest.kt @@ -0,0 +1,45 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.testutil + +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.util.runCommand +import java.io.File +import java.nio.file.Path + +/** + * A helper class holding common data with defaults that is threaded through several functions, to make their + * signatures shorter. + */ +data class IntegrationTestParams( + val addModuleToEventStreamAllowList: Boolean = false, + val service: String? = null, + val runtimeConfig: RuntimeConfig? = null, + val additionalSettings: ObjectNode = ObjectNode.builder().build(), + val overrideTestDir: File? = null, + val command: ((Path) -> Unit)? = null, +) + +/** + * Run cargo test on a true, end-to-end, codegen product of a given model. + */ +fun codegenIntegrationTest(model: Model, params: IntegrationTestParams, invokePlugin: (PluginContext) -> Unit): Path { + val (ctx, testDir) = generatePluginContext( + model, + params.additionalSettings, + params.addModuleToEventStreamAllowList, + params.service, + params.runtimeConfig, + params.overrideTestDir, + ) + invokePlugin(ctx) + ctx.fileManifest.printGeneratedFiles() + params.command?.invoke(testDir) ?: "cargo test".runCommand(testDir, environment = mapOf("RUSTFLAGS" to "-D warnings")) + return testDir +} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt index 6e82fc1b2ca..95ea5677a4c 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt @@ -5,26 +5,35 @@ package software.amazon.smithy.rust.codegen.core.testutil +import org.intellij.lang.annotations.Language import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.dq -internal object EventStreamMarshallTestCases { - internal fun RustWriter.writeMarshallTestCases( +object EventStreamMarshallTestCases { + fun RustWriter.writeMarshallTestCases( testCase: EventStreamTestModels.TestCase, - generator: RuntimeType, + optionalBuilderInputs: Boolean, ) { + val generator = "crate::event_stream_serde::TestStreamMarshaller" + val protocolTestHelpers = CargoDependency.smithyProtocolTestHelpers(TestRuntimeConfig) .copy(scope = DependencyScope.Compile) + + fun builderInput( + @Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}") + input: String, + vararg ctx: Pair, + ): Writable = conditionalBuilderInput(input, conditional = optionalBuilderInputs, ctx = ctx) + rustTemplate( """ use aws_smithy_eventstream::frame::{Message, Header, HeaderValue, MarshallMessage}; use std::collections::HashMap; use aws_smithy_types::{Blob, DateTime}; - use crate::error::*; use crate::model::*; use #{validate_body}; @@ -46,163 +55,192 @@ internal object EventStreamMarshallTestCases { "MediaType" to protocolTestHelpers.toType().resolve("MediaType"), ) - unitTest( - "message_with_blob", - """ - let event = TestStream::MessageWithBlob( - MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap()); - assert_eq!(&b"hello, world!"[..], message.payload()); - """, - ) - - unitTest( - "message_with_string", - """ - let event = TestStream::MessageWithString( - MessageWithString::builder().data("hello, world!").build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap()); - assert_eq!(&b"hello, world!"[..], message.payload()); - """, - ) - - unitTest( - "message_with_struct", - """ - let event = TestStream::MessageWithStruct( - MessageWithStruct::builder().some_struct( - TestStruct::builder() - .some_string("hello") - .some_int(5) - .build() - ).build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); - - validate_body( - message.payload(), - ${testCase.validTestStruct.dq()}, - MediaType::from(${testCase.requestContentType.dq()}) - ).unwrap(); - """, - ) - - unitTest( - "message_with_union", - """ - let event = TestStream::MessageWithUnion(MessageWithUnion::builder().some_union( - TestUnion::Foo("hello".into()) - ).build()); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); - - validate_body( - message.payload(), - ${testCase.validTestUnion.dq()}, - MediaType::from(${testCase.requestContentType.dq()}) - ).unwrap(); - """, - ) - - unitTest( - "message_with_headers", - """ - let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder() - .blob(Blob::new(&b"test"[..])) - .boolean(true) - .byte(55i8) - .int(100_000i32) - .long(9_000_000_000i64) - .short(16_000i16) - .string("test") - .timestamp(DateTime::from_secs(5)) - .build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let actual_message = result.unwrap(); - let expected_message = Message::new(&b""[..]) - .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) - .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into()))) - .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) - .add_header(Header::new("boolean", HeaderValue::Bool(true))) - .add_header(Header::new("byte", HeaderValue::Byte(55i8))) - .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) - .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) - .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) - .add_header(Header::new("string", HeaderValue::String("test".into()))) - .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); - assert_eq!(expected_message, actual_message); - """, - ) - - unitTest( - "message_with_header_and_payload", - """ - let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() - .header("header") - .payload(Blob::new(&b"payload"[..])) - .build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let actual_message = result.unwrap(); - let expected_message = Message::new(&b"payload"[..]) - .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) - .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into()))) - .add_header(Header::new("header", HeaderValue::String("header".into()))) - .add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into()))); - assert_eq!(expected_message, actual_message); - """, - ) - - unitTest( - "message_with_no_header_payload_traits", - """ - let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() - .some_int(5) - .some_string("hello") - .build() - ); - let result = ${format(generator)}().marshall(event); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - let message = result.unwrap(); - let headers = headers_to_map(message.headers()); - assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); - assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); - - validate_body( - message.payload(), - ${testCase.validMessageWithNoHeaderPayloadTraits.dq()}, - MediaType::from(${testCase.requestContentType.dq()}) - ).unwrap(); - """, - ) + unitTest("message_with_blob") { + rustTemplate( + """ + let event = TestStream::MessageWithBlob( + MessageWithBlob::builder().data(#{BlobInput:W}).build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithBlob"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header("application/octet-stream"), *headers.get(":content-type").unwrap()); + assert_eq!(&b"hello, world!"[..], message.payload()); + """, + "BlobInput" to builderInput("Blob::new(&b\"hello, world!\"[..])"), + ) + } + + unitTest("message_with_string") { + rustTemplate( + """ + let event = TestStream::MessageWithString( + MessageWithString::builder().data(#{StringInput}).build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithString"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header("text/plain"), *headers.get(":content-type").unwrap()); + assert_eq!(&b"hello, world!"[..], message.payload()); + """, + "StringInput" to builderInput("\"hello, world!\""), + ) + } + + unitTest("message_with_struct") { + rustTemplate( + """ + let event = TestStream::MessageWithStruct( + MessageWithStruct::builder().some_struct(#{StructInput}).build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + + validate_body( + message.payload(), + ${testCase.validTestStruct.dq()}, + MediaType::from(${testCase.mediaType.dq()}) + ).unwrap(); + """, + "StructInput" to + builderInput( + """ + TestStruct::builder() + .some_string(#{StringInput}) + .some_int(#{IntInput}) + .build() + """, + "IntInput" to builderInput("5"), + "StringInput" to builderInput("\"hello\""), + ), + ) + } + + unitTest("message_with_union") { + rustTemplate( + """ + let event = TestStream::MessageWithUnion(MessageWithUnion::builder() + .some_union(#{UnionInput}) + .build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + + validate_body( + message.payload(), + ${testCase.validTestUnion.dq()}, + MediaType::from(${testCase.mediaType.dq()}) + ).unwrap(); + """, + "UnionInput" to builderInput("TestUnion::Foo(\"hello\".into())"), + ) + } + + unitTest("message_with_headers") { + rustTemplate( + """ + let event = TestStream::MessageWithHeaders(MessageWithHeaders::builder() + .blob(#{BlobInput}) + .boolean(#{BooleanInput}) + .byte(#{ByteInput}) + .int(#{IntInput}) + .long(#{LongInput}) + .short(#{ShortInput}) + .string(#{StringInput}) + .timestamp(#{TimestampInput}) + .build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let actual_message = result.unwrap(); + let expected_message = Message::new(&b""[..]) + .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) + .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaders".into()))) + .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) + .add_header(Header::new("boolean", HeaderValue::Bool(true))) + .add_header(Header::new("byte", HeaderValue::Byte(55i8))) + .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) + .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) + .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) + .add_header(Header::new("string", HeaderValue::String("test".into()))) + .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); + assert_eq!(expected_message, actual_message); + """, + "BlobInput" to builderInput("Blob::new(&b\"test\"[..])"), + "BooleanInput" to builderInput("true"), + "ByteInput" to builderInput("55i8"), + "IntInput" to builderInput("100_000i32"), + "LongInput" to builderInput("9_000_000_000i64"), + "ShortInput" to builderInput("16_000i16"), + "StringInput" to builderInput("\"test\""), + "TimestampInput" to builderInput("DateTime::from_secs(5)"), + ) + } + + unitTest("message_with_header_and_payload") { + rustTemplate( + """ + let event = TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() + .header(#{HeaderInput}) + .payload(#{PayloadInput}) + .build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let actual_message = result.unwrap(); + let expected_message = Message::new(&b"payload"[..]) + .add_header(Header::new(":message-type", HeaderValue::String("event".into()))) + .add_header(Header::new(":event-type", HeaderValue::String("MessageWithHeaderAndPayload".into()))) + .add_header(Header::new("header", HeaderValue::String("header".into()))) + .add_header(Header::new(":content-type", HeaderValue::String("application/octet-stream".into()))); + assert_eq!(expected_message, actual_message); + """, + "HeaderInput" to builderInput("\"header\""), + "PayloadInput" to builderInput("Blob::new(&b\"payload\"[..])"), + ) + } + + unitTest("message_with_no_header_payload_traits") { + rustTemplate( + """ + let event = TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() + .some_int(#{IntInput}) + .some_string(#{StringInput}) + .build() + ); + let result = $generator::new().marshall(event); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + let message = result.unwrap(); + let headers = headers_to_map(message.headers()); + assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); + assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap()); + assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + + validate_body( + message.payload(), + ${testCase.validMessageWithNoHeaderPayloadTraits.dq()}, + MediaType::from(${testCase.mediaType.dq()}) + ).unwrap(); + """, + "IntInput" to builderInput("5"), + "StringInput" to builderInput("\"hello\""), + ) + } } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index 58ab85eec60..e944a552a08 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -19,6 +19,7 @@ private fun fillInBaseModel( ): String = """ namespace test + use smithy.framework#ValidationException use aws.protocols#$protocolName union TestUnion { @@ -69,12 +70,20 @@ private fun fillInBaseModel( MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits, SomeError: SomeError, } - structure TestStreamInputOutput { @httpPayload @required value: TestStream } + + structure TestStreamInputOutput { + @required + @httpPayload + value: TestStream + } + + @http(method: "POST", uri: "/test") operation TestStreamOp { input: TestStreamInputOutput, output: TestStreamInputOutput, - errors: [SomeError], + errors: [SomeError, ValidationException], } + $extraServiceAnnotations @$protocolName service TestService { version: "123", operations: [TestStreamOp] } @@ -92,6 +101,7 @@ object EventStreamTestModels { data class TestCase( val protocolShapeId: String, val model: Model, + val mediaType: String, val requestContentType: String, val responseContentType: String, val validTestStruct: String, @@ -111,7 +121,8 @@ object EventStreamTestModels { TestCase( protocolShapeId = "aws.protocols#restJson1", model = restJson1(), - requestContentType = "application/json", + mediaType = "application/json", + requestContentType = "application/vnd.amazon.eventstream", responseContentType = "application/json", validTestStruct = """{"someString":"hello","someInt":5}""", validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", @@ -126,6 +137,7 @@ object EventStreamTestModels { TestCase( protocolShapeId = "aws.protocols#awsJson1_1", model = awsJson11(), + mediaType = "application/x-amz-json-1.1", requestContentType = "application/x-amz-json-1.1", responseContentType = "application/x-amz-json-1.1", validTestStruct = """{"someString":"hello","someInt":5}""", @@ -141,7 +153,8 @@ object EventStreamTestModels { TestCase( protocolShapeId = "aws.protocols#restXml", model = restXml(), - requestContentType = "application/xml", + mediaType = "application/xml", + requestContentType = "application/vnd.amazon.eventstream", responseContentType = "application/xml", validTestStruct = """ diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt deleted file mode 100644 index 05151b7ce7c..00000000000 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestTools.kt +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.core.testutil - -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.model.traits.ErrorTrait -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer -import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases -import software.amazon.smithy.rust.codegen.core.util.hasTrait -import software.amazon.smithy.rust.codegen.core.util.lookup -import software.amazon.smithy.rust.codegen.core.util.outputShape - -data class TestEventStreamProject( - val model: Model, - val serviceShape: ServiceShape, - val operationShape: OperationShape, - val streamShape: UnionShape, - val symbolProvider: RustSymbolProvider, - val project: TestWriterDelegator, -) - -enum class EventStreamTestVariety { - Marshall, - Unmarshall -} - -interface EventStreamTestRequirements { - /** Create a codegen context for the tests */ - fun createCodegenContext( - model: Model, - serviceShape: ServiceShape, - protocolShapeId: ShapeId, - codegenTarget: CodegenTarget, - ): C - - /** Render the event stream marshall/unmarshall code generator */ - fun renderGenerator( - codegenContext: C, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType - - /** Render a builder for the given shape */ - fun renderBuilderForShape( - writer: RustWriter, - codegenContext: C, - shape: StructureShape, - ) - - /** Render an operation error for the given operation and error shapes */ - fun renderOperationError( - writer: RustWriter, - model: Model, - symbolProvider: RustSymbolProvider, - operationSymbol: Symbol, - errors: List, - ) -} - -object EventStreamTestTools { - fun runTestCase( - testCase: EventStreamTestModels.TestCase, - requirements: EventStreamTestRequirements, - codegenTarget: CodegenTarget, - variety: EventStreamTestVariety, - ) { - val model = EventStreamNormalizer.transform(OperationNormalizer.transform(testCase.model)) - val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape - val codegenContext = requirements.createCodegenContext( - model, - serviceShape, - ShapeId.from(testCase.protocolShapeId), - codegenTarget, - ) - val test = generateTestProject(requirements, codegenContext, codegenTarget) - val protocol = testCase.protocolBuilder(codegenContext) - val generator = requirements.renderGenerator(codegenContext, test, protocol) - - test.project.lib { - when (variety) { - EventStreamTestVariety.Marshall -> writeMarshallTestCases(testCase, generator) - EventStreamTestVariety.Unmarshall -> writeUnmarshallTestCases(testCase, codegenTarget, generator) - } - } - test.project.compileAndTest() - } - - private fun generateTestProject( - requirements: EventStreamTestRequirements, - codegenContext: C, - codegenTarget: CodegenTarget, - ): TestEventStreamProject { - val model = codegenContext.model - val symbolProvider = codegenContext.symbolProvider - val operationShape = model.expectShape(ShapeId.from("test#TestStreamOp")) as OperationShape - val unionShape = model.expectShape(ShapeId.from("test#TestStream")) as UnionShape - - val project = TestWorkspace.testProject(symbolProvider) - val operationSymbol = symbolProvider.toSymbol(operationShape) - project.withModule(ErrorsModule) { - val errors = model.structureShapes.filter { shape -> shape.hasTrait() } - requirements.renderOperationError(this, model, symbolProvider, operationSymbol, errors) - requirements.renderOperationError(this, model, symbolProvider, symbolProvider.toSymbol(unionShape), errors) - for (shape in errors) { - StructureGenerator(model, symbolProvider, this, shape).render(codegenTarget) - requirements.renderBuilderForShape(this, codegenContext, shape) - } - } - project.withModule(ModelsModule) { - val inputOutput = model.lookup("test#TestStreamInputOutput") - recursivelyGenerateModels(model, symbolProvider, inputOutput, this, codegenTarget) - } - project.withModule(RustModule.Output) { - operationShape.outputShape(model).renderWithModelBuilder(model, symbolProvider, this) - } - return TestEventStreamProject( - model, - codegenContext.serviceShape, - operationShape, - unionShape, - symbolProvider, - project, - ) - } - - private fun recursivelyGenerateModels( - model: Model, - symbolProvider: RustSymbolProvider, - shape: Shape, - writer: RustWriter, - mode: CodegenTarget, - ) { - for (member in shape.members()) { - if (member.target.namespace == "smithy.api") { - continue - } - val target = model.expectShape(member.target) - when (target) { - is StructureShape -> target.renderWithModelBuilder(model, symbolProvider, writer) - is UnionShape -> UnionGenerator( - model, - symbolProvider, - writer, - target, - renderUnknownVariant = mode.renderUnknownVariant(), - ).render() - else -> TODO("EventStreamTestTools doesn't support rendering $target") - } - recursivelyGenerateModels(model, symbolProvider, target, writer, mode) - } - } -} diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt index bb27c724e00..daff01af3f0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamUnmarshallTestCases.kt @@ -5,22 +5,27 @@ package software.amazon.smithy.rust.codegen.core.testutil +import org.intellij.lang.annotations.Language +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable -internal object EventStreamUnmarshallTestCases { - internal fun RustWriter.writeUnmarshallTestCases( +object EventStreamUnmarshallTestCases { + fun RustWriter.writeUnmarshallTestCases( testCase: EventStreamTestModels.TestCase, - codegenTarget: CodegenTarget, - generator: RuntimeType, + optionalBuilderInputs: Boolean = false, ) { + val generator = "crate::event_stream_serde::TestStreamUnmarshaller" + rust( """ use aws_smithy_eventstream::frame::{Header, HeaderValue, Message, UnmarshallMessage, UnmarshalledMessage}; use aws_smithy_types::{Blob, DateTime}; - use crate::error::*; + use crate::error::TestStreamError; use crate::model::*; fn msg( @@ -53,206 +58,199 @@ internal object EventStreamUnmarshallTestCases { """, ) - unitTest( - name = "message_with_blob", - test = """ + unitTest("message_with_blob") { + rustTemplate( + """ let message = msg("event", "MessageWithBlob", "application/octet-stream", b"hello, world!"); - let result = ${format(generator)}().unmarshall(&message); + let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); assert_eq!( TestStream::MessageWithBlob( - MessageWithBlob::builder().data(Blob::new(&b"hello, world!"[..])).build() + MessageWithBlob::builder().data(#{DataInput:W}).build() ), expect_event(result.unwrap()) ); - """, - ) + """, + "DataInput" to conditionalBuilderInput( + """ + Blob::new(&b"hello, world!"[..]) + """, + conditional = optionalBuilderInputs, + ), - if (codegenTarget == CodegenTarget.CLIENT) { - unitTest( - "unknown_message", + ) + } + + unitTest("message_with_string") { + rustTemplate( """ - let message = msg("event", "NewUnmodeledMessageType", "application/octet-stream", b"hello, world!"); - let result = ${format(generator)}().unmarshall(&message); + let message = msg("event", "MessageWithString", "text/plain", b"hello, world!"); + let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); assert_eq!( - TestStream::Unknown, + TestStream::MessageWithString(MessageWithString::builder().data(#{DataInput}).build()), expect_event(result.unwrap()) ); """, + "DataInput" to conditionalBuilderInput("\"hello, world!\"", conditional = optionalBuilderInputs), ) } - unitTest( - "message_with_string", - """ - let message = msg("event", "MessageWithString", "text/plain", b"hello, world!"); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithString(MessageWithString::builder().data("hello, world!").build()), - expect_event(result.unwrap()) - ); - """, - ) - - unitTest( - "message_with_struct", - """ - let message = msg( - "event", - "MessageWithStruct", - "${testCase.responseContentType}", - br#"${testCase.validTestStruct}"# - ); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct( + unitTest("message_with_struct") { + rustTemplate( + """ + let message = msg( + "event", + "MessageWithStruct", + "${testCase.responseContentType}", + br##"${testCase.validTestStruct}"## + ); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithStruct(MessageWithStruct::builder().some_struct(#{StructInput}).build()), + expect_event(result.unwrap()) + ); + """, + "StructInput" to conditionalBuilderInput( + """ TestStruct::builder() - .some_string("hello") - .some_int(5) + .some_string(#{StringInput}) + .some_int(#{IntInput}) .build() - ).build()), - expect_event(result.unwrap()) - ); - """, - ) + """, + conditional = optionalBuilderInputs, + "StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs), + "IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs), + ), - unitTest( - "message_with_union", - """ - let message = msg( - "event", - "MessageWithUnion", - "${testCase.responseContentType}", - br#"${testCase.validTestUnion}"# - ); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithUnion(MessageWithUnion::builder().some_union( - TestUnion::Foo("hello".into()) - ).build()), - expect_event(result.unwrap()) - ); - """, - ) + ) + } - unitTest( - "message_with_headers", - """ - let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"") - .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) - .add_header(Header::new("boolean", HeaderValue::Bool(true))) - .add_header(Header::new("byte", HeaderValue::Byte(55i8))) - .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) - .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) - .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) - .add_header(Header::new("string", HeaderValue::String("test".into()))) - .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithHeaders(MessageWithHeaders::builder() - .blob(Blob::new(&b"test"[..])) - .boolean(true) - .byte(55i8) - .int(100_000i32) - .long(9_000_000_000i64) - .short(16_000i16) - .string("test") - .timestamp(DateTime::from_secs(5)) - .build() - ), - expect_event(result.unwrap()) - ); - """, - ) + unitTest("message_with_union") { + rustTemplate( + """ + let message = msg( + "event", + "MessageWithUnion", + "${testCase.responseContentType}", + br##"${testCase.validTestUnion}"## + ); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithUnion(MessageWithUnion::builder().some_union(#{UnionInput}).build()), + expect_event(result.unwrap()) + ); + """, + "UnionInput" to conditionalBuilderInput("TestUnion::Foo(\"hello\".into())", conditional = optionalBuilderInputs), + ) + } - unitTest( - "message_with_header_and_payload", - """ - let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload") - .add_header(Header::new("header", HeaderValue::String("header".into()))); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() - .header("header") - .payload(Blob::new(&b"payload"[..])) - .build() - ), - expect_event(result.unwrap()) - ); - """, - ) + unitTest("message_with_headers") { + rustTemplate( + """ + let message = msg("event", "MessageWithHeaders", "application/octet-stream", b"") + .add_header(Header::new("blob", HeaderValue::ByteArray((&b"test"[..]).into()))) + .add_header(Header::new("boolean", HeaderValue::Bool(true))) + .add_header(Header::new("byte", HeaderValue::Byte(55i8))) + .add_header(Header::new("int", HeaderValue::Int32(100_000i32))) + .add_header(Header::new("long", HeaderValue::Int64(9_000_000_000i64))) + .add_header(Header::new("short", HeaderValue::Int16(16_000i16))) + .add_header(Header::new("string", HeaderValue::String("test".into()))) + .add_header(Header::new("timestamp", HeaderValue::Timestamp(DateTime::from_secs(5)))); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithHeaders(MessageWithHeaders::builder() + .blob(#{BlobInput}) + .boolean(#{BoolInput}) + .byte(#{ByteInput}) + .int(#{IntInput}) + .long(#{LongInput}) + .short(#{ShortInput}) + .string(#{StringInput}) + .timestamp(#{TimestampInput}) + .build() + ), + expect_event(result.unwrap()) + ); + """, + "BlobInput" to conditionalBuilderInput("Blob::new(&b\"test\"[..])", conditional = optionalBuilderInputs), + "BoolInput" to conditionalBuilderInput("true", conditional = optionalBuilderInputs), + "ByteInput" to conditionalBuilderInput("55i8", conditional = optionalBuilderInputs), + "IntInput" to conditionalBuilderInput("100_000i32", conditional = optionalBuilderInputs), + "LongInput" to conditionalBuilderInput("9_000_000_000i64", conditional = optionalBuilderInputs), + "ShortInput" to conditionalBuilderInput("16_000i16", conditional = optionalBuilderInputs), + "StringInput" to conditionalBuilderInput("\"test\"", conditional = optionalBuilderInputs), + "TimestampInput" to conditionalBuilderInput("DateTime::from_secs(5)", conditional = optionalBuilderInputs), + ) + } - unitTest( - "message_with_no_header_payload_traits", - """ - let message = msg( - "event", - "MessageWithNoHeaderPayloadTraits", - "${testCase.responseContentType}", - br#"${testCase.validMessageWithNoHeaderPayloadTraits}"# - ); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - assert_eq!( - TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() - .some_int(5) - .some_string("hello") - .build() - ), - expect_event(result.unwrap()) - ); - """, - ) + unitTest("message_with_header_and_payload") { + rustTemplate( + """ + let message = msg("event", "MessageWithHeaderAndPayload", "application/octet-stream", b"payload") + .add_header(Header::new("header", HeaderValue::String("header".into()))); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithHeaderAndPayload(MessageWithHeaderAndPayload::builder() + .header(#{HeaderInput}) + .payload(#{PayloadInput}) + .build() + ), + expect_event(result.unwrap()) + ); + """, + "HeaderInput" to conditionalBuilderInput("\"header\"", conditional = optionalBuilderInputs), + "PayloadInput" to conditionalBuilderInput("Blob::new(&b\"payload\"[..])", conditional = optionalBuilderInputs), + ) + } - val (someError, kindSuffix) = when (codegenTarget) { - CodegenTarget.CLIENT -> "TestStreamErrorKind::SomeError" to ".kind" - CodegenTarget.SERVER -> "TestStreamError::SomeError" to "" + unitTest("message_with_no_header_payload_traits") { + rustTemplate( + """ + let message = msg( + "event", + "MessageWithNoHeaderPayloadTraits", + "${testCase.responseContentType}", + br##"${testCase.validMessageWithNoHeaderPayloadTraits}"## + ); + let result = $generator::new().unmarshall(&message); + assert!(result.is_ok(), "expected ok, got: {:?}", result); + assert_eq!( + TestStream::MessageWithNoHeaderPayloadTraits(MessageWithNoHeaderPayloadTraits::builder() + .some_int(#{IntInput}) + .some_string(#{StringInput}) + .build() + ), + expect_event(result.unwrap()) + ); + """, + "IntInput" to conditionalBuilderInput("5", conditional = optionalBuilderInputs), + "StringInput" to conditionalBuilderInput("\"hello\"", conditional = optionalBuilderInputs), + ) } - unitTest( - "some_error", - """ - let message = msg( - "exception", - "SomeError", - "${testCase.responseContentType}", - br#"${testCase.validSomeError}"# - ); - let result = ${format(generator)}().unmarshall(&message); - assert!(result.is_ok(), "expected ok, got: {:?}", result); - match expect_error(result.unwrap())$kindSuffix { - $someError(err) => assert_eq!(Some("some error"), err.message()), - kind => panic!("expected SomeError, but got {:?}", kind), - } - """, - ) - if (codegenTarget == CodegenTarget.CLIENT) { - unitTest( - "generic_error", + unitTest("some_error") { + rustTemplate( """ let message = msg( "exception", - "UnmodeledError", + "SomeError", "${testCase.responseContentType}", - br#"${testCase.validUnmodeledError}"# + br##"${testCase.validSomeError}"## ); - let result = ${format(generator)}().unmarshall(&message); + let result = $generator::new().unmarshall(&message); assert!(result.is_ok(), "expected ok, got: {:?}", result); - match expect_error(result.unwrap())$kindSuffix { - TestStreamErrorKind::Unhandled(err) => { - let message = format!("{}", aws_smithy_types::error::display::DisplayErrorContext(&err)); - let expected = "message: \"unmodeled error\""; - assert!(message.contains(expected), "Expected '{message}' to contain '{expected}'"); - } - kind => panic!("expected generic error, but got {:?}", kind), + match expect_error(result.unwrap()) { + TestStreamError::SomeError(err) => assert_eq!(Some("some error"), err.message()), + #{AllowUnreachablePatterns:W} + kind => panic!("expected SomeError, but got {:?}", kind), } """, + "AllowUnreachablePatterns" to writable { Attribute.AllowUnreachablePatterns.render(this) }, ) } @@ -265,10 +263,21 @@ internal object EventStreamUnmarshallTestCases { "wrong-content-type", br#"${testCase.validTestStruct}"# ); - let result = ${format(generator)}().unmarshall(&message); + let result = $generator::new().unmarshall(&message); assert!(result.is_err(), "expected error, got: {:?}", result); assert!(format!("{}", result.err().unwrap()).contains("expected :content-type to be")); """, ) } } + +internal fun conditionalBuilderInput( + @Language("Rust", prefix = "macro_rules! foo { () => {{\n", suffix = "\n}}}") contents: String, + conditional: Boolean, + vararg ctx: Pair, +): Writable = + writable { + conditionalBlock("Some(", ".into())", conditional = conditional) { + rustTemplate(contents, *ctx) + } + } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt index a8567c4db47..613a12252d1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt @@ -9,31 +9,30 @@ import com.moandjiezana.toml.TomlWriter import org.intellij.lang.annotations.Language import software.amazon.smithy.build.FileManifest import software.amazon.smithy.build.PluginContext -import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model +import software.amazon.smithy.model.loader.ModelAssembler import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.node.ObjectNode -import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope import software.amazon.smithy.rust.codegen.core.rustlang.RustDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.raw import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.smithy.CoreCodegenConfig -import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.core.util.CommandFailed -import software.amazon.smithy.rust.codegen.core.util.PANIC import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.orNullIfEmpty import software.amazon.smithy.rust.codegen.core.util.runCommand import java.io.File import java.nio.file.Files.createTempDirectory @@ -101,7 +100,7 @@ object TestWorkspace { // help rust select the right version when we run cargo test // TODO(https://github.com/awslabs/smithy-rs/issues/2048): load this from the msrv property using a // method as we do for runtime crate versions - "[toolchain]\nchannel = \"1.62.1\"\n", + "[toolchain]\nchannel = \"1.63.0\"\n", ) // ensure there at least an empty lib.rs file to avoid broken crates newProject.resolve("src").mkdirs() @@ -112,26 +111,20 @@ object TestWorkspace { } } - @Suppress("NAME_SHADOWING") - fun testProject(symbolProvider: RustSymbolProvider? = null, debugMode: Boolean = false): TestWriterDelegator { - val subprojectDir = subproject() - val symbolProvider = symbolProvider ?: object : RustSymbolProvider { - override fun config(): SymbolVisitorConfig { - PANIC("") - } - - override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? { - PANIC("") - } + fun testProject( + model: Model = ModelAssembler().assemble().unwrap(), + codegenConfig: CoreCodegenConfig = CoreCodegenConfig(), + ): TestWriterDelegator = testProject(testSymbolProvider(model), codegenConfig) - override fun toSymbol(shape: Shape?): Symbol { - PANIC("") - } - } + fun testProject( + symbolProvider: RustSymbolProvider, + codegenConfig: CoreCodegenConfig = CoreCodegenConfig(), + ): TestWriterDelegator { + val subprojectDir = subproject() return TestWriterDelegator( FileManifest.create(subprojectDir.toPath()), symbolProvider, - CoreCodegenConfig(debugMode = debugMode), + codegenConfig, ).apply { lib { // If the test fails before the crate is finalized, we'll end up with a broken crate. @@ -190,8 +183,7 @@ fun generatePluginContext( ) } - val settings = settingsBuilder.merge(additionalSettings) - .build() + val settings = settingsBuilder.merge(additionalSettings).build() val pluginContext = PluginContext.builder().model(model).fileManifest(manifest).settings(settings).build() return pluginContext to testPath } @@ -221,7 +213,47 @@ fun RustWriter.unitTest( if (async) { rust("async") } - return rustBlock("fn $name()", *args, block = block) + return testDependenciesOnly { rustBlock("fn $name()", *args, block = block) } +} + +fun RustWriter.cargoDependencies() = dependencies.map { RustDependency.fromSymbolDependency(it) } + .filterIsInstance().distinct() + +fun RustWriter.assertNoNewDependencies(block: Writable, dependencyFilter: (CargoDependency) -> String?): RustWriter { + val startingDependencies = cargoDependencies().toSet() + block(this) + val endingDependencies = cargoDependencies().toSet() + val newDeps = (endingDependencies - startingDependencies) + val invalidDeps = + newDeps.mapNotNull { dep -> dependencyFilter(dep)?.let { message -> message to dep } }.orNullIfEmpty() + if (invalidDeps != null) { + val badDeps = invalidDeps.map { it.second.rustName } + val writtenOut = this.toString() + val badLines = writtenOut.lines().filter { line -> badDeps.any { line.contains(it) } } + throw CodegenException( + "found invalid dependencies. ${invalidDeps.map { + it.first + }}\nHint: the following lines may be the problem.\n${ + badLines.joinToString( + separator = "\n", + prefix = " ", + ) + }", + ) + } + return this +} + +fun RustWriter.testDependenciesOnly(block: Writable) = assertNoNewDependencies(block) { dep -> + if (dep.scope != DependencyScope.Dev) { + "Cannot add $dep — this writer should only add test dependencies." + } else { + null + } +} + +fun testDependenciesOnly(block: Writable): Writable = { + testDependenciesOnly(block) } fun RustWriter.tokioTest(name: String, vararg args: Any, block: Writable) { @@ -252,6 +284,13 @@ class TestWriterDelegator( fun generatedFiles() = fileManifest.files.map { baseDir.relativize(it) } } +/** + * Generate a newtest module + * + * This should only be used in test code—the generated module name will be something like `tests_123` + */ +fun RustCrate.testModule(block: Writable) = lib { withInlineModule(RustModule.inlineTests(safeName("tests")), block) } + fun FileManifest.printGeneratedFiles() { this.files.forEach { path -> println("file:///$path") @@ -263,7 +302,10 @@ fun FileManifest.printGeneratedFiles() { * should generally be set to `false` to avoid invalidating the Cargo cache between * every unit test run. */ -fun TestWriterDelegator.compileAndTest(runClippy: Boolean = false) { +fun TestWriterDelegator.compileAndTest( + runClippy: Boolean = false, + expectFailure: Boolean = false, +): String { val stubModel = """ namespace fake service Fake { @@ -284,10 +326,11 @@ fun TestWriterDelegator.compileAndTest(runClippy: Boolean = false) { // cargo fmt errors are useless, ignore } val env = mapOf("RUSTFLAGS" to "-A dead_code") - "cargo test".runCommand(baseDir, env) + val testOutput = "cargo test".runCommand(baseDir, env) if (runClippy) { "cargo clippy".runCommand(baseDir, env) } + return testOutput } fun TestWriterDelegator.rustSettings() = @@ -424,8 +467,11 @@ fun RustCrate.integrationTest(name: String, writable: Writable) = this.withFile( fun TestWriterDelegator.unitTest(test: Writable): TestWriterDelegator { lib { - unitTest(safeName("test")) { - test(this) + val name = safeName("test") + withInlineModule(RustModule.inlineTests(name)) { + unitTest(name) { + test(this) + } } } return this diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt index 6c9310a309d..574e9daa044 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/TestHelpers.kt @@ -5,37 +5,105 @@ package software.amazon.smithy.rust.codegen.core.testutil +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.CoreCodegenConfig import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings +import software.amazon.smithy.rust.codegen.core.smithy.ModuleProvider +import software.amazon.smithy.rust.codegen.core.smithy.ModuleProviderContext import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeCrateLocation +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import java.io.File val TestRuntimeConfig = RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.Path(File("../rust-runtime/").absolutePath)) -val TestSymbolVisitorConfig = SymbolVisitorConfig( + +/** + * IMPORTANT: You shouldn't need to refer to these directly in code or tests. They are private for a reason. + * + * In general, the RustSymbolProvider's `config()` has a `moduleFor` function that should be used + * to find the destination module for a given shape. + */ +private object CodegenCoreTestModules { + // Use module paths that don't align with either server or client to make sure + // the codegen is resilient to differences in module path. + val ModelsTestModule = RustModule.public("test_model", documentation = "Test models module") + val ErrorsTestModule = RustModule.public("test_error", documentation = "Test error module") + val InputsTestModule = RustModule.public("test_input", documentation = "Test input module") + val OutputsTestModule = RustModule.public("test_output", documentation = "Test output module") + val OperationsTestModule = RustModule.public("test_operation", documentation = "Test operation module") + + object TestModuleProvider : ModuleProvider { + override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = + when (shape) { + is OperationShape -> OperationsTestModule + is StructureShape -> when { + shape.hasTrait() -> ErrorsTestModule + shape.hasTrait() -> InputsTestModule + shape.hasTrait() -> OutputsTestModule + else -> ModelsTestModule + } + + else -> ModelsTestModule + } + + override fun moduleForOperationError( + context: ModuleProviderContext, + operation: OperationShape, + ): RustModule.LeafModule = ErrorsTestModule + + override fun moduleForEventStreamError( + context: ModuleProviderContext, + eventStream: UnionShape, + ): RustModule.LeafModule = ErrorsTestModule + + override fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule { + val builderNamespace = RustReservedWords.escapeIfNeeded("test_" + symbol.name.toSnakeCase()) + return RustModule.new( + builderNamespace, + visibility = Visibility.PUBLIC, + parent = symbol.module(), + inline = true, + ) + } + } +} + +val TestRustSymbolProviderConfig = RustSymbolProviderConfig( runtimeConfig = TestRuntimeConfig, renameExceptions = true, nullabilityCheckMode = NullableIndex.CheckMode.CLIENT_ZERO_VALUE_V1, + moduleProvider = CodegenCoreTestModules.TestModuleProvider, ) fun testRustSettings( @@ -71,11 +139,12 @@ fun String.asSmithyModel(sourceLocation: String? = null, smithyVersion: String = // Intentionally only visible to codegen-core since the other modules have their own symbol providers internal fun testSymbolProvider(model: Model): RustSymbolProvider = SymbolVisitor( + testRustSettings(), model, ServiceShape.builder().version("test").id("test#Service").build(), - TestSymbolVisitorConfig, -).let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf(Attribute.NonExhaustive)) } - .let { RustReservedWordSymbolProvider(it, model) } + TestRustSymbolProviderConfig, +).let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf(Attribute.NonExhaustive)) } + .let { RustReservedWordSymbolProvider(it) } // Intentionally only visible to codegen-core since the other modules have their own contexts internal fun testCodegenContext( @@ -100,13 +169,16 @@ internal fun testCodegenContext( fun StructureShape.renderWithModelBuilder( model: Model, symbolProvider: RustSymbolProvider, - writer: RustWriter, - forWhom: CodegenTarget = CodegenTarget.CLIENT, + rustCrate: RustCrate, ) { - StructureGenerator(model, symbolProvider, writer, this).render(forWhom) - val modelBuilder = BuilderGenerator(model, symbolProvider, this) - modelBuilder.render(writer) - writer.implBlock(this, symbolProvider) { - modelBuilder.renderConvenienceMethod(this) + val struct = this + rustCrate.withModule(symbolProvider.moduleForShape(struct)) { + StructureGenerator(model, symbolProvider, this, struct, emptyList()).render() + implBlock(symbolProvider.toSymbol(struct)) { + BuilderGenerator.renderConvenienceMethod(this, symbolProvider, struct) + } + } + rustCrate.withModule(symbolProvider.moduleForBuilder(struct)) { + BuilderGenerator(model, symbolProvider, struct, emptyList()).render(this) } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/LetIf.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/LetIf.kt index 2ac4da683d7..89868f7a005 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/LetIf.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/LetIf.kt @@ -10,7 +10,9 @@ package software.amazon.smithy.rust.codegen.core.util fun T.letIf(cond: Boolean, f: (T) -> T): T { return if (cond) { f(this) - } else this + } else { + this + } } fun List.extendIf(condition: Boolean, f: () -> T) = if (condition) { diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Panic.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Panic.kt index b7dc3e681c1..7d24a179883 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Panic.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Panic.kt @@ -6,7 +6,7 @@ package software.amazon.smithy.rust.codegen.core.util /** Something has gone horribly wrong due to a coding error */ -fun PANIC(reason: String): Nothing = throw RuntimeException(reason) +fun PANIC(reason: String = ""): Nothing = throw RuntimeException(reason) /** This code should never be executed (but Kotlin cannot prove that) */ fun UNREACHABLE(reason: String): Nothing = throw IllegalStateException("This should be unreachable: $reason") diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt index 267327c35ec..2e5e7fd5b62 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/util/Smithy.kt @@ -42,7 +42,9 @@ fun StructureShape.expectMember(member: String): MemberShape = fun UnionShape.expectMember(member: String): MemberShape = this.getMember(member).orElseThrow { CodegenException("$member did not exist on $this") } -fun StructureShape.errorMessageMember(): MemberShape? = this.getMember("message").or { this.getMember("Message") }.orNull() +fun StructureShape.errorMessageMember(): MemberShape? = this.getMember("message").or { + this.getMember("Message") +}.orNull() fun StructureShape.hasStreamingMember(model: Model) = this.findStreamingMember(model) != null fun UnionShape.hasStreamingMember(model: Model) = this.findMemberWithTrait(model) != null @@ -137,3 +139,6 @@ fun Shape.isPrimitive(): Boolean { else -> false } } + +/** Convert a string to a ShapeId */ +fun String.shapeId() = ShapeId.from(this) diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt index 2f1d170e5e8..9664ce1dac7 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/InlineDependencyTest.kt @@ -34,12 +34,15 @@ internal class InlineDependencyTest { fun `locate dependencies from the inlineable module`() { val dep = InlineDependency.idempotencyToken() val testProject = TestWorkspace.testProject() - testProject.unitTest { + testProject.lib { rustTemplate( """ - use #{idempotency}::uuid_v4; - let res = uuid_v4(0); - assert_eq!(res, "00000000-0000-4000-8000-000000000000"); + ##[test] + fn idempotency_works() { + use #{idempotency}::uuid_v4; + let res = uuid_v4(0); + assert_eq!(res, "00000000-0000-4000-8000-000000000000"); + } """, "idempotency" to dep.toType(), diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt index 5f72e337653..0cbd3fcbb12 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustReservedWordsTest.kt @@ -7,31 +7,31 @@ package software.amazon.smithy.rust.codegen.core.rustlang import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test -import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.MemberShape -import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.traits.EnumDefinition +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig +import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom +import software.amazon.smithy.rust.codegen.core.testutil.TestRustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel -import software.amazon.smithy.rust.codegen.core.util.PANIC -import software.amazon.smithy.rust.codegen.core.util.orNull -import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.core.testutil.testRustSettings +import software.amazon.smithy.rust.codegen.core.util.lookup internal class RustReservedWordSymbolProviderTest { - class Stub : RustSymbolProvider { - override fun config(): SymbolVisitorConfig { - PANIC("") - } - - override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? { - return definition.name.orNull()?.let { MaybeRenamed(it.toPascalCase(), null) } - } + private class TestSymbolProvider(model: Model) : + WrappingSymbolProvider(SymbolVisitor(testRustSettings(), model, null, TestRustSymbolProviderConfig)) - override fun toSymbol(shape: Shape): Symbol { - return Symbol.builder().name(shape.id.name).build() - } + @Test + fun `structs are escaped`() { + val model = """ + namespace test + structure Self {} + """.asSmithyModel() + val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model)) + val symbol = provider.toSymbol(model.lookup("test#Self")) + symbol.name shouldBe "SelfValue" } @Test @@ -41,8 +41,8 @@ internal class RustReservedWordSymbolProviderTest { structure container { async: String } - """.trimMargin().asSmithyModel() - val provider = RustReservedWordSymbolProvider(Stub(), model) + """.asSmithyModel() + val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model)) provider.toMemberName( MemberShape.builder().id("namespace#container\$async").target("namespace#Integer").build(), ) shouldBe "r##async" @@ -54,6 +54,23 @@ internal class RustReservedWordSymbolProviderTest { @Test fun `enum variant names are updated to avoid conflicts`() { + val model = """ + namespace foo + @enum([{ name: "dontcare", value: "dontcare" }]) string Container + """.asSmithyModel() + val provider = RustReservedWordSymbolProvider(TestSymbolProvider(model)) + + fun expectEnumRename(original: String, expected: MaybeRenamed) { + val symbol = provider.toSymbol( + MemberShape.builder() + .id(ShapeId.fromParts("foo", "Container").withMember(original)) + .target("smithy.api#String") + .build(), + ) + symbol.name shouldBe expected.name + symbol.renamedFrom() shouldBe expected.renamedFrom + } + expectEnumRename("Unknown", MaybeRenamed("UnknownValue", "Unknown")) expectEnumRename("UnknownValue", MaybeRenamed("UnknownValue_", "UnknownValue")) expectEnumRename("UnknownOther", MaybeRenamed("UnknownOther", null)) @@ -63,10 +80,4 @@ internal class RustReservedWordSymbolProviderTest { expectEnumRename("SelfOther", MaybeRenamed("SelfOther", null)) expectEnumRename("SELF", MaybeRenamed("SelfValue", "Self")) } - - private fun expectEnumRename(original: String, expected: MaybeRenamed) { - val model = "namespace foo".asSmithyModel() - val provider = RustReservedWordSymbolProvider(Stub(), model) - provider.toEnumVariantName(EnumDefinition.builder().name(original).value("foo").build()) shouldBe expected - } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustTypeTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustTypeTest.kt index 41e0c100666..7fa364dfaf3 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustTypeTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustTypeTest.kt @@ -163,7 +163,12 @@ internal class RustTypesTest { ), ), ) - forInputExpectOutput(writable { attributeMacro.render(this) }, "#[cfg(all(feature = \"unstable\", any(feature = \"serialize\", feature = \"deserialize\")))]\n") + forInputExpectOutput( + writable { + attributeMacro.render(this) + }, + "#[cfg(all(feature = \"unstable\", any(feature = \"serialize\", feature = \"deserialize\")))]\n", + ) } @Test @@ -178,7 +183,12 @@ internal class RustTypesTest { ), ), ) - forInputExpectOutput(writable { attributeMacro.render(this) }, "#[cfg(all(feature = \"unstable\", feature = \"serialize\", feature = \"deserialize\"))]\n") + forInputExpectOutput( + writable { + attributeMacro.render(this) + }, + "#[cfg(all(feature = \"unstable\", feature = \"serialize\", feature = \"deserialize\"))]\n", + ) } @Test @@ -197,7 +207,12 @@ internal class RustTypesTest { RuntimeType.StdError, ), ) - forInputExpectOutput(writable { attributeMacro.render(this) }, "#[derive(std::clone::Clone, std::error::Error, std::fmt::Debug)]\n") + forInputExpectOutput( + writable { + attributeMacro.render(this) + }, + "#[derive(std::clone::Clone, std::error::Error, std::fmt::Debug)]\n", + ) } @Test diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt index 86d48a2602d..0d8f93617c9 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustWriterTest.kt @@ -10,6 +10,8 @@ import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain import io.kotest.matchers.string.shouldContainOnlyOnce import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.SetShape import software.amazon.smithy.model.shapes.StringShape @@ -96,7 +98,7 @@ class RustWriterTest { val symbol = testSymbolProvider(model).toSymbol(shape) val sut = RustWriter.forModule("lib") sut.docs("A link! #D", symbol) - sut.toString() shouldContain "/// A link! [`Foo`](crate::model::Foo)" + sut.toString() shouldContain "/// A link! [`Foo`](crate::test_model::Foo)" } @Test @@ -161,6 +163,23 @@ class RustWriterTest { sut.toString().shouldContain("inner: hello, regular: http::foo") } + @Test + fun `missing template parameters are enclosed in backticks in the exception message`() { + val sut = RustWriter.forModule("lib") + val exception = assertThrows { + sut.rustTemplate( + "#{Foo} #{Bar}", + "Foo Bar" to CargoDependency.Http.toType().resolve("foo"), + "Baz" to CargoDependency.Http.toType().resolve("foo"), + ) + } + exception.message shouldBe + """ + Rust block template expected `Foo` but was not present in template. + Hint: Template contains: [`Foo Bar`, `Baz`] + """.trimIndent() + } + @Test fun `can handle file paths properly when determining module`() { val sut = RustWriter.forModule("src/module_name") diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegatorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegatorTest.kt index 5be3d0a898b..c9407372d76 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegatorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenDelegatorTest.kt @@ -10,11 +10,12 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.CratesIo import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope.Compile +import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope.Dev class CodegenDelegatorTest { @Test fun testMergeDependencyFeatures() { - val merged = mergeDependencyFeatures( + val merged = listOf( CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf()), CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1")), @@ -26,8 +27,7 @@ class CodegenDelegatorTest { CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()), CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()), - ).shuffled(), - ) + ).shuffled().mergeDependencyFeatures() merged shouldBe setOf( CargoDependency("A", CratesIo("1"), Compile, optional = false, features = setOf("f1", "f2")), @@ -35,4 +35,19 @@ class CodegenDelegatorTest { CargoDependency("C", CratesIo("3"), Compile, optional = true, features = setOf()), ) } + + @Test + fun testMergeIdenticalFeatures() { + val merged = listOf( + CargoDependency("A", CratesIo("1"), Compile), + CargoDependency("A", CratesIo("1"), Dev), + CargoDependency("B", CratesIo("1"), Compile), + CargoDependency("B", CratesIo("1"), Dev, features = setOf("a", "b")), + ).mergeIdenticalTestDependencies() + merged shouldBe setOf( + CargoDependency("A", CratesIo("1"), Compile), + CargoDependency("B", CratesIo("1"), Compile), + CargoDependency("B", CratesIo("1"), Dev, features = setOf("a", "b")), + ) + } } diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/SymbolVisitorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitorTest.kt similarity index 91% rename from codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/SymbolVisitorTest.kt rename to codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitorTest.kt index bc26847bef5..3c47ca4f672 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/SymbolVisitorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolVisitorTest.kt @@ -3,12 +3,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rust.codegen.client.smithy +package software.amazon.smithy.rust.codegen.core.smithy import io.kotest.matchers.collections.shouldContain import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.shouldBe -import io.kotest.matchers.string.shouldContain import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.DisplayName import org.junit.jupiter.api.Test @@ -26,15 +25,10 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.SparseTrait -import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.render -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule -import software.amazon.smithy.rust.codegen.core.smithy.OperationsModule -import software.amazon.smithy.rust.codegen.core.smithy.isOptional -import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider class SymbolVisitorTest { private fun Symbol.referenceClosure(): List { @@ -57,8 +51,8 @@ class SymbolVisitorTest { val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(struct) sym.rustType().render(false) shouldBe "MyStruct" - sym.definitionFile shouldContain ModelsModule.definitionFile() - sym.namespace shouldBe "crate::model" + sym.definitionFile shouldBe "src/test_model.rs" + sym.namespace shouldBe "crate::test_model" } @Test @@ -77,7 +71,7 @@ class SymbolVisitorTest { val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(struct) sym.rustType().render(false) shouldBe "TerribleError" - sym.definitionFile shouldContain ErrorsModule.definitionFile() + sym.definitionFile shouldBe "src/test_error.rs" } @Test @@ -101,8 +95,8 @@ class SymbolVisitorTest { val provider: SymbolProvider = testSymbolProvider(model) val sym = provider.toSymbol(shape) sym.rustType().render(false) shouldBe "StandardUnit" - sym.definitionFile shouldContain ModelsModule.definitionFile() - sym.namespace shouldBe "crate::model" + sym.definitionFile shouldBe "src/test_model.rs" + sym.namespace shouldBe "crate::test_model" } @DisplayName("Creates primitives") @@ -260,7 +254,7 @@ class SymbolVisitorTest { } """.asSmithyModel() val symbol = testSymbolProvider(model).toSymbol(model.expectShape(ShapeId.from("smithy.example#PutObject"))) - symbol.definitionFile shouldBe(OperationsModule.definitionFile()) + symbol.definitionFile shouldBe "src/test_operation.rs" symbol.name shouldBe "PutObject" } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtraTest.kt similarity index 92% rename from codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseGeneratorTest.kt rename to codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtraTest.kt index c147567d712..120fc5cb208 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/customizations/SmithyTypesPubUseExtraTest.kt @@ -8,10 +8,11 @@ package software.amazon.smithy.rust.codegen.core.smithy.customizations import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGeneratorTest.Companion.model import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext -class SmithyTypesPubUseGeneratorTest { +class SmithyTypesPubUseExtraTest { private fun modelWithMember( inputMember: String = "", outputMember: String = "", @@ -48,7 +49,7 @@ class SmithyTypesPubUseGeneratorTest { outputMember: String = "", unionMember: String = "", additionalShape: String = "", - ) = pubUseTypes(TestRuntimeConfig, modelWithMember(inputMember, outputMember, unionMember, additionalShape)) + ) = pubUseTypes(testCodegenContext(model), modelWithMember(inputMember, outputMember, unionMember, additionalShape)) private fun assertDoesntHaveTypes(types: List, expectedTypes: List) = expectedTypes.forEach { assertDoesntHaveType(types, it) } @@ -71,11 +72,6 @@ class SmithyTypesPubUseGeneratorTest { } } - @Test - fun `it always re-exports SdkError`() { - assertHasType(typesWithEmptyModel(), "aws_smithy_http::result::SdkError") - } - @Test fun `it re-exports Blob when a model uses blobs`() { assertDoesntHaveType(typesWithEmptyModel(), "aws_smithy_types::Blob") diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt index 3d5c75384c6..ec107f3fcb8 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/BuilderGeneratorTest.kt @@ -7,18 +7,17 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.traits.EnumDefinition -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.AllowDeprecated +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.smithy.Default -import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig +import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.setDefault +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.core.testutil.unitTest internal class BuilderGeneratorTest { private val model = StructureGeneratorTest.model @@ -30,114 +29,112 @@ internal class BuilderGeneratorTest { @Test fun `generate builders`() { val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - writer.rust("##![allow(deprecated)]") - val innerGenerator = StructureGenerator(model, provider, writer, inner) - val generator = StructureGenerator(model, provider, writer, struct) - val builderGenerator = BuilderGenerator(model, provider, struct) - generator.render() - innerGenerator.render() - builderGenerator.render(writer) - writer.implBlock(struct, provider) { - builderGenerator.renderConvenienceMethod(this) + val project = TestWorkspace.testProject(provider) + project.moduleFor(inner) { + rust("##![allow(deprecated)]") + StructureGenerator(model, provider, this, inner, emptyList()).render() + StructureGenerator(model, provider, this, struct, emptyList()).render() + implBlock(provider.toSymbol(struct)) { + BuilderGenerator.renderConvenienceMethod(this, provider, struct) + } + unitTest("generate_builders") { + rust( + """ + let my_struct = MyStruct::builder().byte_value(4).foo("hello!").build(); + assert_eq!(my_struct.foo.unwrap(), "hello!"); + assert_eq!(my_struct.bar, 0); + """, + ) + } } - writer.compileAndTest( - """ - let my_struct = MyStruct::builder().byte_value(4).foo("hello!").build(); - assert_eq!(my_struct.foo.unwrap(), "hello!"); - assert_eq!(my_struct.bar, 0); - """, - ) + project.withModule(provider.moduleForBuilder(struct)) { + BuilderGenerator(model, provider, struct, emptyList()).render(this) + } + project.compileAndTest() } @Test fun `generate fallible builders`() { - val baseProvider: RustSymbolProvider = testSymbolProvider(StructureGeneratorTest.model) - val provider = - object : RustSymbolProvider { - override fun config(): SymbolVisitorConfig { - return baseProvider.config() - } - - override fun toEnumVariantName(definition: EnumDefinition): MaybeRenamed? { - return baseProvider.toEnumVariantName(definition) - } - - override fun toSymbol(shape: Shape?): Symbol { - return baseProvider.toSymbol(shape).toBuilder().setDefault(Default.NoDefault).build() - } + val baseProvider = testSymbolProvider(StructureGeneratorTest.model) + val provider = object : WrappingSymbolProvider(baseProvider) { + override fun toSymbol(shape: Shape): Symbol { + return baseProvider.toSymbol(shape).toBuilder().setDefault(Default.NoDefault).build() + } + } + val project = TestWorkspace.testProject(provider) - override fun toMemberName(shape: MemberShape?): String { - return baseProvider.toMemberName(shape) - } + project.moduleFor(StructureGeneratorTest.struct) { + AllowDeprecated.render(this) + StructureGenerator(model, provider, this, inner, emptyList()).render() + StructureGenerator(model, provider, this, struct, emptyList()).render() + implBlock(provider.toSymbol(struct)) { + BuilderGenerator.renderConvenienceMethod(this, provider, struct) + } + unitTest("generate_fallible_builders") { + rust( + """ + let my_struct = MyStruct::builder().byte_value(4).foo("hello!").bar(0).build().expect("required field was not provided"); + assert_eq!(my_struct.foo.unwrap(), "hello!"); + assert_eq!(my_struct.bar, 0); + """, + ) } - val writer = RustWriter.forModule("model") - writer.rust("##![allow(deprecated)]") - val innerGenerator = StructureGenerator( - StructureGeneratorTest.model, provider, writer, - StructureGeneratorTest.inner, - ) - val generator = StructureGenerator( - StructureGeneratorTest.model, provider, writer, - StructureGeneratorTest.struct, - ) - generator.render() - innerGenerator.render() - val builderGenerator = BuilderGenerator(model, provider, struct) - builderGenerator.render(writer) - writer.implBlock(struct, provider) { - builderGenerator.renderConvenienceMethod(this) } - writer.compileAndTest( - """ - let my_struct = MyStruct::builder().byte_value(4).foo("hello!").bar(0).build().expect("required field was not provided"); - assert_eq!(my_struct.foo.unwrap(), "hello!"); - assert_eq!(my_struct.bar, 0); - """, - ) + project.withModule(provider.moduleForBuilder(struct)) { + BuilderGenerator(model, provider, struct, emptyList()).render(this) + } + project.compileAndTest() } @Test fun `builder for a struct with sensitive fields should implement the debug trait as such`() { val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - val credsGenerator = StructureGenerator(model, provider, writer, credentials) - val builderGenerator = BuilderGenerator(model, provider, credentials) - credsGenerator.render() - builderGenerator.render(writer) - writer.implBlock(credentials, provider) { - builderGenerator.renderConvenienceMethod(this) + val project = TestWorkspace.testProject(provider) + project.moduleFor(credentials) { + StructureGenerator(model, provider, this, credentials, emptyList()).render() + implBlock(provider.toSymbol(credentials)) { + BuilderGenerator.renderConvenienceMethod(this, provider, credentials) + } + unitTest("sensitive_fields") { + rust( + """ + let builder = Credentials::builder() + .username("admin") + .password("pswd") + .secret_key("12345"); + assert_eq!(format!("{:?}", builder), "Builder { username: Some(\"admin\"), password: \"*** Sensitive Data Redacted ***\", secret_key: \"*** Sensitive Data Redacted ***\" }"); + """, + ) + } + } + project.withModule(provider.moduleForBuilder(credentials)) { + BuilderGenerator(model, provider, credentials, emptyList()).render(this) } - writer.compileAndTest( - """ - use super::*; - let builder = Credentials::builder() - .username("admin") - .password("pswd") - .secret_key("12345"); - assert_eq!(format!("{:?}", builder), "Builder { username: Some(\"admin\"), password: \"*** Sensitive Data Redacted ***\", secret_key: \"*** Sensitive Data Redacted ***\" }"); - """, - ) + project.compileAndTest() } @Test fun `builder for a sensitive struct should implement the debug trait as such`() { val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - val structGenerator = StructureGenerator(model, provider, writer, secretStructure) - val builderGenerator = BuilderGenerator(model, provider, secretStructure) - structGenerator.render() - builderGenerator.render(writer) - writer.implBlock(secretStructure, provider) { - builderGenerator.renderConvenienceMethod(this) + val project = TestWorkspace.testProject(provider) + project.moduleFor(secretStructure) { + StructureGenerator(model, provider, this, secretStructure, emptyList()).render() + implBlock(provider.toSymbol(secretStructure)) { + BuilderGenerator.renderConvenienceMethod(this, provider, secretStructure) + } + unitTest("sensitive_struct") { + rust( + """ + let builder = SecretStructure::builder() + .secret_field("secret"); + assert_eq!(format!("{:?}", builder), "Builder { secret_field: \"*** Sensitive Data Redacted ***\" }"); + """, + ) + } + } + project.withModule(provider.moduleForBuilder(secretStructure)) { + BuilderGenerator(model, provider, secretStructure, emptyList()).render(this) } - writer.compileAndTest( - """ - use super::*; - let builder = SecretStructure::builder() - .secret_field("secret"); - assert_eq!(format!("{:?}", builder), "Builder { secret_field: \"*** Sensitive Data Redacted ***\" }"); - """, - ) + project.compileAndTest() } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt index a36fbf0b08b..2da87e1d453 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt @@ -7,14 +7,18 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain +import io.kotest.matchers.string.shouldNotContain import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.EnumTrait -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.AllowDeprecated import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -47,8 +51,11 @@ class EnumGeneratorTest { private val enumTrait = testModel.lookup("test#EnumWithUnknown").expectTrait() - private fun model(name: String): EnumMemberModel = - EnumMemberModel(enumTrait.values.first { it.name.orNull() == name }, symbolProvider) + private fun model(name: String): EnumMemberModel = EnumMemberModel( + testModel.lookup("test#EnumWithUnknown"), + enumTrait.values.first { it.name.orNull() == name }, + symbolProvider, + ) @Test fun `it converts enum names to PascalCase and renames any named Unknown to UnknownValue`() { @@ -89,6 +96,15 @@ class EnumGeneratorTest { @Nested inner class EnumGeneratorTests { + fun RustWriter.renderEnum( + model: Model, + provider: RustSymbolProvider, + shape: StringShape, + enumType: EnumType = TestEnumType, + ) { + EnumGenerator(model, provider, shape, enumType).render(this) + } + @Test fun `it generates named enums`() { val model = """ @@ -113,30 +129,26 @@ class EnumGeneratorTest { """.asSmithyModel() val shape = model.lookup("test#InstanceType") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { + project.moduleFor(shape) { rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + renderEnum(model, provider, shape) unitTest( "it_generates_named_enums", """ let instance = InstanceType::T2Micro; assert_eq!(instance.as_str(), "t2.micro"); assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano); - assert_eq!(InstanceType::from("other"), InstanceType::Unknown(crate::types::UnknownVariantValue("other".to_owned()))); - // round trip unknown variants: - assert_eq!(InstanceType::from("other").as_str(), "other"); """, ) - val output = toString() - output.shouldContain("#[non_exhaustive]") - // on enum variant `T2Micro` - output.shouldContain("#[deprecated]") - // on enum itself - output.shouldContain("#[deprecated(since = \"1.2.3\")]") + toString().also { output -> + output.shouldContain("#[non_exhaustive]") + // on enum variant `T2Micro` + output.shouldContain("#[deprecated]") + // on enum itself + output.shouldContain("#[deprecated(since = \"1.2.3\")]") + } } project.compileAndTest() } @@ -158,12 +170,10 @@ class EnumGeneratorTest { """.asSmithyModel() val shape = model.lookup("test#FooEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + project.moduleFor(shape) { + renderEnum(model, provider, shape) unitTest( "named_enums_implement_eq_and_hash", """ @@ -193,13 +203,11 @@ class EnumGeneratorTest { """.asSmithyModel() val shape = model.lookup("test#FooEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + project.moduleFor(shape) { + AllowDeprecated.render(this) + renderEnum(model, provider, shape) unitTest( "unnamed_enums_implement_eq_and_hash", """ @@ -238,13 +246,11 @@ class EnumGeneratorTest { """.asSmithyModel() val shape = model.lookup("test#FooEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + project.moduleFor(shape) { + AllowDeprecated.render(this) + renderEnum(model, provider, shape) unitTest( "it_generates_unnamed_enums", """ @@ -257,304 +263,256 @@ class EnumGeneratorTest { } @Test - fun `it escapes the Unknown variant if the enum has an unknown value in the model`() { + fun `it should generate documentation for enums`() { val model = """ namespace test + + /// Some top-level documentation. @enum([ { name: "Known", value: "Known" }, { name: "Unknown", value: "Unknown" }, - { name: "UnknownValue", value: "UnknownValue" }, ]) string SomeEnum """.asSmithyModel() val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "it_escapes_the_unknown_variant_if_the_enum_has_an_unknown_value_in_the_model", + project.moduleFor(shape) { + renderEnum(model, provider, shape) + val rendered = toString() + rendered shouldContain """ - assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue); - assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned()))); - """.trimIndent(), - ) + /// Some top-level documentation. + /// + /// _Note: `SomeEnum::Unknown` has been renamed to `::UnknownValue`._ + """.trimIndent() } project.compileAndTest() } @Test - fun `it should generate documentation for enums`() { + fun `it should generate documentation for unnamed enums`() { val model = """ namespace test /// Some top-level documentation. @enum([ - { name: "Known", value: "Known" }, - { name: "Unknown", value: "Unknown" }, + { value: "One" }, + { value: "Two" }, ]) string SomeEnum """.asSmithyModel() val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + project.moduleFor(shape) { + renderEnum(model, provider, shape) val rendered = toString() rendered shouldContain """ /// Some top-level documentation. - /// - /// _Note: `SomeEnum::Unknown` has been renamed to `::UnknownValue`._ """.trimIndent() } project.compileAndTest() } @Test - fun `it should generate documentation for unnamed enums`() { + fun `it handles variants that clash with Rust reserved words`() { val model = """ namespace test + @enum([ + { name: "Known", value: "Known" }, + { name: "Self", value: "other" }, + ]) + string SomeEnum + """.asSmithyModel() - /// Some top-level documentation. + val shape = model.lookup("test#SomeEnum") + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.moduleFor(shape) { + renderEnum(model, provider, shape) + unitTest( + "it_handles_variants_that_clash_with_rust_reserved_words", + """assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue);""", + ) + } + project.compileAndTest() + } + + @Test + fun `impl debug for non-sensitive enum should implement the derived debug trait`() { + val model = """ + namespace test @enum([ - { value: "One" }, - { value: "Two" }, + { name: "Foo", value: "Foo" }, + { name: "Bar", value: "Bar" }, ]) string SomeEnum """.asSmithyModel() val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - val rendered = toString() - rendered shouldContain + project.moduleFor(shape) { + renderEnum(model, provider, shape) + unitTest( + "impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait", """ - /// Some top-level documentation. - """.trimIndent() + assert_eq!(format!("{:?}", SomeEnum::Foo), "Foo"); + assert_eq!(format!("{:?}", SomeEnum::Bar), "Bar"); + """, + ) } project.compileAndTest() } - } - @Test - fun `it handles variants that clash with Rust reserved words`() { - val model = """ - namespace test - @enum([ - { name: "Known", value: "Known" }, - { name: "Self", value: "other" }, - ]) - string SomeEnum - """.asSmithyModel() + @Test + fun `impl debug for sensitive enum should redact text`() { + val model = """ + namespace test + @sensitive + @enum([ + { name: "Foo", value: "Foo" }, + { name: "Bar", value: "Bar" }, + ]) + string SomeEnum + """.asSmithyModel() - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "it_handles_variants_that_clash_with_rust_reserved_words", - """ - assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned()))); - """.trimIndent(), - ) + val shape = model.lookup("test#SomeEnum") + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.moduleFor(shape) { + renderEnum(model, provider, shape) + unitTest( + "impl_debug_for_sensitive_enum_should_redact_text", + """ + assert_eq!(format!("{:?}", SomeEnum::Foo), $REDACTION); + assert_eq!(format!("{:?}", SomeEnum::Bar), $REDACTION); + """, + ) + } + project.compileAndTest() } - project.compileAndTest() - } - @Test - fun `matching on enum should be forward-compatible`() { - fun expectMatchExpressionCompiles(model: Model, shapeId: String, enumToMatchOn: String) { - val shape = model.lookup(shapeId) - val trait = shape.expectTrait() + @Test + fun `impl debug for non-sensitive unnamed enum should implement the derived debug trait`() { + val model = """ + namespace test + @enum([ + { value: "Foo" }, + { value: "Bar" }, + ]) + string SomeEnum + """.asSmithyModel() + + val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + project.moduleFor(shape) { + renderEnum(model, provider, shape) unitTest( - "matching_on_enum_should_be_forward_compatible", + "impl_debug_for_non_sensitive_unnamed_enum_should_implement_the_derived_debug_trait", """ - match $enumToMatchOn { - SomeEnum::Variant1 => assert!(false, "expected `Variant3` but got `Variant1`"), - SomeEnum::Variant2 => assert!(false, "expected `Variant3` but got `Variant2`"), - other @ _ if other.as_str() == "Variant3" => assert!(true), - _ => assert!(false, "expected `Variant3` but got `_`"), + for variant in SomeEnum::values() { + assert_eq!( + format!("{:?}", SomeEnum(variant.to_string())), + format!("SomeEnum(\"{}\")", variant.to_owned()) + ); } - """.trimIndent(), + """, ) } project.compileAndTest() } - val modelV1 = """ - namespace test - - @enum([ - { name: "Variant1", value: "Variant1" }, - { name: "Variant2", value: "Variant2" }, - ]) - string SomeEnum - """.asSmithyModel() - val variant3AsUnknown = """SomeEnum::from("Variant3")""" - expectMatchExpressionCompiles(modelV1, "test#SomeEnum", variant3AsUnknown) - - val modelV2 = """ - namespace test - - @enum([ - { name: "Variant1", value: "Variant1" }, - { name: "Variant2", value: "Variant2" }, - { name: "Variant3", value: "Variant3" }, - ]) - string SomeEnum - """.asSmithyModel() - val variant3AsVariant3 = "SomeEnum::Variant3" - expectMatchExpressionCompiles(modelV2, "test#SomeEnum", variant3AsVariant3) - } - - @Test - fun `impl debug for non-sensitive enum should implement the derived debug trait`() { - val model = """ - namespace test - @enum([ - { name: "Foo", value: "Foo" }, - { name: "Bar", value: "Bar" }, - ]) - string SomeEnum - """.asSmithyModel() + @Test + fun `impl debug for sensitive unnamed enum should redact text`() { + val model = """ + namespace test + @sensitive + @enum([ + { value: "Foo" }, + { value: "Bar" }, + ]) + string SomeEnum + """.asSmithyModel() - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait", - """ - assert_eq!(format!("{:?}", SomeEnum::Foo), "Foo"); - assert_eq!(format!("{:?}", SomeEnum::Bar), "Bar"); - assert_eq!( - format!("{:?}", SomeEnum::from("Baz")), - "Unknown(UnknownVariantValue(\"Baz\"))" - ); - """, - ) + val shape = model.lookup("test#SomeEnum") + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.moduleFor(shape) { + renderEnum(model, provider, shape) + unitTest( + "impl_debug_for_sensitive_unnamed_enum_should_redact_text", + """ + for variant in SomeEnum::values() { + assert_eq!( + format!("{:?}", SomeEnum(variant.to_string())), + $REDACTION + ); + } + """, + ) + } + project.compileAndTest() } - project.compileAndTest() - } - @Test - fun `impl debug for sensitive enum should redact text`() { - val model = """ - namespace test - @sensitive - @enum([ - { name: "Foo", value: "Foo" }, - { name: "Bar", value: "Bar" }, - ]) - string SomeEnum - """.asSmithyModel() - - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "impl_debug_for_sensitive_enum_should_redact_text", - """ - assert_eq!(format!("{:?}", SomeEnum::Foo), $REDACTION); - assert_eq!(format!("{:?}", SomeEnum::Bar), $REDACTION); - """, - ) - } - project.compileAndTest() - } + @Test + fun `it supports other enum types`() { + class CustomizingEnumType : EnumType() { + override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { + // intentional no-op + } - @Test - fun `impl debug for non-sensitive unnamed enum should implement the derived debug trait`() { - val model = """ - namespace test - @enum([ - { value: "Foo" }, - { value: "Bar" }, - ]) - string SomeEnum - """.asSmithyModel() + override fun implFromStr(context: EnumGeneratorContext): Writable = writable { + // intentional no-op + } - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "impl_debug_for_non_sensitive_unnamed_enum_should_implement_the_derived_debug_trait", - """ - for variant in SomeEnum::values() { - assert_eq!( - format!("{:?}", SomeEnum(variant.to_string())), - format!("SomeEnum(\"{}\")", variant.to_owned()) - ); + override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable { + rust("// additional enum members") } - """, - ) - } - project.compileAndTest() - } - @Test - fun `impl debug for sensitive unnamed enum should redact text`() { - val model = """ - namespace test - @sensitive - @enum([ - { value: "Foo" }, - { value: "Bar" }, - ]) - string SomeEnum - """.asSmithyModel() + override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = writable { + rust("// additional as_str match arm") + } - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "impl_debug_for_sensitive_unnamed_enum_should_redact_text", - """ - for variant in SomeEnum::values() { - assert_eq!( - format!("{:?}", SomeEnum(variant.to_string())), - $REDACTION - ); + override fun additionalDocs(context: EnumGeneratorContext): Writable = writable { + rust("// additional docs") } - """, - ) + } + + val model = """ + namespace test + @enum([ + { name: "Known", value: "Known" }, + { name: "Self", value: "other" }, + ]) + string SomeEnum + """.asSmithyModel() + val shape = model.lookup("test#SomeEnum") + + val provider = testSymbolProvider(model) + val output = RustWriter.root().apply { + renderEnum(model, provider, shape, CustomizingEnumType()) + }.toString() + + // Since we didn't use the Infallible EnumType, there should be no Unknown variant + output shouldNotContain "Unknown" + output shouldNotContain "unknown" + output shouldNotContain "impl From" + output shouldNotContain "impl FromStr" + output shouldContain "// additional enum members" + output shouldContain "// additional as_str match arm" + output shouldContain "// additional docs" + + val project = TestWorkspace.testProject(provider) + project.moduleFor(shape) { + renderEnum(model, provider, shape, CustomizingEnumType()) + } + project.compileAndTest() } - project.compileAndTest() } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt index a0f8200c50f..c5d877394f6 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/InstantiatorTest.kt @@ -14,7 +14,6 @@ import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -25,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext +import software.amazon.smithy.rust.codegen.core.testutil.testModule import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.lookup @@ -82,7 +82,7 @@ class InstantiatorTest { @required num: Integer } - """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } + """.asSmithyModel().let { RecursiveShapeBoxer().transform(it) } private val codegenContext = testCodegenContext(model) private val symbolProvider = codegenContext.symbolProvider @@ -108,8 +108,8 @@ class InstantiatorTest { Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val data = Node.parse("""{ "stringVariant": "ok!" }""") - val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { + val project = TestWorkspace.testProject(model) + project.moduleFor(union) { UnionGenerator(model, symbolProvider, this, union).render() unitTest("generate_unions") { withBlock("let result = ", ";") { @@ -128,9 +128,9 @@ class InstantiatorTest { Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val data = Node.parse("""{ "bar": 10, "foo": "hello" }""") - val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { - structure.renderWithModelBuilder(model, symbolProvider, this) + val project = TestWorkspace.testProject(model) + structure.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(structure) { unitTest("generate_struct_builders") { withBlock("let result = ", ";") { sut.render(this, structure, data) @@ -162,9 +162,9 @@ class InstantiatorTest { """, ) - val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { - structure.renderWithModelBuilder(model, symbolProvider, this) + val project = TestWorkspace.testProject(model) + structure.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(structure) { unitTest("generate_builders_for_boxed_structs") { withBlock("let result = ", ";") { sut.render(this, structure, data) @@ -192,7 +192,7 @@ class InstantiatorTest { Instantiator(symbolProvider, model, runtimeConfig, BuilderKindBehavior(codegenContext), ::enumFromStringFn) val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { + project.lib { unitTest("generate_lists") { withBlock("let result = ", ";") { sut.render(this, model.lookup("com.test#MyList"), data) @@ -213,8 +213,8 @@ class InstantiatorTest { ::enumFromStringFn, ) - val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { + val project = TestWorkspace.testProject(model) + project.lib { unitTest("generate_sparse_lists") { withBlock("let result = ", ";") { sut.render(this, model.lookup("com.test#MySparseList"), data) @@ -245,9 +245,9 @@ class InstantiatorTest { ) val inner = model.lookup("com.test#Inner") - val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { - inner.renderWithModelBuilder(model, symbolProvider, this) + val project = TestWorkspace.testProject(model) + inner.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(inner) { unitTest("generate_maps_of_maps") { withBlock("let result = ", ";") { sut.render(this, model.lookup("com.test#NestedMap"), data) @@ -277,8 +277,8 @@ class InstantiatorTest { ::enumFromStringFn, ) - val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { + val project = TestWorkspace.testProject(model) + project.testModule { unitTest("blob_inputs_are_binary_data") { withBlock("let blob = ", ";") { sut.render( diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt index 37d73291ef8..e6be7eee97f 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/StructureGeneratorTest.kt @@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel @@ -93,8 +92,8 @@ class StructureGeneratorTest { val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.useShapeWriter(inner) { - StructureGenerator(model, provider, this, inner).render() - StructureGenerator(model, provider, this, struct).render() + StructureGenerator(model, provider, this, inner, emptyList()).render() + StructureGenerator(model, provider, this, struct, emptyList()).render() unitTest( "struct_fields_optional", """ @@ -111,16 +110,16 @@ class StructureGeneratorTest { @Test fun `generate structures with public fields`() { - val project = TestWorkspace.testProject() val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) project.lib { Attribute.AllowDeprecated.render(this) } - project.withModule(ModelsModule) { - val innerGenerator = StructureGenerator(model, provider, this, inner) + project.moduleFor(inner) { + val innerGenerator = StructureGenerator(model, provider, this, inner, emptyList()) innerGenerator.render() } project.withModule(RustModule.public("structs")) { - val generator = StructureGenerator(model, provider, this, struct) + val generator = StructureGenerator(model, provider, this, struct, emptyList()) generator.render() } // By putting the test in another module, it can't access the struct @@ -139,25 +138,11 @@ class StructureGeneratorTest { project.compileAndTest() } - @Test - fun `generate error structures`() { - val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("error") - val generator = StructureGenerator(model, provider, writer, error) - generator.render() - writer.compileAndTest( - """ - let err = MyError { message: None }; - assert_eq!(err.retryable_error_kind(), aws_smithy_types::retry::ErrorKind::ServerError); - """, - ) - } - @Test fun `generate a custom debug implementation when the sensitive trait is applied to some members`() { val provider = testSymbolProvider(model) val writer = RustWriter.forModule("lib") - val generator = StructureGenerator(model, provider, writer, credentials) + val generator = StructureGenerator(model, provider, writer, credentials, emptyList()) generator.render() writer.unitTest( "sensitive_fields_redacted", @@ -177,7 +162,7 @@ class StructureGeneratorTest { fun `generate a custom debug implementation when the sensitive trait is applied to the struct`() { val provider = testSymbolProvider(model) val writer = RustWriter.forModule("lib") - val generator = StructureGenerator(model, provider, writer, secretStructure) + val generator = StructureGenerator(model, provider, writer, secretStructure, emptyList()) generator.render() writer.unitTest( "sensitive_structure_redacted", @@ -196,8 +181,8 @@ class StructureGeneratorTest { val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.useShapeWriter(inner) { - val secretGenerator = StructureGenerator(model, provider, this, secretStructure) - val generator = StructureGenerator(model, provider, this, structWithInnerSecretStructure) + val secretGenerator = StructureGenerator(model, provider, this, secretStructure, emptyList()) + val generator = StructureGenerator(model, provider, this, structWithInnerSecretStructure, emptyList()) secretGenerator.render() generator.render() unitTest( @@ -239,9 +224,9 @@ class StructureGeneratorTest { project.lib { Attribute.DenyMissingDocs.render(this) } - project.withModule(ModelsModule) { - StructureGenerator(model, provider, this, model.lookup("com.test#Inner")).render() - StructureGenerator(model, provider, this, model.lookup("com.test#MyStruct")).render() + project.moduleFor(model.lookup("com.test#Inner")) { + StructureGenerator(model, provider, this, model.lookup("com.test#Inner"), emptyList()).render() + StructureGenerator(model, provider, this, model.lookup("com.test#MyStruct"), emptyList()).render() } project.compileAndTest() @@ -251,7 +236,7 @@ class StructureGeneratorTest { fun `documents are optional in structs`() { val provider = testSymbolProvider(model) val writer = RustWriter.forModule("lib") - StructureGenerator(model, provider, writer, structWithDoc).render() + StructureGenerator(model, provider, writer, structWithDoc, emptyList()).render() writer.compileAndTest( """ @@ -283,11 +268,11 @@ class StructureGeneratorTest { val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.lib { rust("##![allow(deprecated)]") } - project.withModule(ModelsModule) { - StructureGenerator(model, provider, this, model.lookup("test#Foo")).render() - StructureGenerator(model, provider, this, model.lookup("test#Bar")).render() - StructureGenerator(model, provider, this, model.lookup("test#Baz")).render() - StructureGenerator(model, provider, this, model.lookup("test#Qux")).render() + project.moduleFor(model.lookup("test#Foo")) { + StructureGenerator(model, provider, this, model.lookup("test#Foo"), emptyList()).render() + StructureGenerator(model, provider, this, model.lookup("test#Bar"), emptyList()).render() + StructureGenerator(model, provider, this, model.lookup("test#Baz"), emptyList()).render() + StructureGenerator(model, provider, this, model.lookup("test#Qux"), emptyList()).render() } // turn on clippy to check the semver-compliant version of `since`. @@ -316,10 +301,10 @@ class StructureGeneratorTest { val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.lib { rust("##![allow(deprecated)]") } - project.withModule(ModelsModule) { - StructureGenerator(model, provider, this, model.lookup("test#Nested")).render() - StructureGenerator(model, provider, this, model.lookup("test#Foo")).render() - StructureGenerator(model, provider, this, model.lookup("test#Bar")).render() + project.moduleFor(model.lookup("test#Nested")) { + StructureGenerator(model, provider, this, model.lookup("test#Nested"), emptyList()).render() + StructureGenerator(model, provider, this, model.lookup("test#Foo"), emptyList()).render() + StructureGenerator(model, provider, this, model.lookup("test#Bar"), emptyList()).render() } project.compileAndTest() @@ -328,7 +313,7 @@ class StructureGeneratorTest { @Test fun `it generates accessor methods`() { val testModel = - RecursiveShapeBoxer.transform( + RecursiveShapeBoxer().transform( """ namespace test @@ -366,10 +351,10 @@ class StructureGeneratorTest { val project = TestWorkspace.testProject(provider) project.useShapeWriter(inner) { - StructureGenerator(testModel, provider, this, testModel.lookup("test#One")).render() - StructureGenerator(testModel, provider, this, testModel.lookup("test#Two")).render() + StructureGenerator(testModel, provider, this, testModel.lookup("test#One"), emptyList()).render() + StructureGenerator(testModel, provider, this, testModel.lookup("test#Two"), emptyList()).render() - rustBlock("fn compile_test_one(one: &crate::model::One)") { + rustBlock("fn compile_test_one(one: &crate::test_model::One)") { rust( """ let _: Option<&str> = one.field_string(); @@ -390,17 +375,17 @@ class StructureGeneratorTest { let _: f32 = one.field_primitive_float(); let _: Option = one.field_double(); let _: f64 = one.field_primitive_double(); - let _: Option<&crate::model::Two> = one.two(); + let _: Option<&crate::test_model::Two> = one.two(); let _: Option = one.build_value(); let _: Option = one.builder_value(); let _: Option = one.default_value(); """, ) } - rustBlock("fn compile_test_two(two: &crate::model::Two)") { + rustBlock("fn compile_test_two(two: &crate::test_model::Two)") { rust( """ - let _: Option<&crate::model::One> = two.one(); + let _: Option<&crate::test_model::One> = two.one(); """, ) } @@ -424,7 +409,7 @@ class StructureGeneratorTest { val provider = testSymbolProvider(model) RustWriter.forModule("test").let { - StructureGenerator(model, provider, it, struct).render() + StructureGenerator(model, provider, it, struct, emptyList()).render() assertEquals(6, it.toString().split("#[doc(hidden)]").size, "there should be 5 doc-hiddens") } } @@ -440,7 +425,7 @@ class StructureGeneratorTest { val provider = testSymbolProvider(model) RustWriter.forModule("test").let { writer -> - StructureGenerator(model, provider, writer, struct).render() + StructureGenerator(model, provider, writer, struct, emptyList()).render() writer.toString().shouldNotContain("#[doc(hidden)]") } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt new file mode 100644 index 00000000000..e8ea12c0dbc --- /dev/null +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.generators + +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.dq + +object TestEnumType : EnumType() { + override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { + rustTemplate( + """ + impl #{From}<&str> for ${context.enumName} { + fn from(s: &str) -> Self { + match s { + #{matchArms} + } + } + } + """, + "From" to RuntimeType.From, + "matchArms" to writable { + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},") + } + rust("_ => panic!()") + }, + ) + } + + override fun implFromStr(context: EnumGeneratorContext): Writable = writable { + rust( + """ + impl std::str::FromStr for ${context.enumName} { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> std::result::Result { + Ok(${context.enumName}::from(s)) + } + } + """, + ) + } +} diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt index 8a66ad112e9..8b6778890f0 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/UnionGeneratorTest.kt @@ -10,7 +10,6 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -117,7 +116,7 @@ class UnionGeneratorTest { val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.lib { rust("##![allow(deprecated)]") } - project.withModule(ModelsModule) { + project.moduleFor(model.lookup("test#Nested")) { UnionGenerator(model, provider, this, model.lookup("test#Nested")).render() UnionGenerator(model, provider, this, model.lookup("test#Foo")).render() UnionGenerator(model, provider, this, model.lookup("test#Bar")).render() diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGeneratorTest.kt new file mode 100644 index 00000000000..678cc808b5b --- /dev/null +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ErrorImplGeneratorTest.kt @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.generators.error + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder +import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.core.util.getTrait + +class ErrorImplGeneratorTest { + val model = + """ + namespace com.test + + @error("server") + @retryable + structure MyError { + message: String + } + """.asSmithyModel() + + @Test + fun `generate error structures`() { + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + val errorShape = model.expectShape(ShapeId.from("com.test#MyError")) as StructureShape + errorShape.renderWithModelBuilder(model, provider, project) + project.moduleFor(errorShape) { + val errorTrait = errorShape.getTrait()!! + ErrorImplGenerator(model, provider, this, errorShape, errorTrait, emptyList()).render(CodegenTarget.CLIENT) + compileAndTest( + """ + let err = MyError::builder().build(); + assert_eq!(err.retryable_error_kind(), aws_smithy_types::retry::ErrorKind::ServerError); + """, + ) + } + } +} diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/OperationErrorGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/OperationErrorGeneratorTest.kt deleted file mode 100644 index 38b27ecf4f7..00000000000 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/OperationErrorGeneratorTest.kt +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.core.smithy.generators.error - -import org.junit.jupiter.api.Test -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule -import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer -import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace -import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel -import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest -import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder -import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider -import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.lookup - -class OperationErrorGeneratorTest { - private val baseModel = """ - namespace error - - operation Greeting { - errors: [InvalidGreeting, ComplexError, FooException, Deprecated] - } - - @error("client") - @retryable - structure InvalidGreeting { - message: String, - } - - @error("server") - structure FooException { } - - @error("server") - structure ComplexError { - abc: String, - other: Integer - } - - @error("server") - @deprecated - structure Deprecated { } - """.asSmithyModel() - private val model = OperationNormalizer.transform(baseModel) - private val symbolProvider = testSymbolProvider(model) - - @Test - fun `generates combined error enums`() { - val project = TestWorkspace.testProject(symbolProvider) - project.withModule(ErrorsModule) { - listOf("FooException", "ComplexError", "InvalidGreeting", "Deprecated").forEach { - model.lookup("error#$it").renderWithModelBuilder(model, symbolProvider, this) - } - val errors = listOf("FooException", "ComplexError", "InvalidGreeting").map { model.lookup("error#$it") } - val generator = OperationErrorGenerator(model, symbolProvider, symbolProvider.toSymbol(model.lookup("error#Greeting")), errors) - generator.render(this) - - unitTest( - name = "generates_combined_error_enums", - test = """ - let kind = GreetingErrorKind::InvalidGreeting(InvalidGreeting::builder().message("an error").build()); - let error = GreetingError::new(kind, aws_smithy_types::Error::builder().code("InvalidGreeting").message("an error").build()); - assert_eq!(format!("{}", error), "InvalidGreeting: an error"); - assert_eq!(error.message(), Some("an error")); - assert_eq!(error.code(), Some("InvalidGreeting")); - use aws_smithy_types::retry::ProvideErrorKind; - assert_eq!(error.retryable_error_kind(), Some(aws_smithy_types::retry::ErrorKind::ClientError)); - - // Generate is_xyz methods for errors. - assert_eq!(error.is_invalid_greeting(), true); - assert_eq!(error.is_complex_error(), false); - - // Unhandled variants properly delegate message. - let error = GreetingError::generic(aws_smithy_types::Error::builder().message("hello").build()); - assert_eq!(error.message(), Some("hello")); - - let error = GreetingError::unhandled("some other error"); - assert_eq!(error.message(), None); - assert_eq!(error.code(), None); - - // Indicate the original name in the display output. - let error = FooError::builder().build(); - assert_eq!(format!("{}", error), "FooError [FooException]"); - - let error = Deprecated::builder().build(); - assert_eq!(error.to_string(), "Deprecated"); - """, - ) - - project.compileAndTest() - } - } -} diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ServiceErrorGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ServiceErrorGeneratorTest.kt deleted file mode 100644 index 746927d134b..00000000000 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/error/ServiceErrorGeneratorTest.kt +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.core.smithy.generators.error - -import org.junit.jupiter.api.Test -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.AttributeKind -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings -import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors -import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel -import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext -import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider -import software.amazon.smithy.rust.codegen.core.util.runCommand -import kotlin.io.path.ExperimentalPathApi -import kotlin.io.path.createDirectory -import kotlin.io.path.writeText - -internal class ServiceErrorGeneratorTest { - @ExperimentalPathApi - @Test - fun `top level errors are send + sync`() { - val model = """ - namespace com.example - - use aws.protocols#restJson1 - - @restJson1 - service HelloService { - operations: [SayHello], - version: "1" - } - - @http(uri: "/", method: "POST") - operation SayHello { - input: EmptyStruct, - output: EmptyStruct, - errors: [SorryBusy, CanYouRepeatThat, MeDeprecated] - } - - structure EmptyStruct { } - - @error("server") - structure SorryBusy { } - - @error("client") - structure CanYouRepeatThat { } - - @error("client") - @deprecated - structure MeDeprecated { } - """.asSmithyModel() - - val (pluginContext, testDir) = generatePluginContext(model) - val moduleName = pluginContext.settings.expectStringMember("module").value.replace('-', '_') - val symbolProvider = testSymbolProvider(model) - val settings = CoreRustSettings.from(model, pluginContext.settings) - val codegenContext = CodegenContext( - model, - symbolProvider, - model.expectShape(ShapeId.from("com.example#HelloService")) as ServiceShape, - ShapeId.from("aws.protocols#restJson1"), - settings, - CodegenTarget.CLIENT, - ) - - val rustCrate = RustCrate( - pluginContext.fileManifest, - symbolProvider, - codegenContext.settings.codegenConfig, - ) - - rustCrate.lib { - Attribute.AllowDeprecated.render(this, AttributeKind.Inner) - } - rustCrate.withModule(RustModule.Error) { - for (operation in model.operationShapes) { - if (operation.id.namespace == "com.example") { - OperationErrorGenerator( - model, - symbolProvider, - symbolProvider.toSymbol(operation), - operation.operationErrors(model).map { it as StructureShape }, - ).render(this) - } - } - for (shape in model.structureShapes) { - if (shape.id.namespace == "com.example") { - StructureGenerator(model, symbolProvider, this, shape).render(CodegenTarget.CLIENT) - } - } - } - ServiceErrorGenerator(codegenContext, model.operationShapes.toList()).render(rustCrate) - - testDir.resolve("tests").createDirectory() - testDir.resolve("tests/validate_errors.rs").writeText( - """ - fn check_send_sync() {} - #[test] - fn tl_errors_are_send_sync() { - check_send_sync::<$moduleName::Error>() - } - """, - ) - rustCrate.finalize(settings, model, emptyMap(), emptyList(), false) - - "cargo test".runCommand(testDir) - } -} diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/InlineFunctionNamerTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/InlineFunctionNamerTest.kt index a988872b8bb..1685f66a834 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/InlineFunctionNamerTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/InlineFunctionNamerTest.kt @@ -102,12 +102,12 @@ class InlineFunctionNamerTest { symbolProvider.deserializeFunctionName(testModel.lookup(shapeId)) shouldBe "deser_$suffix" } - test("test#Op1", "operation_crate_operation_op1") + test("test#Op1", "operation_crate_test_operation_op1") test("test#SomeList1", "list_test_some_list1") test("test#SomeMap1", "map_test_some_map1") test("test#SomeSet1", "set_test_some_set1") - test("test#SomeStruct1", "structure_crate_model_some_struct1") - test("test#SomeUnion1", "union_crate_model_some_union1") + test("test#SomeStruct1", "structure_crate_test_model_some_struct1") + test("test#SomeUnion1", "union_crate_test_model_some_union1") test("test#SomeStruct1\$some_string", "member_test_some_struct1_some_string") } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt index b90543bea3e..38beea1e1be 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/AwsQueryParserGeneratorTest.kt @@ -8,9 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig @@ -43,13 +41,12 @@ class AwsQueryParserGeneratorTest { @Test fun `it modifies operation parsing to include Response and Result tags`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = AwsQueryParserGenerator( codegenContext, RuntimeType.wrappedXmlErrors(TestRuntimeConfig), - builderSymbolFn(symbolProvider), ) val operationParser = parserGenerator.operationParser(model.lookup("test#SomeOperation"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) @@ -65,20 +62,17 @@ class AwsQueryParserGeneratorTest { "#; - let output = ${format(operationParser)}(xml, output::some_operation_output::Builder::default()).unwrap().build(); + let output = ${format(operationParser)}(xml, test_output::SomeOperationOutput::builder()).unwrap().build(); assert_eq!(output.some_attribute, Some(5)); assert_eq!(output.some_val, Some("Some value".to_string())); """, ) } - - project.withModule(RustModule.public("model")) { - model.lookup("test#SomeOutput").renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#SomeOutput").also { struct -> + struct.renderWithModelBuilder(model, symbolProvider, project) } - - project.withModule(RustModule.public("output")) { - model.lookup("test#SomeOperation").outputShape(model) - .renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#SomeOperation").outputShape(model).also { output -> + output.renderWithModelBuilder(model, symbolProvider, project) } project.compileAndTest() } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt index 7b835d82234..9a51b072538 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/Ec2QueryParserGeneratorTest.kt @@ -8,9 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig @@ -43,13 +41,12 @@ class Ec2QueryParserGeneratorTest { @Test fun `it modifies operation parsing to include Response and Result tags`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = Ec2QueryParserGenerator( codegenContext, RuntimeType.wrappedXmlErrors(TestRuntimeConfig), - builderSymbolFn(symbolProvider), ) val operationParser = parserGenerator.operationParser(model.lookup("test#SomeOperation"))!! val project = TestWorkspace.testProject(testSymbolProvider(model)) @@ -63,20 +60,19 @@ class Ec2QueryParserGeneratorTest { Some value "#; - let output = ${format(operationParser)}(xml, output::some_operation_output::Builder::default()).unwrap().build(); + let output = ${format(operationParser)}(xml, test_output::SomeOperationOutput::builder()).unwrap().build(); assert_eq!(output.some_attribute, Some(5)); assert_eq!(output.some_val, Some("Some value".to_string())); """, ) } - project.withModule(RustModule.public("model")) { - model.lookup("test#SomeOutput").renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#SomeOutput").also { struct -> + struct.renderWithModelBuilder(model, symbolProvider, project) } - project.withModule(RustModule.public("output")) { - model.lookup("test#SomeOperation").outputShape(model) - .renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#SomeOperation").outputShape(model).also { output -> + output.renderWithModelBuilder(model, symbolProvider, project) } project.compileAndTest() } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt index 9bd71a7a4ef..79435b5e9b8 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt @@ -6,14 +6,12 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols.parse import org.junit.jupiter.api.Test -import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolContentTypes import software.amazon.smithy.rust.codegen.core.smithy.protocols.restJsonFieldName @@ -26,7 +24,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -115,17 +112,14 @@ class JsonParserGeneratorTest { @Test fun `generates valid deserializers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider - fun builderSymbol(shape: StructureShape): Symbol = - shape.builderSymbol(symbolProvider) val parserGenerator = JsonParserGenerator( codegenContext, HttpTraitHttpBindingResolver(model, ProtocolContentTypes.consistent("application/json")), ::restJsonFieldName, - ::builderSymbol, ) val operationGenerator = parserGenerator.operationParser(model.lookup("test#Op")) val payloadGenerator = parserGenerator.payloadParser(model.lookup("test#OpOutput\$top")) @@ -136,7 +130,7 @@ class JsonParserGeneratorTest { unitTest( "json_parser", """ - use model::Choice; + use test_model::Choice; // Generate the document serializer even though it's not tested directly // ${format(payloadGenerator)} @@ -151,7 +145,7 @@ class JsonParserGeneratorTest { } "#; - let output = ${format(operationGenerator!!)}(json, output::op_output::Builder::default()).unwrap().build(); + let output = ${format(operationGenerator!!)}(json, test_output::OpOutput::builder()).unwrap().build(); let top = output.top.expect("top"); assert_eq!(Some(45), top.extra); assert_eq!(Some("something".to_string()), top.field); @@ -162,7 +156,7 @@ class JsonParserGeneratorTest { "empty_body", """ // empty body - let output = ${format(operationGenerator)}(b"", output::op_output::Builder::default()).unwrap().build(); + let output = ${format(operationGenerator)}(b"", test_output::OpOutput::builder()).unwrap().build(); assert_eq!(output.top, None); """, ) @@ -171,7 +165,7 @@ class JsonParserGeneratorTest { """ // unknown variant let input = br#"{ "top": { "choice": { "somenewvariant": "data" } } }"#; - let output = ${format(operationGenerator)}(input, output::op_output::Builder::default()).unwrap().build(); + let output = ${format(operationGenerator)}(input, test_output::OpOutput::builder()).unwrap().build(); assert!(output.top.unwrap().choice.unwrap().is_unknown()); """, ) @@ -180,7 +174,7 @@ class JsonParserGeneratorTest { "empty_error", """ // empty error - let error_output = ${format(errorParser!!)}(b"", error::error::Builder::default()).unwrap().build(); + let error_output = ${format(errorParser!!)}(b"", test_error::Error::builder()).unwrap().build(); assert_eq!(error_output.message, None); """, ) @@ -189,24 +183,25 @@ class JsonParserGeneratorTest { "error_with_message", """ // error with message - let error_output = ${format(errorParser)}(br#"{"message": "hello"}"#, error::error::Builder::default()).unwrap().build(); + let error_output = ${format(errorParser)}(br#"{"message": "hello"}"#, test_error::Error::builder()).unwrap().build(); assert_eq!(error_output.message.expect("message should be set"), "hello"); """, ) } - project.withModule(RustModule.public("model")) { - model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) - model.lookup("test#EmptyStruct").renderWithModelBuilder(model, symbolProvider, this) - UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() - val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + model.lookup("test#EmptyStruct").renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } } - - project.withModule(RustModule.public("output")) { - model.lookup("test#Op").outputShape(model).renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#Op").outputShape(model).also { output -> + output.renderWithModelBuilder(model, symbolProvider, project) } - project.withModule(RustModule.public("error")) { - model.lookup("test#Error").renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#Error").also { error -> + error.renderWithModelBuilder(model, symbolProvider, project) } project.compileAndTest() } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt index 50fb343d6d8..47f310e83df 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt @@ -9,11 +9,12 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig @@ -24,7 +25,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -93,21 +93,22 @@ internal class XmlBindingTraitParserGeneratorTest { @Test fun `generates valid parsers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = XmlBindingTraitParserGenerator( codegenContext, RuntimeType.wrappedXmlErrors(TestRuntimeConfig), - builderSymbolFn(symbolProvider), ) { _, inner -> inner("decoder") } val operationParser = parserGenerator.operationParser(model.lookup("test#Op"))!! + + val choiceShape = model.lookup("test#Choice") val project = TestWorkspace.testProject(testSymbolProvider(model)) project.lib { - unitTest( - name = "valid_input", - test = """ - let xml = br#" + unitTest(name = "valid_input") { + rustTemplate( + """ + let xml = br##" some key @@ -118,19 +119,21 @@ internal class XmlBindingTraitParserGeneratorTest { hey - "#; - let output = ${format(operationParser)}(xml, output::op_output::Builder::default()).unwrap().build(); + "##; + let output = ${format(operationParser)}(xml, test_output::OpOutput::builder()).unwrap().build(); let mut map = std::collections::HashMap::new(); - map.insert("some key".to_string(), model::Choice::S("hello".to_string())); - assert_eq!(output.choice, Some(model::Choice::FlatMap(map))); + map.insert("some key".to_string(), #{Choice}::S("hello".to_string())); + assert_eq!(output.choice, Some(#{Choice}::FlatMap(map))); assert_eq!(output.renamed_with_prefix.as_deref(), Some("hey")); - """, - ) - - unitTest( - name = "ignore_extras", - test = """ - let xml = br#" + """, + "Choice" to symbolProvider.toSymbol(choiceShape), + ) + } + + unitTest(name = "ignore_extras") { + rustTemplate( + """ + let xml = br##" @@ -146,13 +149,15 @@ internal class XmlBindingTraitParserGeneratorTest { - "#; - let output = ${format(operationParser)}(xml, output::op_output::Builder::default()).unwrap().build(); + "##; + let output = ${format(operationParser)}(xml, test_output::OpOutput::builder()).unwrap().build(); let mut map = std::collections::HashMap::new(); - map.insert("some key".to_string(), model::Choice::S("hello".to_string())); - assert_eq!(output.choice, Some(model::Choice::FlatMap(map))); - """, - ) + map.insert("some key".to_string(), #{Choice}::S("hello".to_string())); + assert_eq!(output.choice, Some(#{Choice}::FlatMap(map))); + """, + "Choice" to symbolProvider.toSymbol(choiceShape), + ) + } unitTest( name = "nopanics_on_invalid", @@ -174,7 +179,7 @@ internal class XmlBindingTraitParserGeneratorTest { "#; - ${format(operationParser)}(xml, output::op_output::Builder::default()).expect("unknown union variant does not cause failure"); + ${format(operationParser)}(xml, test_output::OpOutput::builder()).expect("unknown union variant does not cause failure"); """, ) unitTest( @@ -191,20 +196,23 @@ internal class XmlBindingTraitParserGeneratorTest { "#; - let output = ${format(operationParser)}(xml, output::op_output::Builder::default()).unwrap().build(); + let output = ${format(operationParser)}(xml, test_output::OpOutput::builder()).unwrap().build(); assert!(output.choice.unwrap().is_unknown()); """, ) } - project.withModule(RustModule.public("model")) { - model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) - UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() - val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator(model, symbolProvider, this, choiceShape).render() + model.lookup("test#FooEnum").also { enum -> + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } + } } - project.withModule(RustModule.public("output")) { - model.lookup("test#Op").outputShape(model).renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#Op").outputShape(model).also { out -> + out.renderWithModelBuilder(model, symbolProvider, project) } project.compileAndTest() } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt index f8ef938bea7..ad44554e36d 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt @@ -10,9 +10,9 @@ import org.junit.jupiter.params.provider.CsvSource import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer @@ -22,7 +22,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup @@ -93,7 +92,7 @@ class AwsQuerySerializerGeneratorTest { true -> CodegenTarget.CLIENT false -> CodegenTarget.SERVER } - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model, codegenTarget = codegenTarget) val symbolProvider = codegenContext.symbolProvider val parserGenerator = AwsQuerySerializerGenerator(testCodegenContext(model, codegenTarget = codegenTarget)) @@ -104,9 +103,9 @@ class AwsQuerySerializerGeneratorTest { unitTest( "query_serializer", """ - use model::Top; + use test_model::Top; - let input = crate::input::OpInput::builder() + let input = crate::test_input::OpInput::builder() .top( Top::builder() .field("hello!") @@ -133,15 +132,23 @@ class AwsQuerySerializerGeneratorTest { """, ) } - project.withModule(RustModule.public("model")) { - model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) - UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice"), renderUnknownVariant = generateUnknownVariant).render() - val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator( + model, + symbolProvider, + this, + model.lookup("test#Choice"), + renderUnknownVariant = generateUnknownVariant, + ).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } } - project.withModule(RustModule.public("input")) { - model.lookup("test#Op").inputShape(model).renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#Op").inputShape(model).also { input -> + input.renderWithModelBuilder(model, symbolProvider, project) } project.compileAndTest() } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt index b3a21898ee0..2436aff7061 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt @@ -9,8 +9,8 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer @@ -21,7 +21,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup @@ -86,7 +85,7 @@ class Ec2QuerySerializerGeneratorTest { @Test fun `generates valid serializers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = Ec2QuerySerializerGenerator(codegenContext) @@ -97,9 +96,9 @@ class Ec2QuerySerializerGeneratorTest { unitTest( "ec2query_serializer", """ - use model::Top; + use test_model::Top; - let input = crate::input::OpInput::builder() + let input = crate::test_input::OpInput::builder() .top( Top::builder() .field("hello!") @@ -126,15 +125,17 @@ class Ec2QuerySerializerGeneratorTest { """, ) } - project.withModule(RustModule.public("model")) { - model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) - UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() - val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } } - project.withModule(RustModule.public("input")) { - model.lookup("test#Op").inputShape(model).renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#Op").inputShape(model).also { input -> + input.renderWithModelBuilder(model, symbolProvider, project) } project.compileAndTest() } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt index bf3fb604da4..23c27f331b1 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt @@ -9,8 +9,8 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolContentTypes @@ -24,7 +24,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup @@ -101,7 +100,7 @@ class JsonSerializerGeneratorTest { @Test fun `generates valid serializers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserSerializer = JsonSerializerGenerator( @@ -117,12 +116,12 @@ class JsonSerializerGeneratorTest { unitTest( "json_serializers", """ - use model::{Top, Choice}; + use test_model::{Top, Choice}; // Generate the document serializer even though it's not tested directly // ${format(documentGenerator)} - let input = crate::input::OpInput::builder().top( + let input = crate::test_input::OpInput::builder().top( Top::builder() .field("hello!") .extra(45) @@ -133,7 +132,7 @@ class JsonSerializerGeneratorTest { let output = std::str::from_utf8(serialized.bytes().unwrap()).unwrap(); assert_eq!(output, r#"{"top":{"field":"hello!","extra":45,"rec":[{"extra":55}]}}"#); - let input = crate::input::OpInput::builder().top( + let input = crate::test_input::OpInput::builder().top( Top::builder() .choice(Choice::Unknown) .build() @@ -142,15 +141,17 @@ class JsonSerializerGeneratorTest { """, ) } - project.withModule(RustModule.public("model")) { - model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) - UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() - val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } } - project.withModule(RustModule.public("input")) { - model.lookup("test#Op").inputShape(model).renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#Op").inputShape(model).also { input -> + input.renderWithModelBuilder(model, symbolProvider, project) } project.compileAndTest() } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt index f8f9aafa7bc..a695d2a4018 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt @@ -9,8 +9,8 @@ import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolContentTypes @@ -23,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup @@ -106,7 +105,7 @@ internal class XmlBindingTraitSerializerGeneratorTest { @Test fun `generates valid serializers`() { - val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel)) + val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel)) val codegenContext = testCodegenContext(model) val symbolProvider = codegenContext.symbolProvider val parserGenerator = XmlBindingTraitSerializerGenerator( @@ -120,8 +119,8 @@ internal class XmlBindingTraitSerializerGeneratorTest { unitTest( "serialize_xml", """ - use model::Top; - let inp = crate::input::OpInput::builder().payload( + use test_model::Top; + let inp = crate::test_input::OpInput::builder().payload( Top::builder() .field("hello!") .extra(45) @@ -136,8 +135,8 @@ internal class XmlBindingTraitSerializerGeneratorTest { unitTest( "unknown_variants", """ - use model::{Top, Choice}; - let input = crate::input::OpInput::builder().payload( + use test_model::{Top, Choice}; + let input = crate::test_input::OpInput::builder().payload( Top::builder() .choice(Choice::Unknown) .build() @@ -146,15 +145,16 @@ internal class XmlBindingTraitSerializerGeneratorTest { """, ) } - project.withModule(RustModule.public("model")) { - model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) - UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() - val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + model.lookup("test#Top").also { top -> + top.renderWithModelBuilder(model, symbolProvider, project) + project.moduleFor(top) { + UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) + } } - - project.withModule(RustModule.public("input")) { - model.lookup("test#Op").inputShape(model).renderWithModelBuilder(model, symbolProvider, this) + model.lookup("test#Op").inputShape(model).also { input -> + input.renderWithModelBuilder(model, symbolProvider, project) } project.compileAndTest() } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt index 061814a73a6..293e2217131 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapeBoxerTest.kt @@ -31,7 +31,7 @@ internal class RecursiveShapeBoxerTest { hello: Hello } """.asSmithyModel() - RecursiveShapeBoxer.transform(model) shouldBe model + RecursiveShapeBoxer().transform(model) shouldBe model } @Test @@ -43,7 +43,7 @@ internal class RecursiveShapeBoxerTest { anotherField: Boolean } """.asSmithyModel() - val transformed = RecursiveShapeBoxer.transform(model) + val transformed = RecursiveShapeBoxer().transform(model) val member: MemberShape = transformed.lookup("com.example#Recursive\$RecursiveStruct") member.expectTrait() } @@ -70,7 +70,7 @@ internal class RecursiveShapeBoxerTest { third: SecondTree } """.asSmithyModel() - val transformed = RecursiveShapeBoxer.transform(model) + val transformed = RecursiveShapeBoxer().transform(model) val boxed = transformed.shapes().filter { it.hasTrait() }.toList() boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe setOf( "Atom\$add", diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt index bea21fd0833..8b33c576945 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/transformers/RecursiveShapesIntegrationTest.kt @@ -10,9 +10,10 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider @@ -42,23 +43,30 @@ class RecursiveShapesIntegrationTest { third: SecondTree } """.asSmithyModel() + val check = { input: Model -> + val symbolProvider = testSymbolProvider(model) + val project = TestWorkspace.testProject(symbolProvider) val structures = listOf("Expr", "SecondTree").map { input.lookup("com.example#$it") } - val writer = RustWriter.forModule("model") - val symbolProvider = testSymbolProvider(input) - structures.forEach { - StructureGenerator(input, symbolProvider, writer, it).render() + structures.forEach { struct -> + project.moduleFor(struct) { + StructureGenerator(input, symbolProvider, this, struct, emptyList()).render() + } + } + input.lookup("com.example#Atom").also { atom -> + project.moduleFor(atom) { + UnionGenerator(input, symbolProvider, this, atom).render() + } } - UnionGenerator(input, symbolProvider, writer, input.lookup("com.example#Atom")).render() - writer + project } - val unmodifiedWriter = check(model) + val unmodifiedProject = check(model) val output = assertThrows { - unmodifiedWriter.compileAndTest(expectFailure = true) + unmodifiedProject.compileAndTest(expectFailure = true) } output.message shouldContain "has infinite size" - val fixedWriter = check(RecursiveShapeBoxer.transform(model)) - fixedWriter.compileAndTest() + val fixedProject = check(RecursiveShapeBoxer().transform(model)) + fixedProject.compileAndTest() } } diff --git a/codegen-server-test/build.gradle.kts b/codegen-server-test/build.gradle.kts index b2841452c14..d8ec164a406 100644 --- a/codegen-server-test/build.gradle.kts +++ b/codegen-server-test/build.gradle.kts @@ -38,6 +38,7 @@ dependencies { val allCodegenTests = "../codegen-core/common-test-models".let { commonModels -> listOf( CodegenTest("crate#Config", "naming_test_ops", imports = listOf("$commonModels/naming-obstacle-course-ops.smithy")), + CodegenTest("casing#ACRONYMInside_Service", "naming_test_casing", imports = listOf("$commonModels/naming-obstacle-course-casing.smithy")), CodegenTest( "naming_obs_structs#NamingObstacleCourseStructs", "naming_test_structs", diff --git a/codegen-server/build.gradle.kts b/codegen-server/build.gradle.kts index 8dd6dc5d186..f5b2f5bc29a 100644 --- a/codegen-server/build.gradle.kts +++ b/codegen-server/build.gradle.kts @@ -26,6 +26,10 @@ dependencies { implementation(project(":codegen-core")) implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") + + // `smithy.framework#ValidationException` is defined here, which is used in `constraints.smithy`, which is used + // in `CustomValidationExceptionWithReasonDecoratorTest`. + testImplementation("software.amazon.smithy:smithy-validation-model:$smithyVersion") } tasks.compileKotlin { kotlinOptions.jvmTarget = "1.8" } diff --git a/codegen-server/python/build.gradle.kts b/codegen-server/python/build.gradle.kts index 5a23bd5d7dc..bbc30ac6e72 100644 --- a/codegen-server/python/build.gradle.kts +++ b/codegen-server/python/build.gradle.kts @@ -24,7 +24,6 @@ val smithyVersion: String by project dependencies { implementation(project(":codegen-core")) - implementation(project(":codegen-client")) implementation(project(":codegen-server")) implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion") diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index 8fd2a07f583..51fdb95c8c5 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -16,17 +16,21 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator +import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerOperationHandlerGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerServiceGenerator import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerStructureGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenVisitor +import software.amazon.smithy.rust.codegen.server.smithy.ServerModuleProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings import software.amazon.smithy.rust.codegen.server.smithy.ServerSymbolProviders import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol @@ -41,15 +45,16 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtoco */ class PythonServerCodegenVisitor( context: PluginContext, - codegenDecorator: ServerCodegenDecorator, + private val codegenDecorator: ServerCodegenDecorator, ) : ServerCodegenVisitor(context, codegenDecorator) { init { - val symbolVisitorConfig = - SymbolVisitorConfig( + val rustSymbolProviderConfig = + RustSymbolProviderConfig( runtimeConfig = settings.runtimeConfig, renameExceptions = false, nullabilityCheckMode = NullableIndex.CheckMode.SERVER, + moduleProvider = ServerModuleProvider, ) val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) @@ -70,16 +75,19 @@ class PythonServerCodegenVisitor( settings = settings.copy(codegenConfig = settings.codegenConfig.copy(publicConstrainedTypes = false)) fun baseSymbolProviderFactory( + settings: ServerRustSettings, model: Model, serviceShape: ServiceShape, - symbolVisitorConfig: SymbolVisitorConfig, + rustSymbolProviderConfig: RustSymbolProviderConfig, publicConstrainedTypes: Boolean, - ) = PythonCodegenServerPlugin.baseSymbolProvider(model, serviceShape, symbolVisitorConfig, publicConstrainedTypes) + includeConstraintShapeProvider: Boolean, + ) = RustServerCodegenPythonPlugin.baseSymbolProvider(settings, model, serviceShape, rustSymbolProviderConfig, publicConstrainedTypes) val serverSymbolProviders = ServerSymbolProviders.from( + settings, model, service, - symbolVisitorConfig, + rustSymbolProviderConfig, settings.codegenConfig.publicConstrainedTypes, ::baseSymbolProviderFactory, ) @@ -119,7 +127,18 @@ class PythonServerCodegenVisitor( rustCrate.useShapeWriter(shape) { // Use Python specific structure generator that adds the #[pyclass] attribute // and #[pymethods] implementation. - PythonServerStructureGenerator(model, codegenContext.symbolProvider, this, shape).render(CodegenTarget.SERVER) + PythonServerStructureGenerator(model, codegenContext.symbolProvider, this, shape).render() + + shape.getTrait()?.also { errorTrait -> + ErrorImplGenerator( + model, + codegenContext.symbolProvider, + this, + shape, + errorTrait, + codegenDecorator.errorImplCustomizations(codegenContext, emptyList()), + ).render(CodegenTarget.SERVER) + } renderStructureShapeBuilder(shape, this) } @@ -131,8 +150,8 @@ class PythonServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - fun pythonServerEnumGeneratorFactory(codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) = - PythonServerEnumGenerator(codegenContext, writer, shape) + fun pythonServerEnumGeneratorFactory(codegenContext: ServerCodegenContext, shape: StringShape) = + PythonServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) stringShape(shape, ::pythonServerEnumGeneratorFactory) } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt index ca750bbb5ca..51f1f3caac5 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerSymbolProvider.kt @@ -22,15 +22,16 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.smithy.SymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isStreaming +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings /** * Symbol visitor allowing that recursively replace symbols in nested shapes. @@ -44,11 +45,12 @@ import software.amazon.smithy.rust.codegen.core.util.isStreaming * `aws_smithy_http_server_python::types`. */ class PythonServerSymbolVisitor( - private val model: Model, + settings: ServerRustSettings, + model: Model, serviceShape: ServiceShape?, - config: SymbolVisitorConfig, -) : SymbolVisitor(model, serviceShape, config) { - private val runtimeConfig = config().runtimeConfig + config: RustSymbolProviderConfig, +) : SymbolVisitor(settings, model, serviceShape, config) { + private val runtimeConfig = config.runtimeConfig override fun toSymbol(shape: Shape): Symbol { val initial = shape.accept(this) @@ -68,7 +70,7 @@ class PythonServerSymbolVisitor( // For example a TimestampShape doesn't become a different symbol when streaming is involved, but BlobShape // become a ByteStream. return if (target is BlobShape && shape.isStreaming(model)) { - PythonServerRuntimeType.byteStream(config().runtimeConfig).toSymbol() + PythonServerRuntimeType.byteStream(config.runtimeConfig).toSymbol() } else { initial } @@ -95,19 +97,23 @@ class PythonServerSymbolVisitor( * * Note that since streaming members can only be used on the root shape, this can only impact input and output shapes. */ -class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider, private val model: Model) : SymbolMetadataProvider(base) { +class PythonStreamingShapeMetadataProvider(private val base: RustSymbolProvider) : SymbolMetadataProvider(base) { override fun structureMeta(structureShape: StructureShape): RustMetadata { val baseMetadata = base.toSymbol(structureShape).expectRustMetadata() return if (structureShape.hasStreamingMember(model)) { baseMetadata.withoutDerives(RuntimeType.PartialEq) - } else baseMetadata + } else { + baseMetadata + } } override fun unionMeta(unionShape: UnionShape): RustMetadata { val baseMetadata = base.toSymbol(unionShape).expectRustMetadata() return if (unionShape.hasStreamingMember(model)) { baseMetadata.withoutDerives(RuntimeType.PartialEq) - } else baseMetadata + } else { + baseMetadata + } } override fun memberMeta(memberShape: MemberShape) = base.toSymbol(memberShape).expectRustMetadata() diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonType.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonType.kt new file mode 100644 index 00000000000..c5606b5d38a --- /dev/null +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonType.kt @@ -0,0 +1,176 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.python.smithy + +import software.amazon.smithy.rust.codegen.core.rustlang.RustType + +/** + * A hierarchy of Python types handled by Smithy codegen. + * + * Mostly copied from [RustType] and modified for Python accordingly. + */ +sealed class PythonType { + /** + * A Python type that contains [member], another [PythonType]. + * Used to generically operate over shapes that contain other shape. + */ + sealed interface Container { + val member: PythonType + val namespace: String? + val name: String + } + + /** + * Name refers to the top-level type for import purposes. + */ + abstract val name: String + + open val namespace: String? = null + + object None : PythonType() { + override val name: String = "None" + } + + object Bool : PythonType() { + override val name: String = "bool" + } + + object Int : PythonType() { + override val name: String = "int" + } + + object Float : PythonType() { + override val name: String = "float" + } + + object Str : PythonType() { + override val name: String = "str" + } + + object Any : PythonType() { + override val name: String = "Any" + override val namespace: String = "typing" + } + + data class List(override val member: PythonType) : PythonType(), Container { + override val name: String = "List" + override val namespace: String = "typing" + } + + data class Dict(val key: PythonType, override val member: PythonType) : PythonType(), Container { + override val name: String = "Dict" + override val namespace: String = "typing" + } + + data class Set(override val member: PythonType) : PythonType(), Container { + override val name: String = "Set" + override val namespace: String = "typing" + } + + data class Optional(override val member: PythonType) : PythonType(), Container { + override val name: String = "Optional" + override val namespace: String = "typing" + } + + data class Awaitable(override val member: PythonType) : PythonType(), Container { + override val name: String = "Awaitable" + override val namespace: String = "typing" + } + + data class Callable(val args: kotlin.collections.List, val rtype: PythonType) : PythonType() { + override val name: String = "Callable" + override val namespace: String = "typing" + } + + data class Union(val args: kotlin.collections.List) : PythonType() { + override val name: String = "Union" + override val namespace: String = "typing" + } + + data class Opaque(override val name: String, val rustNamespace: String? = null) : PythonType() { + // Since Python doesn't have a something like Rust's `crate::` we are using a custom placeholder here + // and in our stub generation script we will replace placeholder with the real root module name. + private val pythonRootModulePlaceholder = "__root_module_name__" + + override val namespace: String? = rustNamespace?.split("::")?.joinToString(".") { + when (it) { + "crate" -> pythonRootModulePlaceholder + // In Python, we expose submodules from `aws_smithy_http_server_python` + // like `types`, `middleware`, `tls` etc. from `__root_module__name` + "aws_smithy_http_server_python" -> pythonRootModulePlaceholder + else -> it + } + } + } +} + +/** + * Return corresponding [PythonType] for a [RustType]. + */ +fun RustType.pythonType(): PythonType = + when (this) { + is RustType.Unit -> PythonType.None + is RustType.Bool -> PythonType.Bool + is RustType.Float -> PythonType.Float + is RustType.Integer -> PythonType.Int + is RustType.String -> PythonType.Str + is RustType.Vec -> PythonType.List(this.member.pythonType()) + is RustType.Slice -> PythonType.List(this.member.pythonType()) + is RustType.HashMap -> PythonType.Dict(this.key.pythonType(), this.member.pythonType()) + is RustType.HashSet -> PythonType.Set(this.member.pythonType()) + is RustType.Reference -> this.member.pythonType() + is RustType.Option -> PythonType.Optional(this.member.pythonType()) + is RustType.Box -> this.member.pythonType() + is RustType.Dyn -> this.member.pythonType() + is RustType.Opaque -> PythonType.Opaque(this.name, this.namespace) + // TODO(Constraints): How to handle this? + // Revisit as part of https://github.com/awslabs/smithy-rs/issues/2114 + is RustType.MaybeConstrained -> this.member.pythonType() + } + +/** + * Render this type, including references and generic parameters. + * It generates something like `typing.Dict[String, String]`. + */ +fun PythonType.render(fullyQualified: Boolean = true): String { + val namespace = if (fullyQualified) { + this.namespace?.let { "$it." } ?: "" + } else { + "" + } + val base = when (this) { + is PythonType.None -> this.name + is PythonType.Bool -> this.name + is PythonType.Float -> this.name + is PythonType.Int -> this.name + is PythonType.Str -> this.name + is PythonType.Any -> this.name + is PythonType.Opaque -> this.name + is PythonType.List -> "${this.name}[${this.member.render(fullyQualified)}]" + is PythonType.Dict -> "${this.name}[${this.key.render(fullyQualified)}, ${this.member.render(fullyQualified)}]" + is PythonType.Set -> "${this.name}[${this.member.render(fullyQualified)}]" + is PythonType.Awaitable -> "${this.name}[${this.member.render(fullyQualified)}]" + is PythonType.Optional -> "${this.name}[${this.member.render(fullyQualified)}]" + is PythonType.Callable -> { + val args = this.args.joinToString(", ") { it.render(fullyQualified) } + val rtype = this.rtype.render(fullyQualified) + "${this.name}[[$args], $rtype]" + } + is PythonType.Union -> { + val args = this.args.joinToString(", ") { it.render(fullyQualified) } + "${this.name}[$args]" + } + } + return "$namespace$base" +} + +/** + * Renders [PythonType] with proper escaping for Docstrings. + */ +fun PythonType.renderAsDocstring(): String = + this.render() + .replace("[", "\\[") + .replace("]", "\\]") diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/RustServerCodegenPythonPlugin.kt similarity index 71% rename from codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt rename to codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/RustServerCodegenPythonPlugin.kt index 7e277185f7b..b61bb725ae1 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonCodegenServerPlugin.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/RustServerCodegenPythonPlugin.kt @@ -14,28 +14,34 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolP import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolMetadataProvider import software.amazon.smithy.rust.codegen.server.smithy.ConstrainedShapeSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.DeriveEqAndHashSymbolMetadataProvider +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings +import software.amazon.smithy.rust.codegen.server.smithy.customizations.CustomValidationExceptionWithReasonDecorator import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionDecorator import software.amazon.smithy.rust.codegen.server.smithy.customize.CombinedServerCodegenDecorator import java.util.logging.Level import java.util.logging.Logger /** - * Rust with Python bindings Codegen Plugin. + * Rust Server with Python bindings Codegen Plugin. + * * This is the entrypoint for code generation, triggered by the smithy-build plugin. * `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which * enables the smithy-build plugin to invoke `execute` with all of the Smithy plugin context + models. */ -class PythonCodegenServerPlugin : SmithyBuildPlugin { +class RustServerCodegenPythonPlugin : SmithyBuildPlugin { private val logger = Logger.getLogger(javaClass.name) override fun getName(): String = "rust-server-codegen-python" + /** + * See [software.amazon.smithy.rust.codegen.client.smithy.RustClientCodegenPlugin]. + */ override fun execute(context: PluginContext) { // Suppress extremely noisy logs about reserved words Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF @@ -47,7 +53,10 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin { val codegenDecorator: CombinedServerCodegenDecorator = CombinedServerCodegenDecorator.fromClasspath( context, - CombinedServerCodegenDecorator(DECORATORS + ServerRequiredCustomizations()), + ServerRequiredCustomizations(), + SmithyValidationExceptionDecorator(), + CustomValidationExceptionWithReasonDecorator(), + *DECORATORS, ) // PythonServerCodegenVisitor is the main driver of code generation that traverses the model and generates code @@ -57,36 +66,36 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin { companion object { /** - * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider - * - * The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered - * with other symbol providers, documented inline, to handle the full scope of Smithy types. + * See [software.amazon.smithy.rust.codegen.client.smithy.RustClientCodegenPlugin]. */ fun baseSymbolProvider( + settings: ServerRustSettings, model: Model, serviceShape: ServiceShape, - symbolVisitorConfig: SymbolVisitorConfig, + rustSymbolProviderConfig: RustSymbolProviderConfig, constrainedTypes: Boolean = true, ) = // Rename a set of symbols that do not implement `PyClass` and have been wrapped in // `aws_smithy_http_server_python::types`. - PythonServerSymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) + PythonServerSymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) // Generate public constrained types for directly constrained shapes. // In the Python server project, this is only done to generate constrained types for simple shapes (e.g. // a `string` shape with the `length` trait), but these always remain `pub(crate)`. - .let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape) else it } + .let { + if (constrainedTypes) ConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it + } // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) } + .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.SERVER) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes - .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf()) } + .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) } // Constrained shapes generate newtypes that need the same derives we place on types generated from aggregate shapes. - .let { ConstrainedShapeSymbolMetadataProvider(it, model, constrainedTypes) } + .let { ConstrainedShapeSymbolMetadataProvider(it, constrainedTypes) } // Streaming shapes need different derives (e.g. they cannot derive Eq) - .let { PythonStreamingShapeMetadataProvider(it, model) } + .let { PythonStreamingShapeMetadataProvider(it) } // Derive `Eq` and `Hash` if possible. - .let { DeriveEqAndHashSymbolMetadataProvider(it, model) } + .let { DeriveEqAndHashSymbolMetadataProvider(it) } // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot // be the name of an operation input - .let { RustReservedWordSymbolProvider(it, model) } + .let { RustReservedWordSymbolProvider(it) } } } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt index ff07ba1207e..9e353db2cd0 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt @@ -98,6 +98,7 @@ class PubUsePythonTypesDecorator : ServerCodegenDecorator { /** * Generates `pyproject.toml` for the crate. * - Configures Maturin as the build system + * - Configures Python source directory */ class PyProjectTomlDecorator : ServerCodegenDecorator { override val name: String = "PyProjectTomlDecorator" @@ -110,6 +111,11 @@ class PyProjectTomlDecorator : ServerCodegenDecorator { "requires" to listOfNotNull("maturin>=0.14,<0.15"), "build-backend" to "maturin", ).toMap(), + "tool" to listOfNotNull( + "maturin" to listOfNotNull( + "python-source" to "python", + ).toMap(), + ).toMap(), ) writeWithNoFormatting(TomlWriter().write(config)) } @@ -134,7 +140,61 @@ class PyO3ExtensionModuleDecorator : ServerCodegenDecorator { } } -val DECORATORS = listOf( +/** + * Generates `__init__.py` for the Python source. + * + * This file allows Python module to be imported like: + * ``` + * import pokemon_service_server_sdk + * pokemon_service_server_sdk.App() + * ``` + * instead of: + * ``` + * from pokemon_service_server_sdk import pokemon_service_server_sdk + * ``` + */ +class InitPyDecorator : ServerCodegenDecorator { + override val name: String = "InitPyDecorator" + override val order: Byte = 0 + + override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + val libName = codegenContext.settings.moduleName.toSnakeCase() + + rustCrate.withFile("python/$libName/__init__.py") { + writeWithNoFormatting( + """ + from .$libName import * + + __doc__ = $libName.__doc__ + if hasattr($libName, "__all__"): + __all__ = $libName.__all__ + """.trimIndent(), + ) + } + } +} + +/** + * Generates `py.typed` for the Python source. + * + * This marker file is required to be PEP 561 compliant stub package. + * Type definitions will be ignored by `mypy` if the package is not PEP 561 compliant: + * https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-library-stubs-or-py-typed-marker + */ +class PyTypedMarkerDecorator : ServerCodegenDecorator { + override val name: String = "PyTypedMarkerDecorator" + override val order: Byte = 0 + + override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + val libName = codegenContext.settings.moduleName.toSnakeCase() + + rustCrate.withFile("python/$libName/py.typed") { + writeWithNoFormatting("") + } + } +} + +val DECORATORS = arrayOf( /** * Add the [InternalServerError] error to all operations. * This is done because the Python interpreter can raise exceptions during execution. @@ -150,4 +210,8 @@ val DECORATORS = listOf( PyProjectTomlDecorator(), // Add PyO3 extension module feature. PyO3ExtensionModuleDecorator(), + // Generate `__init__.py` for the Python source. + InitPyDecorator(), + // Generate `py.typed` for the Python source. + PyTypedMarkerDecorator(), ) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 793118b4410..8db6dc84396 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -12,9 +12,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule -import software.amazon.smithy.rust.codegen.core.smithy.InputsModule -import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -22,8 +19,13 @@ import software.amazon.smithy.rust.codegen.core.util.outputShape import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency +import software.amazon.smithy.rust.codegen.server.python.smithy.PythonType +import software.amazon.smithy.rust.codegen.server.python.smithy.renderAsDocstring import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Error as ErrorModule +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Input as InputModule +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output as OutputModule /** * Generates a Python compatible application and server that can be configured from Python. @@ -103,6 +105,9 @@ class PythonApplicationGenerator( """ ##[#{pyo3}::pyclass] ##[derive(Debug)] + /// :generic Ctx: + /// :extends typing.Generic\[Ctx\]: + /// :rtype None: pub struct App { handlers: #{HashMap}, middlewares: Vec<#{SmithyPython}::PyMiddlewareHandler>, @@ -239,6 +244,12 @@ class PythonApplicationGenerator( """, *codegenScope, ) { + val middlewareRequest = PythonType.Opaque("Request", "crate::middleware") + val middlewareResponse = PythonType.Opaque("Response", "crate::middleware") + val middlewareNext = PythonType.Callable(listOf(middlewareRequest), PythonType.Awaitable(middlewareResponse)) + val middlewareFunc = PythonType.Callable(listOf(middlewareRequest, middlewareNext), PythonType.Awaitable(middlewareResponse)) + val tlsConfig = PythonType.Opaque("TlsConfig", "crate::tls") + rustTemplate( """ /// Create a new [App]. @@ -246,12 +257,20 @@ class PythonApplicationGenerator( pub fn new() -> Self { Self::default() } + /// Register a context object that will be shared between handlers. + /// + /// :param context Ctx: + /// :rtype ${PythonType.None.renderAsDocstring()}: ##[pyo3(text_signature = "(${'$'}self, context)")] pub fn context(&mut self, context: #{pyo3}::PyObject) { self.context = Some(context); } + /// Register a Python function to be executed inside a Tower middleware layer. + /// + /// :param func ${middlewareFunc.renderAsDocstring()}: + /// :rtype ${PythonType.None.renderAsDocstring()}: ##[pyo3(text_signature = "(${'$'}self, func)")] pub fn middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { let handler = #{SmithyPython}::PyMiddlewareHandler::new(py, func)?; @@ -263,8 +282,16 @@ class PythonApplicationGenerator( self.middlewares.push(handler); Ok(()) } + /// Main entrypoint: start the server on multiple workers. - ##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers, tls)")] + /// + /// :param address ${PythonType.Optional(PythonType.Str).renderAsDocstring()}: + /// :param port ${PythonType.Optional(PythonType.Int).renderAsDocstring()}: + /// :param backlog ${PythonType.Optional(PythonType.Int).renderAsDocstring()}: + /// :param workers ${PythonType.Optional(PythonType.Int).renderAsDocstring()}: + /// :param tls ${PythonType.Optional(tlsConfig).renderAsDocstring()}: + /// :rtype ${PythonType.None.renderAsDocstring()}: + ##[pyo3(text_signature = "(${'$'}self, address=None, port=None, backlog=None, workers=None, tls=None)")] pub fn run( &mut self, py: #{pyo3}::Python, @@ -277,7 +304,10 @@ class PythonApplicationGenerator( use #{SmithyPython}::PyApp; self.run_server(py, address, port, backlog, workers, tls) } + /// Lambda entrypoint: start the server on Lambda. + /// + /// :rtype ${PythonType.None.renderAsDocstring()}: ##[pyo3(text_signature = "(${'$'}self)")] pub fn run_lambda( &mut self, @@ -286,8 +316,9 @@ class PythonApplicationGenerator( use #{SmithyPython}::PyApp; self.run_lambda_handler(py) } + /// Build the service and start a single worker. - ##[pyo3(text_signature = "(${'$'}self, socket, worker_number, tls)")] + ##[pyo3(text_signature = "(${'$'}self, socket, worker_number, tls=None)")] pub fn start_worker( &mut self, py: pyo3::Python, @@ -306,10 +337,31 @@ class PythonApplicationGenerator( operations.map { operation -> val operationName = symbolProvider.toSymbol(operation).name val name = operationName.toSnakeCase() + + val input = PythonType.Opaque("${operationName}Input", "crate::input") + val output = PythonType.Opaque("${operationName}Output", "crate::output") + val context = PythonType.Opaque("Ctx") + val returnType = PythonType.Union(listOf(output, PythonType.Awaitable(output))) + val handler = PythonType.Union( + listOf( + PythonType.Callable( + listOf(input, context), + returnType, + ), + PythonType.Callable( + listOf(input), + returnType, + ), + ), + ) + rustTemplate( """ /// Method to register `$name` Python implementation inside the handlers map. /// It can be used as a function decorator in Python. + /// + /// :param func ${handler.renderAsDocstring()}: + /// :rtype ${PythonType.None.renderAsDocstring()}: ##[pyo3(text_signature = "(${'$'}self, func)")] pub fn $name(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { use #{SmithyPython}::PyApp; @@ -338,12 +390,12 @@ class PythonApplicationGenerator( ) writer.rust( """ - /// from $libName import ${InputsModule.name} - /// from $libName import ${OutputsModule.name} + /// from $libName import ${InputModule.name} + /// from $libName import ${OutputModule.name} """.trimIndent(), ) if (operations.any { it.errors.isNotEmpty() }) { - writer.rust("""/// from $libName import ${ErrorsModule.name}""".trimIndent()) + writer.rust("""/// from $libName import ${ErrorModule.name}""".trimIndent()) } writer.rust( """ @@ -382,7 +434,9 @@ class PythonApplicationGenerator( val operationDocumentation = it.getTrait()?.value val ret = if (!operationDocumentation.isNullOrBlank()) { operationDocumentation.replace("#", "##").prependIndent("/// ## ") + "\n" - } else "" + } else { + "" + } ret + """ /// ${it.signature()}: @@ -397,8 +451,8 @@ class PythonApplicationGenerator( private fun OperationShape.signature(): String { val inputSymbol = symbolProvider.toSymbol(inputShape(model)) val outputSymbol = symbolProvider.toSymbol(outputShape(model)) - val inputT = "${InputsModule.name}::${inputSymbol.name}" - val outputT = "${OutputsModule.name}::${outputSymbol.name}" + val inputT = "${InputModule.name}::${inputSymbol.name}" + val outputT = "${OutputModule.name}::${outputSymbol.name}" val operationName = symbolProvider.toSymbol(this).name.toSnakeCase() return "@app.$operationName\n/// def $operationName(input: $inputT, ctx: Context) -> $outputT" } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt index 90bdcc6e444..ac12cc0df37 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt @@ -7,45 +7,39 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGeneratorContext import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedEnum +import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator /** * To share enums defined in Rust with Python, `pyo3` provides the `PyClass` trait. * This class generates enums definitions, implements the `PyClass` trait and adds * some utility functions like `__str__()` and `__repr__()`. */ -class PythonServerEnumGenerator( +class PythonConstrainedEnum( codegenContext: ServerCodegenContext, - private val writer: RustWriter, shape: StringShape, -) : ServerEnumGenerator(codegenContext, writer, shape) { - + validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, +) : ConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator) { private val pyO3 = PythonServerCargoDependency.PyO3.toType() - override fun render() { - renderPyClass() - super.render() - renderPyO3Methods() - } - - private fun renderPyClass() { - Attribute(pyO3.resolve("pyclass")).render(writer) - } + override fun additionalEnumAttributes(context: EnumGeneratorContext): List = + listOf(Attribute(pyO3.resolve("pyclass"))) - private fun renderPyO3Methods() { - Attribute(pyO3.resolve("pymethods")).render(writer) - writer.rustTemplate( + override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = writable { + Attribute(pyO3.resolve("pymethods")).render(this) + rustTemplate( """ - impl $enumName { + impl ${context.enumName} { #{name_method:W} ##[getter] pub fn value(&self) -> &str { @@ -59,11 +53,11 @@ class PythonServerEnumGenerator( } } """, - "name_method" to renderPyEnumName(), + "name_method" to pyEnumName(context), ) } - private fun renderPyEnumName(): Writable = + private fun pyEnumName(context: EnumGeneratorContext): Writable = writable { rustBlock( """ @@ -72,11 +66,22 @@ class PythonServerEnumGenerator( """, ) { rustBlock("match self") { - sortedMembers.forEach { member -> + context.sortedMembers.forEach { member -> val memberName = member.name()?.name - rust("""$enumName::$memberName => ${memberName?.dq()},""") + rust("""${context.enumName}::$memberName => ${memberName?.dq()},""") } } } } } + +class PythonServerEnumGenerator( + codegenContext: ServerCodegenContext, + shape: StringShape, + validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, +) : EnumGenerator( + codegenContext.model, + codegenContext.symbolProvider, + shape, + PythonConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator), +) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index 9d9d4d26179..df7e246880e 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -99,6 +99,7 @@ class PythonServerModuleGenerator( let types = #{pyo3}::types::PyModule::new(py, "types")?; types.add_class::<#{SmithyPython}::types::Blob>()?; types.add_class::<#{SmithyPython}::types::DateTime>()?; + types.add_class::<#{SmithyPython}::types::Format>()?; types.add_class::<#{SmithyPython}::types::ByteStream>()?; #{pyo3}::py_run!( py, @@ -185,6 +186,10 @@ class PythonServerModuleGenerator( """ let aws_lambda = #{pyo3}::types::PyModule::new(py, "aws_lambda")?; aws_lambda.add_class::<#{SmithyPython}::lambda::PyLambdaContext>()?; + aws_lambda.add_class::<#{SmithyPython}::lambda::PyClientApplication>()?; + aws_lambda.add_class::<#{SmithyPython}::lambda::PyClientContext>()?; + aws_lambda.add_class::<#{SmithyPython}::lambda::PyCognitoIdentity>()?; + aws_lambda.add_class::<#{SmithyPython}::lambda::PyConfig>()?; pyo3::py_run!( py, aws_lambda, diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationErrorGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationErrorGenerator.kt index 65e53adb6c3..e2f66a83681 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationErrorGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationErrorGenerator.kt @@ -15,7 +15,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationErrorGenerator @@ -27,13 +26,11 @@ class PythonServerOperationErrorGenerator( private val model: Model, private val symbolProvider: RustSymbolProvider, private val operation: OperationShape, -) : ServerOperationErrorGenerator(model, symbolProvider, symbolProvider.toSymbol(operation), listOf()) { - +) { private val operationIndex = OperationIndex.of(model) private val errors = operationIndex.getErrors(operation) - override fun render(writer: RustWriter) { - super.render(writer) + fun render(writer: RustWriter) { renderFromPyErr(writer) } @@ -52,7 +49,7 @@ class PythonServerOperationErrorGenerator( """, "pyo3" to PythonServerCargoDependency.PyO3.toType(), - "Error" to operation.errorSymbol(symbolProvider), + "Error" to symbolProvider.symbolForOperationError(operation), "From" to RuntimeType.From, "CastPyErrToRustError" to castPyErrToRustError(), ) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt index 436d956fa29..496660e28c5 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerStructureGenerator.kt @@ -23,6 +23,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGener import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency +import software.amazon.smithy.rust.codegen.server.python.smithy.PythonType +import software.amazon.smithy.rust.codegen.server.python.smithy.pythonType +import software.amazon.smithy.rust.codegen.server.python.smithy.renderAsDocstring /** * To share structures defined in Rust with Python, `pyo3` provides the `PyClass` trait. @@ -34,7 +37,7 @@ class PythonServerStructureGenerator( private val symbolProvider: RustSymbolProvider, private val writer: RustWriter, private val shape: StructureShape, -) : StructureGenerator(model, symbolProvider, writer, shape) { +) : StructureGenerator(model, symbolProvider, writer, shape, emptyList()) { private val pyO3 = PythonServerCargoDependency.PyO3.toType() @@ -52,6 +55,7 @@ class PythonServerStructureGenerator( } else { Attribute(pyO3.resolve("pyclass")).render(writer) } + writer.rustTemplate("#{ConstructorSignature:W}", "ConstructorSignature" to renderConstructorSignature()) super.renderStructure() renderPyO3Methods() } @@ -65,6 +69,7 @@ class PythonServerStructureGenerator( writer.addDependency(PythonServerCargoDependency.PyO3) // Above, we manually add dependency since we can't use a `RuntimeType` below Attribute("pyo3(get, set)").render(writer) + writer.rustTemplate("#{Signature:W}", "Signature" to renderSymbolSignature(memberSymbol)) super.renderStructureMember(writer, member, memberName, memberSymbol) } @@ -107,4 +112,20 @@ class PythonServerStructureGenerator( rust("$memberName,") } } + + private fun renderConstructorSignature(): Writable = + writable { + forEachMember(members) { _, memberName, memberSymbol -> + val memberType = memberSymbol.rustType().pythonType() + rust("/// :param $memberName ${memberType.renderAsDocstring()}:") + } + + rust("/// :rtype ${PythonType.None.renderAsDocstring()}:") + } + + private fun renderSymbolSignature(symbol: Symbol): Writable = + writable { + val pythonType = symbol.rustType().pythonType() + rust("/// :type ${pythonType.renderAsDocstring()}:") + } } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt index e38e1ae67c3..669279c8d53 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/testutil/PythonServerTestHelpers.kt @@ -13,7 +13,9 @@ import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext import software.amazon.smithy.rust.codegen.core.util.runCommand import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCodegenVisitor import software.amazon.smithy.rust.codegen.server.python.smithy.customizations.DECORATORS +import software.amazon.smithy.rust.codegen.server.smithy.customizations.CustomValidationExceptionWithReasonDecorator import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionDecorator import software.amazon.smithy.rust.codegen.server.smithy.customize.CombinedServerCodegenDecorator import java.io.File import java.nio.file.Path @@ -25,10 +27,13 @@ fun generatePythonServerPluginContext(model: Model) = generatePluginContext(model, runtimeConfig = TestRuntimeConfig) fun executePythonServerCodegenVisitor(pluginCtx: PluginContext) { - val codegenDecorator: CombinedServerCodegenDecorator = + val codegenDecorator = CombinedServerCodegenDecorator.fromClasspath( pluginCtx, - CombinedServerCodegenDecorator(DECORATORS + ServerRequiredCustomizations()), + *DECORATORS, + ServerRequiredCustomizations(), + SmithyValidationExceptionDecorator(), + CustomValidationExceptionWithReasonDecorator(), ) PythonServerCodegenVisitor(pluginCtx, codegenDecorator).execute() } diff --git a/codegen-server/python/src/main/resources/META-INF/services/software.amazon.smithy.build.SmithyBuildPlugin b/codegen-server/python/src/main/resources/META-INF/services/software.amazon.smithy.build.SmithyBuildPlugin index 6dc5b76c780..000cc4b7aec 100644 --- a/codegen-server/python/src/main/resources/META-INF/services/software.amazon.smithy.build.SmithyBuildPlugin +++ b/codegen-server/python/src/main/resources/META-INF/services/software.amazon.smithy.build.SmithyBuildPlugin @@ -2,4 +2,4 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 # -software.amazon.smithy.rust.codegen.server.python.smithy.PythonCodegenServerPlugin +software.amazon.smithy.rust.codegen.server.python.smithy.RustServerCodegenPythonPlugin diff --git a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerSymbolProviderTest.kt b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerSymbolProviderTest.kt index 96439ac7870..c7467b58eda 100644 --- a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerSymbolProviderTest.kt +++ b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerSymbolProviderTest.kt @@ -12,7 +12,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerSymbolVisitor -import software.amazon.smithy.rust.codegen.server.smithy.testutil.ServerTestSymbolVisitorConfig +import software.amazon.smithy.rust.codegen.server.smithy.testutil.ServerTestRustSymbolProviderConfig +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings internal class PythonServerSymbolProviderTest { private val pythonBlobType = RustType.Opaque("Blob", "aws_smithy_http_server_python::types") @@ -45,7 +46,8 @@ internal class PythonServerSymbolProviderTest { value: Timestamp } """.asSmithyModel() - val provider = PythonServerSymbolVisitor(model, null, ServerTestSymbolVisitorConfig) + val provider = + PythonServerSymbolVisitor(serverTestRustSettings(), model, null, ServerTestRustSymbolProviderConfig) // Struct test val timestamp = provider.toSymbol(model.expectShape(ShapeId.from("test#TimestampStruct\$inner"))).rustType() @@ -95,7 +97,8 @@ internal class PythonServerSymbolProviderTest { value: Blob } """.asSmithyModel() - val provider = PythonServerSymbolVisitor(model, null, ServerTestSymbolVisitorConfig) + val provider = + PythonServerSymbolVisitor(serverTestRustSettings(), model, null, ServerTestRustSymbolProviderConfig) // Struct test val blob = provider.toSymbol(model.expectShape(ShapeId.from("test#BlobStruct\$inner"))).rustType() diff --git a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt index c15b399744b..be510db87bc 100644 --- a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt +++ b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerTypesTest.kt @@ -21,7 +21,7 @@ internal class PythonServerTypesTest { fun `document type`() { val model = """ namespace test - + use aws.protocols#restJson1 @restJson1 @@ -30,7 +30,7 @@ internal class PythonServerTypesTest { Echo, ], } - + @http(method: "POST", uri: "/echo") operation Echo { input: EchoInput, @@ -40,7 +40,7 @@ internal class PythonServerTypesTest { structure EchoInput { value: Document, } - + structure EchoOutput { value: Document, } @@ -53,38 +53,38 @@ internal class PythonServerTypesTest { Pair( """ { "value": 42 } """, """ - assert input.value == 42 - output = EchoOutput(value=input.value) + assert input.value == 42 + output = EchoOutput(value=input.value) """, ), Pair( """ { "value": "foobar" } """, """ - assert input.value == "foobar" - output = EchoOutput(value=input.value) + assert input.value == "foobar" + output = EchoOutput(value=input.value) """, ), Pair( """ - { - "value": [ - true, - false, - 42, - 42.0, - -42, - { - "nested": "value" - }, - { - "nested": [1, 2, 3] - } - ] - } + { + "value": [ + true, + false, + 42, + 42.0, + -42, + { + "nested": "value" + }, + { + "nested": [1, 2, 3] + } + ] + } """, """ - assert input.value == [True, False, 42, 42.0, -42, {"nested": "value"}, {"nested": [1, 2, 3]}] - output = EchoOutput(value=input.value) + assert input.value == [True, False, 42, 42.0, -42, {"nested": "value"}, {"nested": [1, 2, 3]}] + output = EchoOutput(value=input.value) """, ), ) @@ -97,7 +97,7 @@ internal class PythonServerTypesTest { use pyo3::{types::IntoPyDict, IntoPy, Python}; use hyper::{Body, Request, body}; use crate::{input, output}; - + pyo3::prepare_freethreaded_python(); """.trimIndent(), ) @@ -112,9 +112,9 @@ internal class PythonServerTypesTest { Ok(Python::with_gil(|py| { let globals = [("EchoOutput", py.get_type::())].into_py_dict(py); let locals = [("input", input.into_py(py))].into_py_dict(py); - + py.run(${pythonHandler.dq()}, Some(globals), Some(locals)).unwrap(); - + locals .get_item("output") .unwrap() @@ -124,13 +124,13 @@ internal class PythonServerTypesTest { }) .build() .unwrap(); - + let req = Request::builder() .method("POST") .uri("/echo") .body(Body::from(${payload.dq()})) .unwrap(); - + let res = service.call(req).await.unwrap(); assert!(res.status().is_success()); let body = body::to_bytes(res.into_body()).await.unwrap(); diff --git a/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonTypeInformationGenerationTest.kt b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonTypeInformationGenerationTest.kt new file mode 100644 index 00000000000..1473edcffe8 --- /dev/null +++ b/codegen-server/python/src/test/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonTypeInformationGenerationTest.kt @@ -0,0 +1,45 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.python.smithy.generators + +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext + +internal class PythonTypeInformationGenerationTest { + @Test + fun `generates python type information`() { + val model = """ + namespace test + + structure Foo { + @required + bar: String, + baz: Integer + } + """.asSmithyModel() + val foo = model.lookup("test#Foo") + + val codegenContext = serverTestCodegenContext(model) + val symbolProvider = codegenContext.symbolProvider + val writer = RustWriter.forModule("model") + PythonServerStructureGenerator(model, symbolProvider, writer, foo).render() + + val result = writer.toString() + + // Constructor signature + result.shouldContain("/// :param bar str:") + result.shouldContain("/// :param baz typing.Optional\\[int\\]:") + + // Field types + result.shouldContain("/// :type str:") + result.shouldContain("/// :type typing.Optional\\[int\\]:") + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt index 01e8255ccc2..0c83795f2fd 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolMetadataProvider.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.rust.codegen.server.smithy -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape @@ -29,7 +28,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata */ class ConstrainedShapeSymbolMetadataProvider( private val base: RustSymbolProvider, - private val model: Model, private val constrainedTypes: Boolean, ) : SymbolMetadataProvider(base) { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt index 5d37c465a09..25b10009773 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProvider.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.ByteShape @@ -19,9 +18,12 @@ import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShortShape import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.contextName @@ -30,9 +32,13 @@ import software.amazon.smithy.rust.codegen.core.smithy.handleRustBoxing import software.amazon.smithy.rust.codegen.core.smithy.locatedIn import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.smithy.symbolBuilder +import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderModule +import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait /** * The [ConstrainedShapeSymbolProvider] returns, for a given _directly_ @@ -55,16 +61,17 @@ import software.amazon.smithy.rust.codegen.core.util.toPascalCase */ class ConstrainedShapeSymbolProvider( private val base: RustSymbolProvider, - private val model: Model, private val serviceShape: ServiceShape, + private val publicConstrainedTypes: Boolean, ) : WrappingSymbolProvider(base) { private val nullableIndex = NullableIndex.of(model) private fun publicConstrainedSymbolForMapOrCollectionShape(shape: Shape): Symbol { check(shape is MapShape || shape is CollectionShape) - val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) - return symbolBuilder(shape, rustType).locatedIn(ModelsModule).build() + val (name, module) = getMemberNameAndModule(shape, serviceShape, ServerRustModule.Model, !publicConstrainedTypes) + val rustType = RustType.Opaque(name) + return symbolBuilder(shape, rustType).locatedIn(module).build() } override fun toSymbol(shape: Shape): Symbol { @@ -75,8 +82,14 @@ class ConstrainedShapeSymbolProvider( val target = model.expectShape(shape.target) val targetSymbol = this.toSymbol(target) // Handle boxing first, so we end up with `Option>`, not `Box>`. - handleOptionality(handleRustBoxing(targetSymbol, shape), shape, nullableIndex, base.config().nullabilityCheckMode) + handleOptionality( + handleRustBoxing(targetSymbol, shape), + shape, + nullableIndex, + base.config.nullabilityCheckMode, + ) } + is MapShape -> { if (shape.isDirectlyConstrained(base)) { check(shape.hasTrait()) { @@ -92,6 +105,7 @@ class ConstrainedShapeSymbolProvider( .build() } } + is CollectionShape -> { if (shape.isDirectlyConstrained(base)) { check(constrainedCollectionCheck(shape)) { @@ -106,8 +120,11 @@ class ConstrainedShapeSymbolProvider( is StringShape, is IntegerShape, is ShortShape, is LongShape, is ByteShape, is BlobShape -> { if (shape.isDirectlyConstrained(base)) { - val rustType = RustType.Opaque(shape.contextName(serviceShape).toPascalCase()) - symbolBuilder(shape, rustType).locatedIn(ModelsModule).build() + // A standalone constrained shape goes into `ModelsModule`, but one + // arising from a constrained member shape goes into a module for the container. + val (name, module) = getMemberNameAndModule(shape, serviceShape, ServerRustModule.Model, !publicConstrainedTypes) + val rustType = RustType.Opaque(name) + symbolBuilder(shape, rustType).locatedIn(module).build() } else { base.toSymbol(shape) } @@ -123,9 +140,51 @@ class ConstrainedShapeSymbolProvider( * - That it has no unsupported constraints applied. */ private fun constrainedCollectionCheck(shape: CollectionShape): Boolean { - val supportedConstraintTraits = supportedCollectionConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet() + val supportedConstraintTraits = + supportedCollectionConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet() val allConstraintTraits = allConstraintTraits.mapNotNull { shape.getTrait(it).orNull() }.toSet() - return supportedConstraintTraits.isNotEmpty() && allConstraintTraits.subtract(supportedConstraintTraits).isEmpty() + return supportedConstraintTraits.isNotEmpty() && allConstraintTraits.subtract(supportedConstraintTraits) + .isEmpty() + } + + /** + * Returns the pair (Rust Symbol Name, Inline Module) for the shape. At the time of model transformation all + * constrained member shapes are extracted and are given a model-wide unique name. However, the generated code + * for the new shapes is in a module that is named after the containing shape (structure, list, map or union). + * The new shape's Rust Symbol is renamed from `{structureName}{memberName}` to `{structure_name}::{member_name}` + */ + private fun getMemberNameAndModule( + shape: Shape, + serviceShape: ServiceShape, + defaultModule: RustModule.LeafModule, + pubCrateServerBuilder: Boolean, + ): Pair { + val syntheticMemberTrait = shape.getTrait() + ?: return Pair(shape.contextName(serviceShape), defaultModule) + + return if (syntheticMemberTrait.container is StructureShape) { + val builderModule = syntheticMemberTrait.container.serverBuilderModule(base, pubCrateServerBuilder) + val renameTo = syntheticMemberTrait.member.memberName ?: syntheticMemberTrait.member.id.name + Pair(renameTo.toPascalCase(), builderModule) + } else { + // For non-structure shapes, the new shape defined for a constrained member shape + // needs to be placed in an inline module named `pub {container_name_in_snake_case}`. + val moduleName = RustReservedWords.escapeIfNeeded(syntheticMemberTrait.container.id.name.toSnakeCase()) + val innerModuleName = moduleName + if (pubCrateServerBuilder) { + "_internal" + } else { + "" + } + + val innerModule = RustModule.new( + innerModuleName, + visibility = Visibility.publicIf(!pubCrateServerBuilder, Visibility.PUBCRATE), + parent = defaultModule, + inline = true, + ) + val renameTo = syntheticMemberTrait.member.memberName ?: syntheticMemberTrait.member.id.name + Pair(renameTo.toPascalCase(), innerModule) + } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt index 845150c5cc1..969b95a6f88 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintViolationSymbolProvider.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.ByteShape import software.amazon.smithy.model.shapes.CollectionShape @@ -23,15 +22,16 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.Visibility -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.contextName import software.amazon.smithy.rust.codegen.core.smithy.locatedIn import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.rustType +import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol +import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait /** * The [ConstraintViolationSymbolProvider] returns, for a given constrained @@ -68,7 +68,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilde */ class ConstraintViolationSymbolProvider( private val base: RustSymbolProvider, - private val model: Model, private val publicConstrainedTypes: Boolean, private val serviceShape: ServiceShape, ) : WrappingSymbolProvider(base) { @@ -80,15 +79,29 @@ class ConstraintViolationSymbolProvider( private fun Shape.shapeModule(): RustModule.LeafModule { val documentation = if (publicConstrainedTypes && this.isDirectlyConstrained(base)) { - "See [`${this.contextName(serviceShape)}`]." + val symbol = base.toSymbol(this) + "See [`${this.contextName(serviceShape)}`]($symbol)." } else { null } - return RustModule.new( + + val syntheticTrait = getTrait() + + val (module, name) = if (syntheticTrait != null) { + // For constrained member shapes, the ConstraintViolation code needs to go in an inline rust module + // that is a descendant of the module that contains the extracted shape itself. + val overriddenMemberModule = this.getParentAndInlineModuleForConstrainedMember(base, publicConstrainedTypes)!! + val name = syntheticTrait.member.memberName + Pair(overriddenMemberModule.second, RustReservedWords.escapeIfNeeded(name).toSnakeCase()) + } else { // Need to use the context name so we get the correct name for maps. - name = RustReservedWords.escapeIfNeeded(this.contextName(serviceShape)).toSnakeCase(), + Pair(ServerRustModule.Model, RustReservedWords.escapeIfNeeded(this.contextName(serviceShape)).toSnakeCase()) + } + + return RustModule.new( + name = name, visibility = visibility, - parent = ModelsModule, + parent = module, inline = true, documentation = documentation, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt index f6cc943a9cb..4bac39c2df8 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt @@ -26,11 +26,19 @@ import software.amazon.smithy.model.traits.PatternTrait import software.amazon.smithy.model.traits.RangeTrait import software.amazon.smithy.model.traits.RequiredTrait import software.amazon.smithy.model.traits.UniqueItemsTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderModule +import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait /** * This file contains utilities to work with constrained shapes. @@ -83,7 +91,7 @@ fun Shape.isDirectlyConstrained(symbolProvider: SymbolProvider): Boolean = when // The only reason why the functions in this file have // to take in a `SymbolProvider` is because non-`required` blob streaming members are interpreted as // `required`, so we can't use `member.isOptional` here. - this.members().map { symbolProvider.toSymbol(it) }.any { !it.isOptional() } + this.members().any { !symbolProvider.toSymbol(it).isOptional() && !it.hasNonNullDefault() } } is MapShape -> this.hasTrait() @@ -160,3 +168,47 @@ fun Shape.typeNameContainsNonPublicType( is StructureShape, is UnionShape -> false else -> UNREACHABLE("the above arms should be exhaustive, but we received shape: $this") } + +/** + * For synthetic shapes that are added to the model because of member constrained shapes, it returns + * the "container" and "the member shape" that originally had the constraint trait. For all other + * shapes, it returns null. + */ +fun Shape.overriddenConstrainedMemberInfo(): Pair? { + val trait = getTrait() ?: return null + return Pair(trait.container, trait.member) +} + +/** + * Returns the parent and the inline module that this particular shape should go in. + */ +fun Shape.getParentAndInlineModuleForConstrainedMember(symbolProvider: SymbolProvider, publicConstrainedTypes: Boolean): Pair? { + val overriddenTrait = getTrait() ?: return null + return if (overriddenTrait.container is StructureShape) { + val structureModule = symbolProvider.toSymbol(overriddenTrait.container).module() + val builderModule = overriddenTrait.container.serverBuilderModule(symbolProvider, !publicConstrainedTypes) + Pair(structureModule, builderModule) + } else { + // For constrained member shapes, the ConstraintViolation code needs to go in an inline rust module + // that is a descendant of the module that contains the extracted shape itself. + return if (publicConstrainedTypes) { + // Non-structured shape types need to go into their own module. + val shapeSymbol = symbolProvider.toSymbol(this) + val shapeModule = shapeSymbol.module() + check(!shapeModule.parent.isInline()) { + "Parent module of $id should not be an inline module." + } + Pair(shapeModule.parent as RustModule.LeafModule, shapeModule) + } else { + val name = RustReservedWords.escapeIfNeeded(overriddenTrait.container.id.name).toSnakeCase() + "_internal" + val innerModule = RustModule.new( + name = name, + visibility = Visibility.PUBCRATE, + parent = ServerRustModule.Model, + inline = true, + ) + + Pair(ServerRustModule.Model, innerModule) + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt index 5438447ed5b..d3f70592714 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProvider.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.rust.codegen.server.smithy -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.DocumentShape import software.amazon.smithy.model.shapes.DoubleShape @@ -49,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.util.hasTrait */ class DeriveEqAndHashSymbolMetadataProvider( private val base: RustSymbolProvider, - val model: Model, ) : SymbolMetadataProvider(base) { private val walker = DirectedWalker(model) @@ -58,7 +56,7 @@ class DeriveEqAndHashSymbolMetadataProvider( val baseMetadata = base.toSymbol(shape).expectRustMetadata() // See class-level documentation for why we filter these out. return if (walker.walkShapes(shape) - .any { it is FloatShape || it is DoubleShape || it is DocumentShape || it.hasTrait() } + .any { it is FloatShape || it is DoubleShape || it is DocumentShape || it.hasTrait() } ) { baseMetadata } else { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt index b15b2dc8f00..cb08de3c734 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/LengthTraitValidationErrorMessage.kt @@ -11,9 +11,11 @@ fun LengthTrait.validationErrorMessage(): String { val beginning = "Value with length {} at '{}' failed to satisfy constraint: Member must have length " val ending = if (this.min.isPresent && this.max.isPresent) { "between ${this.min.get()} and ${this.max.get()}, inclusive" - } else if (this.min.isPresent) ( - "greater than or equal to ${this.min.get()}" - ) else { + } else if (this.min.isPresent) { + ( + "greater than or equal to ${this.min.get()}" + ) + } else { check(this.max.isPresent) "less than or equal to ${this.max.get()}" } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitValidationErrorMessage.kt deleted file mode 100644 index 8bb3cb648e1..00000000000 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PatternTraitValidationErrorMessage.kt +++ /dev/null @@ -1,12 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy - -import software.amazon.smithy.model.traits.PatternTrait - -@Suppress("UnusedReceiverParameter") -fun PatternTrait.validationErrorMessage(): String = - "Value {} at '{}' failed to satisfy constraint: Member must satisfy regular expression pattern: {}" diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt index 800dc6c7302..c64182f152d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProvider.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.MapShape @@ -20,7 +19,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.Visibility -import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality @@ -62,7 +60,6 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase */ class PubCrateConstrainedShapeSymbolProvider( private val base: RustSymbolProvider, - private val model: Model, private val serviceShape: ServiceShape, ) : WrappingSymbolProvider(base) { private val nullableIndex = NullableIndex.of(model) @@ -74,7 +71,7 @@ class PubCrateConstrainedShapeSymbolProvider( val module = RustModule.new( RustReservedWords.escapeIfNeeded(name.toSnakeCase()), visibility = Visibility.PUBCRATE, - parent = ConstrainedModule, + parent = ServerRustModule.ConstrainedModule, inline = true, ) val rustType = RustType.Opaque(name, module.fullyQualifiedPath()) @@ -110,7 +107,7 @@ class PubCrateConstrainedShapeSymbolProvider( handleRustBoxing(targetSymbol, shape), shape, nullableIndex, - base.config().nullabilityCheckMode, + base.config.nullabilityCheckMode, ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt index 5512da84704..20881503945 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RangeTraitValidationErrorMessage.kt @@ -11,9 +11,11 @@ fun RangeTrait.validationErrorMessage(): String { val beginning = "Value {} at '{}' failed to satisfy constraint: Member must be " val ending = if (this.min.isPresent && this.max.isPresent) { "between ${this.min.get()} and ${this.max.get()}, inclusive" - } else if (this.min.isPresent) ( - "greater than or equal to ${this.min.get()}" - ) else { + } else if (this.min.isPresent) { + ( + "greater than or equal to ${this.min.get()}" + ) + } else { check(this.max.isPresent) "less than or equal to ${this.max.get()}" } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriter.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriter.kt new file mode 100644 index 00000000000..f4f352260eb --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriter.kt @@ -0,0 +1,349 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.module +import java.util.concurrent.ConcurrentHashMap + +typealias DocWriter = () -> Any +typealias InlineModuleCreator = (Symbol, Writable) -> Unit + +/** + * Initializes RustCrate -> InnerModule data structure. + */ +fun RustCrate.initializeInlineModuleWriter(debugMode: Boolean): InnerModule = + crateToInlineModule + .getOrPut(this) { InnerModule(debugMode) } + +/** + * Returns the InnerModule for the given RustCrate + */ +fun RustCrate.getInlineModuleWriter(): InnerModule { + return crateToInlineModule.getOrPut(this) { InnerModule(false) } +} + +/** + * Returns a function that can be used to create an inline module writer. + */ +fun RustCrate.createInlineModuleCreator(): InlineModuleCreator { + return { symbol: Symbol, writable: Writable -> + this.getInlineModuleWriter().withInlineModuleHierarchyUsingCrate(this, symbol.module()) { + writable() + } + } +} + +/** + * If the passed in `shape` is a synthetic extracted shape resulting from a constrained struct member, + * the `Writable` is called using the structure's builder module. Otherwise, the `Writable` is called + * using the given `module`. + */ +fun RustCrate.withModuleOrWithStructureBuilderModule( + module: RustModule, + shape: Shape, + codegenContext: ServerCodegenContext, + codeWritable: Writable, +) { + // All structure constrained-member-shapes code is generated inside the structure builder's module. + val parentAndInlineModuleInfo = + shape.getParentAndInlineModuleForConstrainedMember(codegenContext.symbolProvider, codegenContext.settings.codegenConfig.publicConstrainedTypes) + if (parentAndInlineModuleInfo == null) { + this.withModule(module, codeWritable) + } else { + val (parent, inline) = parentAndInlineModuleInfo + val inlineWriter = this.getInlineModuleWriter() + + inlineWriter.withInlineModuleHierarchyUsingCrate(this, parent) { + inlineWriter.withInlineModuleHierarchy(this, inline) { + codeWritable(this) + } + } + } +} + +/** + * If the passed in `shape` is a synthetic extracted shape resulting from a constrained struct member, + * the `Writable` is called using the structure's builder module. Otherwise, the `Writable` is called + * using shape's `module`. + */ +fun RustCrate.useShapeWriterOrUseWithStructureBuilder( + shape: Shape, + codegenContext: ServerCodegenContext, + docWriter: DocWriter? = null, + writable: Writable, +) { + // All structure constrained-member-shapes code is generated inside the structure builder's module. + val parentAndInlineModuleInfo = + shape.getParentAndInlineModuleForConstrainedMember(codegenContext.symbolProvider, codegenContext.settings.codegenConfig.publicConstrainedTypes) + if (parentAndInlineModuleInfo == null) { + docWriter?.invoke() + this.useShapeWriter(shape, writable) + } else { + val (parent, inline) = parentAndInlineModuleInfo + val inlineWriter = this.getInlineModuleWriter() + + inlineWriter.withInlineModuleHierarchyUsingCrate(this, parent) { + inlineWriter.withInlineModuleHierarchy(this, inline) { + writable(this) + } + } + } +} + +fun RustCrate.renderInlineMemoryModules() { + val inlineModule = crateToInlineModule[this] + check(inlineModule != null) { + "InlineModule writer has not been registered for this crate" + } + inlineModule.render() +} + +/** + * Given a `RustWriter` calls the `Writable` using a `RustWriter` for the `inlineModule` + */ +fun RustCrate.withInMemoryInlineModule( + outerWriter: RustWriter, + inlineModule: RustModule.LeafModule, + docWriter: DocWriter?, + codeWritable: Writable, +) { + check(inlineModule.isInline()) { + "Module has to be an inline module for it to be used with the InlineModuleWriter" + } + this.getInlineModuleWriter().withInlineModuleHierarchy(outerWriter, inlineModule, docWriter) { + codeWritable(this) + } +} + +fun RustWriter.createTestInlineModuleCreator(): InlineModuleCreator { + return { symbol: Symbol, writable: Writable -> + this.withInlineModule(symbol.module()) { + writable() + } + } +} + +/** + * Maintains the `RustWriter` that has been created for a `RustModule.LeafModule`. + */ +private data class InlineModuleWithWriter(val inlineModule: RustModule.LeafModule, val writer: RustWriter) + +/** + * For each RustCrate a separate mapping of inline-module to `RustWriter` is maintained. + */ +private val crateToInlineModule: ConcurrentHashMap = + ConcurrentHashMap() + +class InnerModule(debugMode: Boolean) { + // Holds the root modules to start rendering the descendents from. + private val topLevelModuleWriters: ConcurrentHashMap = ConcurrentHashMap() + private val inlineModuleWriters: ConcurrentHashMap> = ConcurrentHashMap() + private val docWriters: ConcurrentHashMap> = ConcurrentHashMap() + private val writerCreator = RustWriter.factory(debugMode) + + // By default, when a RustWriter is rendered, it prints a comment on top + // indicating that it contains generated code and should not be manually edited. This comment + // appears on each descendent inline module. To remove those comments, each time an inline + // module is rendered, first `emptyLineCount` characters are removed from it. + private val emptyLineCount: Int = writerCreator + .apply("lines-it-always-writes.rs", "crate") + .toString() + .split("\n")[0] + .length + + fun withInlineModule(outerWriter: RustWriter, innerModule: RustModule.LeafModule, docWriter: DocWriter? = null, writable: Writable) { + if (docWriter != null) { + val moduleDocWriterList = docWriters.getOrPut(innerModule) { mutableListOf() } + moduleDocWriterList.add(docWriter) + } + writable(getWriter(outerWriter, innerModule)) + } + + /** + * Given a `RustCrate` and a `RustModule.LeafModule()`, it creates a writer to that module and calls the writable. + */ + fun withInlineModuleHierarchyUsingCrate(rustCrate: RustCrate, inlineModule: RustModule.LeafModule, docWriter: DocWriter? = null, writable: Writable) { + val hierarchy = getHierarchy(inlineModule).toMutableList() + check(!hierarchy.first().isInline()) { + "When adding a `RustModule.LeafModule` to the crate, the topmost module in the hierarchy cannot be an inline module." + } + // The last in the hierarchy is the one we will return the writer for. + val bottomMost = hierarchy.removeLast() + + // In case it is a top level module that has been passed (e.g. ModelsModule, OutputsModule) then + // register it with the topLevel writers and call the writable on it. Otherwise, go over the + // complete hierarchy, registering each of the inner modules and then call the `Writable` + // with the bottom most inline module that has been passed. + if (hierarchy.isNotEmpty()) { + val topMost = hierarchy.removeFirst() + + // Create an intermediate writer for all inner modules in the hierarchy. + rustCrate.withModule(topMost) { + var writer = this + hierarchy.forEach { + writer = getWriter(writer, it) + } + + withInlineModule(writer, bottomMost, docWriter, writable) + } + } else { + check(!bottomMost.isInline()) { + "There is only one module in the hierarchy, so it has to be non-inlined." + } + rustCrate.withModule(bottomMost) { + registerTopMostWriter(this) + writable(this) + } + } + } + + /** + * Given a `Writer` to a module and an inline `RustModule.LeafModule()`, it creates a writer to that module and calls the writable. + * It registers the complete hierarchy including the `outerWriter` if that is not already registrered. + */ + fun withInlineModuleHierarchy(outerWriter: RustWriter, inlineModule: RustModule.LeafModule, docWriter: DocWriter? = null, writable: Writable) { + val hierarchy = getHierarchy(inlineModule).toMutableList() + if (!hierarchy.first().isInline()) { + hierarchy.removeFirst() + } + check(hierarchy.isNotEmpty()) { + "An inline module should always have one parent besides itself." + } + + // The last in the hierarchy is the module under which the new inline module resides. + val bottomMost = hierarchy.removeLast() + + // Create an entry in the HashMap for all the descendent modules in the hierarchy. + var writer = outerWriter + hierarchy.forEach { + writer = getWriter(writer, it) + } + + withInlineModule(writer, bottomMost, docWriter, writable) + } + + /** + * Creates an in memory writer and registers it with a map of RustWriter -> listOf(Inline descendent modules) + */ + private fun createNewInlineModule(): RustWriter { + val writer = writerCreator.apply("unknown-module-would-never-be-written.rs", "crate") + // Register the new RustWriter in the map to allow further descendent inline modules to be created inside it. + inlineModuleWriters[writer] = mutableListOf() + return writer + } + + /** + * Returns the complete hierarchy of a `RustModule.LeafModule` from top to bottom + */ + private fun getHierarchy(module: RustModule.LeafModule): List { + var current: RustModule = module + var hierarchy = listOf() + + while (current is RustModule.LeafModule) { + hierarchy = listOf(current) + hierarchy + current = current.parent + } + + return hierarchy + } + + /** + * Writes out each inline module's code (`toString`) to the respective top level `RustWriter`. + */ + fun render() { + var writerToAddDependencies: RustWriter? = null + + fun writeInlineCode(rustWriter: RustWriter, code: String) { + val inlineCode = code.drop(emptyLineCount) + rustWriter.writeWithNoFormatting(inlineCode) + } + + fun renderDescendents(topLevelWriter: RustWriter, inMemoryWriter: RustWriter) { + // Traverse all descendent inline modules and render them. + inlineModuleWriters[inMemoryWriter]!!.forEach { + writeDocs(it.inlineModule) + + topLevelWriter.withInlineModule(it.inlineModule) { + writeInlineCode(this, it.writer.toString()) + renderDescendents(this, it.writer) + } + + // Add dependencies introduced by the inline module to the top most RustWriter. + it.writer.dependencies.forEach { dep -> writerToAddDependencies!!.addDependency(dep) } + } + } + + // Go over all the top level modules, create an `inlineModule` on the `RustWriter` + // and call the descendent hierarchy renderer using the `inlineModule::RustWriter`. + topLevelModuleWriters.keys.forEach { + writerToAddDependencies = it + + check(inlineModuleWriters[it] != null) { + "There must be a registered RustWriter for this module." + } + + renderDescendents(it, it) + } + } + + /** + * Given the inline-module returns an existing `RustWriter`, or if that inline module + * has never been registered before then a new `RustWriter` is created and returned. + */ + private fun getWriter(outerWriter: RustWriter, inlineModule: RustModule.LeafModule): RustWriter { + val nestedModuleWriter = inlineModuleWriters[outerWriter] + if (nestedModuleWriter != null) { + return findOrAddToList(nestedModuleWriter, inlineModule) + } + + val inlineWriters = registerTopMostWriter(outerWriter) + return findOrAddToList(inlineWriters, inlineModule) + } + + /** + * Records the root of a dependency graph of inline modules. + */ + private fun registerTopMostWriter(outerWriter: RustWriter): MutableList { + topLevelModuleWriters[outerWriter] = Unit + return inlineModuleWriters.getOrPut(outerWriter) { mutableListOf() } + } + + /** + * Either gets a new `RustWriter` for the inline module or creates a new one and adds it to + * the list of inline modules. + */ + private fun findOrAddToList( + inlineModuleList: MutableList, + lookForModule: RustModule.LeafModule, + ): RustWriter { + val inlineModuleAndWriter = inlineModuleList.firstOrNull() { + it.inlineModule.name == lookForModule.name + } + return if (inlineModuleAndWriter == null) { + val inlineWriter = createNewInlineModule() + inlineModuleList.add(InlineModuleWithWriter(lookForModule, inlineWriter)) + inlineWriter + } else { + check(inlineModuleAndWriter.inlineModule == lookForModule) { + "The two inline modules have the same name but different attributes on them." + } + + inlineModuleAndWriter.writer + } + } + + private fun writeDocs(innerModule: RustModule.LeafModule) { + docWriters[innerModule]?.forEach { + it() + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt similarity index 58% rename from codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt rename to codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt index 8a1dc17e546..396ee17841c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCodegenServerPlugin.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustServerCodegenPlugin.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.build.PluginContext -import software.amazon.smithy.build.SmithyBuildPlugin import software.amazon.smithy.codegen.core.ReservedWordSymbolProvider import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape @@ -14,12 +13,16 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolP import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.EventStreamSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeMetadataProvider import software.amazon.smithy.rust.codegen.core.smithy.StreamingShapeSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitor -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig +import software.amazon.smithy.rust.codegen.server.smithy.customizations.CustomValidationExceptionWithReasonDecorator import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionDecorator import software.amazon.smithy.rust.codegen.server.smithy.customize.CombinedServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.testutil.ServerDecoratableBuildPlugin import java.util.logging.Level import java.util.logging.Logger @@ -30,57 +33,62 @@ import java.util.logging.Logger * `resources/META-INF.services/software.amazon.smithy.build.SmithyBuildPlugin` refers to this class by name which * enables the smithy-build plugin to invoke `execute` with all Smithy plugin context + models. */ -class RustCodegenServerPlugin : SmithyBuildPlugin { +class RustServerCodegenPlugin : ServerDecoratableBuildPlugin() { private val logger = Logger.getLogger(javaClass.name) override fun getName(): String = "rust-server-codegen" - override fun execute(context: PluginContext) { - // Suppress extremely noisy logs about reserved words + /** + * See [software.amazon.smithy.rust.codegen.client.smithy.RustClientCodegenPlugin]. + */ + override fun executeWithDecorator( + context: PluginContext, + vararg decorator: ServerCodegenDecorator, + ) { Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF - // Discover [RustCodegenDecorators] on the classpath. [RustCodegenDecorator] returns different types of - // customizations. A customization is a function of: - // - location (e.g. the mutate section of an operation) - // - context (e.g. the of the operation) - // - writer: The active RustWriter at the given location - val codegenDecorator: CombinedServerCodegenDecorator = - CombinedServerCodegenDecorator.fromClasspath(context, ServerRequiredCustomizations()) - - // ServerCodegenVisitor is the main driver of code generation that traverses the model and generates code + val codegenDecorator = + CombinedServerCodegenDecorator.fromClasspath( + context, + ServerRequiredCustomizations(), + SmithyValidationExceptionDecorator(), + CustomValidationExceptionWithReasonDecorator(), + *decorator, + ) logger.info("Loaded plugin to generate pure Rust bindings for the server SDK") ServerCodegenVisitor(context, codegenDecorator).execute() } companion object { /** - * When generating code, smithy types need to be converted into Rust types—that is the core role of the symbol provider. - * - * The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered - * with other symbol providers, documented inline, to handle the full scope of Smithy types. + * See [software.amazon.smithy.rust.codegen.client.smithy.RustClientCodegenPlugin]. */ fun baseSymbolProvider( + settings: ServerRustSettings, model: Model, serviceShape: ServiceShape, - symbolVisitorConfig: SymbolVisitorConfig, + rustSymbolProviderConfig: RustSymbolProviderConfig, constrainedTypes: Boolean = true, + includeConstrainedShapeProvider: Boolean = true, ) = - SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig) + SymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig) // Generate public constrained types for directly constrained shapes. - .let { if (constrainedTypes) ConstrainedShapeSymbolProvider(it, model, serviceShape) else it } + .let { + if (includeConstrainedShapeProvider) ConstrainedShapeSymbolProvider(it, serviceShape, constrainedTypes) else it + } // Generate different types for EventStream shapes (e.g. transcribe streaming) - .let { EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER) } + .let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.SERVER) } // Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject) - .let { StreamingShapeSymbolProvider(it, model) } + .let { StreamingShapeSymbolProvider(it) } // Add Rust attributes (like `#[derive(PartialEq)]`) to generated shapes - .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf()) } + .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) } // Constrained shapes generate newtypes that need the same derives we place on types generated from aggregate shapes. - .let { ConstrainedShapeSymbolMetadataProvider(it, model, constrainedTypes) } + .let { ConstrainedShapeSymbolMetadataProvider(it, constrainedTypes) } // Streaming shapes need different derives (e.g. they cannot derive `PartialEq`) - .let { StreamingShapeMetadataProvider(it, model) } + .let { StreamingShapeMetadataProvider(it) } // Derive `Eq` and `Hash` if possible. - .let { DeriveEqAndHashSymbolMetadataProvider(it, model) } + .let { DeriveEqAndHashSymbolMetadataProvider(it) } // Rename shapes that clash with Rust reserved words & and other SDK specific features e.g. `send()` cannot // be the name of an operation input - .let { RustReservedWordSymbolProvider(it, model) } + .let { RustReservedWordSymbolProvider(it) } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt index a0ad38f04f6..e71ae7ded1c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenContext.kt @@ -13,7 +13,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider /** - * [ServerCodegenContext] contains code-generation context that is _specific_ to the [RustCodegenServerPlugin] plugin + * [ServerCodegenContext] contains code-generation context that is _specific_ to the [RustServerCodegenPlugin] plugin * from the `rust-codegen-server` subproject. * * It inherits from [CodegenContext], which contains code-generation context that is common to _all_ smithy-rs plugins. diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 607aad2158e..7c04c4904f2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.LongShape import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.SetShape @@ -26,29 +27,26 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.model.transform.ModelTransformer -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RustCrate -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig -import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer -import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors -import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.core.util.CommandFailed +import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isEventStream @@ -74,11 +72,14 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerStruct import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedCollectionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedMapGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachValidationExceptionToConstrainedOperationInputsInAllowList +import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform +import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.logging.Logger @@ -101,14 +102,15 @@ open class ServerCodegenVisitor( protected var codegenContext: ServerCodegenContext protected var protocolGeneratorFactory: ProtocolGeneratorFactory protected var protocolGenerator: ServerProtocolGenerator + protected var validationExceptionConversionGenerator: ValidationExceptionConversionGenerator init { - val symbolVisitorConfig = - SymbolVisitorConfig( - runtimeConfig = settings.runtimeConfig, - renameExceptions = false, - nullabilityCheckMode = NullableIndex.CheckMode.SERVER, - ) + val rustSymbolProviderConfig = RustSymbolProviderConfig( + runtimeConfig = settings.runtimeConfig, + renameExceptions = false, + nullabilityCheckMode = NullableIndex.CheckMode.SERVER, + moduleProvider = ServerModuleProvider, + ) val baseModel = baselineTransform(context.model) val service = settings.getService(baseModel) @@ -125,11 +127,12 @@ open class ServerCodegenVisitor( model = codegenDecorator.transformModel(service, baseModel) val serverSymbolProviders = ServerSymbolProviders.from( + settings, model, service, - symbolVisitorConfig, + rustSymbolProviderConfig, settings.codegenConfig.publicConstrainedTypes, - RustCodegenServerPlugin::baseSymbolProvider, + RustServerCodegenPlugin::baseSymbolProvider, ) codegenContext = ServerCodegenContext( @@ -144,6 +147,9 @@ open class ServerCodegenVisitor( serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider, ) + // We can use a not-null assertion because [CombinedServerCodegenDecorator] returns a not null value. + validationExceptionConversionGenerator = codegenDecorator.validationExceptionConversion(codegenContext)!! + rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, settings.codegenConfig) protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext) } @@ -159,9 +165,14 @@ open class ServerCodegenVisitor( // Add errors attached at the service level to the models .let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) } // Add `Box` to recursive shapes as necessary - .let(RecursiveShapeBoxer::transform) + .let(RecursiveShapeBoxer()::transform) + // Add `Box` to recursive constraint violations as necessary + .let(RecursiveConstraintViolationBoxer::transform) // Normalize operations by adding synthetic input and output shapes to every operation .let(OperationNormalizer::transform) + // Transforms constrained member shapes into non-constrained member shapes targeting a new shape that + // has the member's constraints. + .let(ConstrainedMemberTransform::transform) // Remove the EBS model's own `ValidationException`, which collides with `smithy.framework#ValidationException` .let(RemoveEbsModelValidationException::transform) // Attach the `smithy.framework#ValidationException` error to operations whose inputs are constrained, @@ -195,10 +206,14 @@ open class ServerCodegenVisitor( "[rust-server-codegen] Generating Rust server for service $service, protocol ${codegenContext.protocol}", ) + val validationExceptionShapeId = validationExceptionConversionGenerator.shapeId for (validationResult in listOf( - validateOperationsWithConstrainedInputHaveValidationExceptionAttached( - model, - service, + codegenDecorator.postprocessValidationExceptionNotAttachedErrorMessage( + validateOperationsWithConstrainedInputHaveValidationExceptionAttached( + model, + service, + validationExceptionShapeId, + ), ), validateUnsupportedConstraints(model, service, codegenContext.settings.codegenConfig), )) { @@ -207,13 +222,18 @@ open class ServerCodegenVisitor( logger.log(logMessage.level, logMessage.message) } if (validationResult.shouldAbort) { - throw CodegenException("Unsupported constraints feature used; see error messages above for resolution") + throw CodegenException("Unsupported constraints feature used; see error messages above for resolution", validationResult) } } + rustCrate.initializeInlineModuleWriter(codegenContext.settings.codegenConfig.debugMode) + val serviceShapes = DirectedWalker(model).walkShapes(service) serviceShapes.forEach { it.accept(this) } codegenDecorator.extras(codegenContext, rustCrate) + + rustCrate.getInlineModuleWriter().render() + rustCrate.finalize( settings, model, @@ -249,7 +269,24 @@ open class ServerCodegenVisitor( override fun structureShape(shape: StructureShape) { logger.info("[rust-server-codegen] Generating a structure $shape") rustCrate.useShapeWriter(shape) { - StructureGenerator(model, codegenContext.symbolProvider, this, shape).render(CodegenTarget.SERVER) + StructureGenerator( + model, + codegenContext.symbolProvider, + this, + shape, + codegenDecorator.structureCustomizations(codegenContext, emptyList()), + ).render() + + shape.getTrait()?.also { errorTrait -> + ErrorImplGenerator( + model, + codegenContext.symbolProvider, + this, + shape, + errorTrait, + codegenDecorator.errorImplCustomizations(codegenContext, emptyList()), + ).render(CodegenTarget.SERVER) + } renderStructureShapeBuilder(shape, this) } @@ -260,11 +297,11 @@ open class ServerCodegenVisitor( writer: RustWriter, ) { if (codegenContext.settings.codegenConfig.publicConstrainedTypes || shape.isReachableFromOperationInput()) { - val serverBuilderGenerator = ServerBuilderGenerator(codegenContext, shape) - serverBuilderGenerator.render(writer) + val serverBuilderGenerator = ServerBuilderGenerator(codegenContext, shape, validationExceptionConversionGenerator) + serverBuilderGenerator.render(rustCrate, writer) if (codegenContext.settings.codegenConfig.publicConstrainedTypes) { - writer.implBlock(shape, codegenContext.symbolProvider) { + writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) { serverBuilderGenerator.renderConvenienceMethod(this) } } @@ -281,10 +318,10 @@ open class ServerCodegenVisitor( if (!codegenContext.settings.codegenConfig.publicConstrainedTypes) { val serverBuilderGeneratorWithoutPublicConstrainedTypes = - ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape) - serverBuilderGeneratorWithoutPublicConstrainedTypes.render(writer) + ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape, validationExceptionConversionGenerator) + serverBuilderGeneratorWithoutPublicConstrainedTypes.render(rustCrate, writer) - writer.implBlock(shape, codegenContext.symbolProvider) { + writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) { serverBuilderGeneratorWithoutPublicConstrainedTypes.renderConvenienceMethod(this) } } @@ -303,25 +340,29 @@ open class ServerCodegenVisitor( if (renderUnconstrainedList) { logger.info("[rust-server-codegen] Generating an unconstrained type for collection shape $shape") - rustCrate.withModule(UnconstrainedModule) { + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.UnconstrainedModule, shape, codegenContext) { UnconstrainedCollectionGenerator( codegenContext, - this, + rustCrate.createInlineModuleCreator(), shape, ).render() } if (!isDirectlyConstrained) { logger.info("[rust-server-codegen] Generating a constrained type for collection shape $shape") - rustCrate.withModule(ConstrainedModule) { - PubCrateConstrainedCollectionGenerator(codegenContext, this, shape).render() + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.ConstrainedModule, shape, codegenContext) { + PubCrateConstrainedCollectionGenerator( + codegenContext, + rustCrate.createInlineModuleCreator(), + shape, + ).render() } } } val constraintsInfo = CollectionTraitInfo.fromShape(shape, codegenContext.constrainedShapeSymbolProvider) if (isDirectlyConstrained) { - rustCrate.withModule(ModelsModule) { + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.Model, shape, codegenContext) { ConstrainedCollectionGenerator( codegenContext, this, @@ -333,8 +374,13 @@ open class ServerCodegenVisitor( } if (isDirectlyConstrained || renderUnconstrainedList) { - rustCrate.withModule(ModelsModule) { - CollectionConstraintViolationGenerator(codegenContext, this, shape, constraintsInfo).render() + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.Model, shape, codegenContext) { + CollectionConstraintViolationGenerator( + codegenContext, + rustCrate.createInlineModuleCreator(), + shape, constraintsInfo, + validationExceptionConversionGenerator, + ).render() } } } @@ -349,20 +395,28 @@ open class ServerCodegenVisitor( if (renderUnconstrainedMap) { logger.info("[rust-server-codegen] Generating an unconstrained type for map $shape") - rustCrate.withModule(UnconstrainedModule) { - UnconstrainedMapGenerator(codegenContext, this, shape).render() + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.UnconstrainedModule, shape, codegenContext) { + UnconstrainedMapGenerator( + codegenContext, + rustCrate.createInlineModuleCreator(), + shape, + ).render() } if (!isDirectlyConstrained) { logger.info("[rust-server-codegen] Generating a constrained type for map $shape") - rustCrate.withModule(ConstrainedModule) { - PubCrateConstrainedMapGenerator(codegenContext, this, shape).render() + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.ConstrainedModule, shape, codegenContext) { + PubCrateConstrainedMapGenerator( + codegenContext, + rustCrate.createInlineModuleCreator(), + shape, + ).render() } } } if (isDirectlyConstrained) { - rustCrate.withModule(ModelsModule) { + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.Model, shape, codegenContext) { ConstrainedMapGenerator( codegenContext, this, @@ -373,8 +427,13 @@ open class ServerCodegenVisitor( } if (isDirectlyConstrained || renderUnconstrainedMap) { - rustCrate.withModule(ModelsModule) { - MapConstraintViolationGenerator(codegenContext, this, shape).render() + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.Model, shape, codegenContext) { + MapConstraintViolationGenerator( + codegenContext, + rustCrate.createInlineModuleCreator(), + shape, + validationExceptionConversionGenerator, + ).render() } } } @@ -385,55 +444,37 @@ open class ServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - fun serverEnumGeneratorFactory(codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) = - ServerEnumGenerator(codegenContext, writer, shape) + fun serverEnumGeneratorFactory(codegenContext: ServerCodegenContext, shape: StringShape) = + ServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) stringShape(shape, ::serverEnumGeneratorFactory) } - override fun integerShape(shape: IntegerShape) { - if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) { - logger.info("[rust-server-codegen] Generating a constrained integer $shape") - rustCrate.withModule(ModelsModule) { - ConstrainedNumberGenerator(codegenContext, this, shape).render() - } - } - } - - override fun shortShape(shape: ShortShape) { + override fun integerShape(shape: IntegerShape) = integralShape(shape) + override fun shortShape(shape: ShortShape) = integralShape(shape) + override fun longShape(shape: LongShape) = integralShape(shape) + override fun byteShape(shape: ByteShape) = integralShape(shape) + private fun integralShape(shape: NumberShape) { if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) { - logger.info("[rust-server-codegen] Generating a constrained short $shape") - rustCrate.withModule(ModelsModule) { - ConstrainedNumberGenerator(codegenContext, this, shape).render() - } - } - } - - override fun longShape(shape: LongShape) { - if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) { - logger.info("[rust-server-codegen] Generating a constrained long $shape") - rustCrate.withModule(ModelsModule) { - ConstrainedNumberGenerator(codegenContext, this, shape).render() - } - } - } - - override fun byteShape(shape: ByteShape) { - if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) { - logger.info("[rust-server-codegen] Generating a constrained byte $shape") - rustCrate.withModule(ModelsModule) { - ConstrainedNumberGenerator(codegenContext, this, shape).render() + logger.info("[rust-server-codegen] Generating a constrained integral $shape") + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.Model, shape, codegenContext) { + ConstrainedNumberGenerator( + codegenContext, rustCrate.createInlineModuleCreator(), + this, + shape, + validationExceptionConversionGenerator, + ).render() } } } protected fun stringShape( shape: StringShape, - enumShapeGeneratorFactory: (codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) -> ServerEnumGenerator, + enumShapeGeneratorFactory: (codegenContext: ServerCodegenContext, shape: StringShape) -> EnumGenerator, ) { if (shape.hasTrait()) { logger.info("[rust-server-codegen] Generating an enum $shape") - rustCrate.useShapeWriter(shape) { - enumShapeGeneratorFactory(codegenContext, this, shape).render() + rustCrate.useShapeWriterOrUseWithStructureBuilder(shape, codegenContext) { + enumShapeGeneratorFactory(codegenContext, shape).render(this) ConstrainedTraitForEnumGenerator(model, codegenContext.symbolProvider, this, shape).render() } } @@ -449,8 +490,14 @@ open class ServerCodegenVisitor( ) } else if (!shape.hasTrait() && shape.isDirectlyConstrained(codegenContext.symbolProvider)) { logger.info("[rust-server-codegen] Generating a constrained string $shape") - rustCrate.withModule(ModelsModule) { - ConstrainedStringGenerator(codegenContext, this, shape).render() + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.Model, shape, codegenContext) { + ConstrainedStringGenerator( + codegenContext, + rustCrate.createInlineModuleCreator(), + this, + shape, + validationExceptionConversionGenerator, + ).render() } } } @@ -474,28 +521,19 @@ open class ServerCodegenVisitor( ) ) { logger.info("[rust-server-codegen] Generating an unconstrained type for union shape $shape") - rustCrate.withModule(UnconstrainedModule) unconstrainedModuleWriter@{ - rustCrate.withModule(ModelsModule) modelsModuleWriter@{ - UnconstrainedUnionGenerator( - codegenContext, - this@unconstrainedModuleWriter, - this@modelsModuleWriter, - shape, - ).render() - } + rustCrate.withModule(ServerRustModule.UnconstrainedModule) modelsModuleWriter@{ + UnconstrainedUnionGenerator( + codegenContext, + rustCrate.createInlineModuleCreator(), + this@modelsModuleWriter, + shape, + ).render() } } if (shape.isEventStream()) { - val errors = shape.eventStreamErrors() - .map { model.expectShape(it.asMemberShape().get().target, StructureShape::class.java) } - if (errors.isNotEmpty()) { - rustCrate.withModule(RustModule.Error) { - val symbol = codegenContext.symbolProvider.toSymbol(shape) - val errorSymbol = shape.eventStreamErrorSymbol(codegenContext.symbolProvider) - ServerOperationErrorGenerator(model, codegenContext.symbolProvider, symbol, errors) - .renderErrors(this, errorSymbol, symbol) - } + rustCrate.withModule(ServerRustModule.Error) { + ServerOperationErrorGenerator(model, codegenContext.symbolProvider, shape).render(this) } } } @@ -525,14 +563,8 @@ open class ServerCodegenVisitor( * Generate errors for operation shapes */ override fun operationShape(shape: OperationShape) { - rustCrate.withModule(RustModule.Error) { - val symbol = codegenContext.symbolProvider.toSymbol(shape) - ServerOperationErrorGenerator( - model, - codegenContext.symbolProvider, - symbol, - shape.operationErrors(model).map { it.asStructureShape().get() }, - ).render(this) + rustCrate.withModule(ServerRustModule.Error) { + ServerOperationErrorGenerator(model, codegenContext.symbolProvider, shape).render(this) } } @@ -543,8 +575,14 @@ open class ServerCodegenVisitor( } if (shape.isDirectlyConstrained(codegenContext.symbolProvider)) { - rustCrate.withModule(ModelsModule) { - ConstrainedBlobGenerator(codegenContext, this, shape).render() + rustCrate.withModuleOrWithStructureBuilderModule(ServerRustModule.Model, shape, codegenContext) { + ConstrainedBlobGenerator( + codegenContext, + rustCrate.createInlineModuleCreator(), + this, + shape, + validationExceptionConversionGenerator, + ).render() } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustModule.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustModule.kt new file mode 100644 index 00000000000..bd784a6ece0 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustModule.kt @@ -0,0 +1,77 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.smithy.ModuleProvider +import software.amazon.smithy.rust.codegen.core.smithy.ModuleProviderContext +import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase + +object ServerRustModule { + val root = RustModule.LibRs + + val Error = RustModule.public("error", documentation = "All error types that operations can return. Documentation on these types is copied from the model.") + val Operation = RustModule.public("operation", documentation = "All operations that this crate can perform.") + val Model = RustModule.public("model", documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.") + val Input = RustModule.public("input", documentation = "Input structures for operations. Documentation on these types is copied from the model.") + val Output = RustModule.public("output", documentation = "Output structures for operations. Documentation on these types is copied from the model.") + val Types = RustModule.public("types", documentation = "Data primitives referenced by other data types.") + + val UnconstrainedModule = + software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule + val ConstrainedModule = + software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule +} + +object ServerModuleProvider : ModuleProvider { + override fun moduleForShape(context: ModuleProviderContext, shape: Shape): RustModule.LeafModule = when (shape) { + is OperationShape -> ServerRustModule.Operation + is StructureShape -> when { + shape.hasTrait() -> ServerRustModule.Error + shape.hasTrait() -> ServerRustModule.Input + shape.hasTrait() -> ServerRustModule.Output + else -> ServerRustModule.Model + } + else -> ServerRustModule.Model + } + + override fun moduleForOperationError( + context: ModuleProviderContext, + operation: OperationShape, + ): RustModule.LeafModule = ServerRustModule.Error + + override fun moduleForEventStreamError( + context: ModuleProviderContext, + eventStream: UnionShape, + ): RustModule.LeafModule = ServerRustModule.Error + + override fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule { + val pubCrate = !(context.settings as ServerRustSettings).codegenConfig.publicConstrainedTypes + val builderNamespace = RustReservedWords.escapeIfNeeded(symbol.name.toSnakeCase()) + + if (pubCrate) { + "_internal" + } else { + "" + } + val visibility = when (pubCrate) { + true -> Visibility.PUBCRATE + false -> Visibility.PUBLIC + } + return RustModule.new(builderNamespace, visibility, parent = symbol.module(), inline = true) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt index dbfc8356a29..67fbea91cd6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRustSettings.kt @@ -25,7 +25,7 @@ import java.util.Optional */ /** - * Settings used by [RustCodegenServerPlugin]. + * Settings used by [RustServerCodegenPlugin]. */ data class ServerRustSettings( override val service: ShapeId, @@ -83,12 +83,20 @@ data class ServerCodegenConfig( override val debugMode: Boolean = defaultDebugMode, val publicConstrainedTypes: Boolean = defaultPublicConstrainedTypes, val ignoreUnsupportedConstraints: Boolean = defaultIgnoreUnsupportedConstraints, + /** + * A flag to enable _experimental_ support for custom validation exceptions via the + * [CustomValidationExceptionWithReasonDecorator] decorator. + * TODO(https://github.com/awslabs/smithy-rs/pull/2053): this will go away once we implement the RFC, when users will be + * able to define the converters in their Rust application code. + */ + val experimentalCustomValidationExceptionWithReasonPleaseDoNotUse: String? = defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse, ) : CoreCodegenConfig( formatTimeoutSeconds, debugMode, ) { companion object { private const val defaultPublicConstrainedTypes = true private const val defaultIgnoreUnsupportedConstraints = false + private val defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse = null fun fromCodegenConfigAndNode(coreCodegenConfig: CoreCodegenConfig, node: Optional) = if (node.isPresent) { @@ -97,6 +105,7 @@ data class ServerCodegenConfig( debugMode = coreCodegenConfig.debugMode, publicConstrainedTypes = node.get().getBooleanMemberOrDefault("publicConstrainedTypes", defaultPublicConstrainedTypes), ignoreUnsupportedConstraints = node.get().getBooleanMemberOrDefault("ignoreUnsupportedConstraints", defaultIgnoreUnsupportedConstraints), + experimentalCustomValidationExceptionWithReasonPleaseDoNotUse = node.get().getStringMemberOrDefault("experimentalCustomValidationExceptionWithReasonPleaseDoNotUse", defaultExperimentalCustomValidationExceptionWithReasonPleaseDoNotUse), ) } else { ServerCodegenConfig( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt index 0e368d85175..675b72b0b38 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerSymbolProviders.kt @@ -8,7 +8,7 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig /** * Just a handy class to centralize initialization all the symbol providers required by the server code generators, to @@ -24,38 +24,41 @@ class ServerSymbolProviders private constructor( ) { companion object { fun from( + settings: ServerRustSettings, model: Model, service: ServiceShape, - symbolVisitorConfig: SymbolVisitorConfig, + rustSymbolProviderConfig: RustSymbolProviderConfig, publicConstrainedTypes: Boolean, - baseSymbolProviderFactory: (model: Model, service: ServiceShape, symbolVisitorConfig: SymbolVisitorConfig, publicConstrainedTypes: Boolean) -> RustSymbolProvider, + baseSymbolProviderFactory: (settings: ServerRustSettings, model: Model, service: ServiceShape, rustSymbolProviderConfig: RustSymbolProviderConfig, publicConstrainedTypes: Boolean, includeConstraintShapeProvider: Boolean) -> RustSymbolProvider, ): ServerSymbolProviders { - val baseSymbolProvider = baseSymbolProviderFactory(model, service, symbolVisitorConfig, publicConstrainedTypes) + val baseSymbolProvider = baseSymbolProviderFactory(settings, model, service, rustSymbolProviderConfig, publicConstrainedTypes, publicConstrainedTypes) return ServerSymbolProviders( symbolProvider = baseSymbolProvider, constrainedShapeSymbolProvider = baseSymbolProviderFactory( + settings, model, service, - symbolVisitorConfig, + rustSymbolProviderConfig, + publicConstrainedTypes, true, ), unconstrainedShapeSymbolProvider = UnconstrainedShapeSymbolProvider( baseSymbolProviderFactory( + settings, model, service, - symbolVisitorConfig, + rustSymbolProviderConfig, + false, false, ), - model, publicConstrainedTypes, service, + publicConstrainedTypes, service, ), pubCrateConstrainedShapeSymbolProvider = PubCrateConstrainedShapeSymbolProvider( baseSymbolProvider, - model, service, ), constraintViolationSymbolProvider = ConstraintViolationSymbolProvider( baseSymbolProvider, - model, publicConstrainedTypes, service, ), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt index 3da31293879..711f35e462f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/UnconstrainedShapeSymbolProvider.kt @@ -6,7 +6,6 @@ package software.amazon.smithy.rust.codegen.server.smithy import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.MapShape @@ -22,7 +21,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.smithy.Default import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.WrappingSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.contextName import software.amazon.smithy.rust.codegen.core.smithy.handleOptionality @@ -78,7 +76,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilde */ class UnconstrainedShapeSymbolProvider( private val base: RustSymbolProvider, - private val model: Model, private val publicConstrainedTypes: Boolean, private val serviceShape: ServiceShape, ) : WrappingSymbolProvider(base) { @@ -101,10 +98,12 @@ class UnconstrainedShapeSymbolProvider( check(shape is CollectionShape || shape is MapShape || shape is UnionShape) val name = unconstrainedTypeNameForCollectionOrMapOrUnionShape(shape) + val parent = shape.getParentAndInlineModuleForConstrainedMember(this, publicConstrainedTypes)?.second ?: ServerRustModule.UnconstrainedModule + val module = RustModule.new( RustReservedWords.escapeIfNeeded(name.toSnakeCase()), visibility = Visibility.PUBCRATE, - parent = UnconstrainedModule, + parent = parent, inline = true, ) val rustType = RustType.Opaque(name, module.fullyQualifiedPath()) @@ -167,7 +166,7 @@ class UnconstrainedShapeSymbolProvider( handleRustBoxing(targetSymbol, shape), shape, nullableIndex, - base.config().nullabilityCheckMode, + base.config.nullabilityCheckMode, ) } else { base.toSymbol(shape) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt index 8bf6f928d9c..3a6cdabcd08 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraints.kt @@ -10,7 +10,9 @@ import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.ByteShape import software.amazon.smithy.model.shapes.EnumShape import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape @@ -36,13 +38,17 @@ private sealed class UnsupportedConstraintMessageKind { private val constraintTraitsUberIssue = "https://github.com/awslabs/smithy-rs/issues/1401" fun intoLogMessage(ignoreUnsupportedConstraints: Boolean): LogMessage { - fun buildMessage(intro: String, willSupport: Boolean, trackingIssue: String, canBeIgnored: Boolean = true): String { + fun buildMessage(intro: String, willSupport: Boolean, trackingIssue: String? = null, canBeIgnored: Boolean = true): String { var msg = """ $intro This is not supported in the smithy-rs server SDK.""" if (willSupport) { msg += """ - It will be supported in the future. See the tracking issue ($trackingIssue).""" + It will be supported in the future.""" + } + if (trackingIssue != null) { + msg += """ + For more information, and to report if you're affected by this, please use the tracking issue: $trackingIssue.""" } if (canBeIgnored) { msg += """ @@ -106,6 +112,19 @@ private sealed class UnsupportedConstraintMessageKind { level, buildMessageShapeHasUnsupportedConstraintTrait(shape, uniqueItemsTrait, constraintTraitsUberIssue), ) + + is UnsupportedMapShapeReachableFromUniqueItemsList -> LogMessage( + Level.SEVERE, + buildMessage( + """ + The map shape `${mapShape.id}` is reachable from the list shape `${listShape.id}`, which has the + `@uniqueItems` trait attached. + """.trimIndent().replace("\n", " "), + willSupport = false, + trackingIssue = "https://github.com/awslabs/smithy/issues/1567", + canBeIgnored = false, + ), + ) } } } @@ -129,14 +148,25 @@ private data class UnsupportedRangeTraitOnShape(val shape: Shape, val rangeTrait private data class UnsupportedUniqueItemsTraitOnShape(val shape: Shape, val uniqueItemsTrait: UniqueItemsTrait) : UnsupportedConstraintMessageKind() +private data class UnsupportedMapShapeReachableFromUniqueItemsList( + val listShape: ListShape, + val uniqueItemsTrait: UniqueItemsTrait, + val mapShape: MapShape, +) : UnsupportedConstraintMessageKind() + data class LogMessage(val level: Level, val message: String) -data class ValidationResult(val shouldAbort: Boolean, val messages: List) +data class ValidationResult(val shouldAbort: Boolean, val messages: List) : + Throwable(message = messages.joinToString("\n") { it.message }) private val unsupportedConstraintsOnMemberShapes = allConstraintTraits - RequiredTrait::class.java +/** + * Validate that all constrained operations have the shape [validationExceptionShapeId] shape attached to their errors. + */ fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached( model: Model, service: ServiceShape, + validationExceptionShapeId: ShapeId, ): ValidationResult { // Traverse the model and error out if an operation uses constrained input, but it does not have // `ValidationException` attached in `errors`. https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809424783 @@ -151,7 +181,7 @@ fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached( walker.walkShapes(operationShape.inputShape(model)) .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } } - .filter { !it.errors.contains(ShapeId.from("smithy.framework#ValidationException")) } + .filter { !it.errors.contains(validationExceptionShapeId) } .map { OperationWithConstrainedInputWithoutValidationException(it) } .toSet() @@ -167,11 +197,11 @@ fun validateOperationsWithConstrainedInputHaveValidationExceptionAttached( """ ```smithy - use smithy.framework#ValidationException + use $validationExceptionShapeId operation ${it.shape.id.name} { ... - errors: [..., ValidationException] // <-- Add this. + errors: [..., ${validationExceptionShapeId.name}] // <-- Add this. } ``` """.trimIndent(), @@ -189,18 +219,7 @@ fun validateUnsupportedConstraints( // Traverse the model and error out if: val walker = DirectedWalker(model) - // 1. Constraint traits on member shapes are used. [Constraint trait precedence] has not been implemented yet. - // TODO(https://github.com/awslabs/smithy-rs/issues/1401) - // [Constraint trait precedence]: https://awslabs.github.io/smithy/2.0/spec/model.html#applying-traits - val unsupportedConstraintOnMemberShapeSet = walker - .walkShapes(service) - .asSequence() - .filterIsInstance() - .filterMapShapesToTraits(unsupportedConstraintsOnMemberShapes) - .map { (shape, trait) -> UnsupportedConstraintOnMemberShape(shape as MemberShape, trait) } - .toSet() - - // 2. Constraint traits on streaming blob shapes are used. Their semantics are unclear. + // 1. Constraint traits on streaming blob shapes are used. Their semantics are unclear. // TODO(https://github.com/awslabs/smithy/issues/1389) val unsupportedLengthTraitOnStreamingBlobShapeSet = walker .walkShapes(service) @@ -210,7 +229,7 @@ fun validateUnsupportedConstraints( .map { UnsupportedLengthTraitOnStreamingBlobShape(it, it.expectTrait(), it.expectTrait()) } .toSet() - // 3. Constraint traits in event streams are used. Their semantics are unclear. + // 2. Constraint traits in event streams are used. Their semantics are unclear. // TODO(https://github.com/awslabs/smithy/issues/1388) val eventStreamShapes = walker .walkShapes(service) @@ -221,7 +240,9 @@ fun validateUnsupportedConstraints( .filterMapShapesToTraits(allConstraintTraits) .map { (shape, trait) -> UnsupportedConstraintOnShapeReachableViaAnEventStream(shape, trait) } .toSet() - val eventStreamErrors = eventStreamShapes.map { it.expectTrait() }.map { it.errorMembers } + val eventStreamErrors = eventStreamShapes.map { + it.expectTrait() + }.map { it.errorMembers } val unsupportedConstraintErrorShapeReachableViaAnEventStreamSet = eventStreamErrors .flatMap { it } .flatMap { walker.walkShapes(it) } @@ -231,7 +252,7 @@ fun validateUnsupportedConstraints( val unsupportedConstraintShapeReachableViaAnEventStreamSet = unsupportedConstraintOnNonErrorShapeReachableViaAnEventStreamSet + unsupportedConstraintErrorShapeReachableViaAnEventStreamSet - // 4. Range trait used on unsupported shapes. + // 3. Range trait used on unsupported shapes. // TODO(https://github.com/awslabs/smithy-rs/issues/1401) val unsupportedRangeTraitOnShapeSet = walker .walkShapes(service) @@ -241,11 +262,34 @@ fun validateUnsupportedConstraints( .map { (shape, rangeTrait) -> UnsupportedRangeTraitOnShape(shape, rangeTrait as RangeTrait) } .toSet() + // 5. `@uniqueItems` cannot reach a map shape. + // See https://github.com/awslabs/smithy/issues/1567. + val mapShapeReachableFromUniqueItemsListShapeSet = walker + .walkShapes(service) + .asSequence() + .filterMapShapesToTraits(setOf(UniqueItemsTrait::class.java)) + .flatMap { (listShape, uniqueItemsTrait) -> + walker.walkShapes(listShape).filterIsInstance().map { mapShape -> + UnsupportedMapShapeReachableFromUniqueItemsList( + listShape as ListShape, + uniqueItemsTrait as UniqueItemsTrait, + mapShape, + ) + } + } + .toSet() + val messages = - unsupportedConstraintOnMemberShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + - unsupportedLengthTraitOnStreamingBlobShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + - unsupportedConstraintShapeReachableViaAnEventStreamSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + - unsupportedRangeTraitOnShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + unsupportedLengthTraitOnStreamingBlobShapeSet.map { + it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) + } + + unsupportedConstraintShapeReachableViaAnEventStreamSet.map { + it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) + } + + unsupportedRangeTraitOnShapeSet.map { it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) } + + mapShapeReachableFromUniqueItemsListShapeSet.map { + it.intoLogMessage(codegenConfig.ignoreUnsupportedConstraints) + } return ValidationResult(shouldAbort = messages.any { it.level == Level.SEVERE }, messages) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecorator.kt new file mode 100644 index 00000000000..746ea98ca94 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecorator.kt @@ -0,0 +1,314 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.join +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength +import software.amazon.smithy.rust.codegen.server.smithy.generators.CollectionTraitInfo +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstraintViolation +import software.amazon.smithy.rust.codegen.server.smithy.generators.Length +import software.amazon.smithy.rust.codegen.server.smithy.generators.Pattern +import software.amazon.smithy.rust.codegen.server.smithy.generators.Range +import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitInfo +import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained +import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage + +/** + * A decorator that adds code to convert from constraint violations to a custom `ValidationException` shape that is very + * similar to `smithy.framework#ValidationException`, with an additional `reason` field. + * + * The shape definition is in [CustomValidationExceptionWithReasonDecoratorTest]. + * + * This is just an example to showcase experimental support for custom validation exceptions. + * TODO(https://github.com/awslabs/smithy-rs/pull/2053): this will go away once we implement the RFC, when users will be + * able to define the converters in their Rust application code. + */ +class CustomValidationExceptionWithReasonDecorator : ServerCodegenDecorator { + override val name: String + get() = "CustomValidationExceptionWithReasonDecorator" + override val order: Byte + get() = -69 + + override fun validationExceptionConversion(codegenContext: ServerCodegenContext): + ValidationExceptionConversionGenerator? = + if (codegenContext.settings.codegenConfig.experimentalCustomValidationExceptionWithReasonPleaseDoNotUse != null) { + ValidationExceptionWithReasonConversionGenerator(codegenContext) + } else { + null + } +} + +class ValidationExceptionWithReasonConversionGenerator(private val codegenContext: ServerCodegenContext) : + ValidationExceptionConversionGenerator { + override val shapeId: ShapeId = + ShapeId.from(codegenContext.settings.codegenConfig.experimentalCustomValidationExceptionWithReasonPleaseDoNotUse) + + override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable { + val codegenScope = arrayOf( + "RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig), + "From" to RuntimeType.From, + ) + rustTemplate( + """ + impl #{From} for #{RequestRejection} { + fn from(constraint_violation: ConstraintViolation) -> Self { + let first_validation_exception_field = constraint_violation.as_validation_exception_field("".to_owned()); + let validation_exception = crate::error::ValidationException { + message: format!("1 validation error detected. {}", &first_validation_exception_field.message), + reason: crate::model::ValidationExceptionReason::FieldValidationFailed, + fields: Some(vec![first_validation_exception_field]), + }; + Self::ConstraintViolation( + crate::operation_ser::serialize_structure_crate_error_validation_exception(&validation_exception) + .expect("validation exceptions should never fail to serialize; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + ) + } + } + """, + *codegenScope, + ) + } + + override fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection): Writable = writable { + val validationExceptionFields = + stringConstraintsInfo.map { + writable { + when (it) { + is Pattern -> { + rustTemplate( + """ + Self::Pattern(string) => crate::model::ValidationExceptionField { + message: #{MessageWritable:W}, + name: path, + reason: crate::model::ValidationExceptionFieldReason::PatternNotValid, + }, + """, + "MessageWritable" to it.errorMessage(), + ) + } + is Length -> { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.lengthTrait.validationErrorMessage()}", length, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, + }, + """, + ) + } + } + } + }.join("\n") + + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } + } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to validationExceptionFields, + ) + } + + override fun enumShapeConstraintViolationImplBlock(enumTrait: EnumTrait) = writable { + val enumValueSet = enumTrait.enumDefinitionValues.joinToString(", ") + val message = "Value {} at '{}' failed to satisfy constraint: Member must satisfy enum value set: [$enumValueSet]" + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + crate::model::ValidationExceptionField { + message: format!(r##"$message"##, &self.0, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::ValueNotValid, + } + } + """, + "String" to RuntimeType.String, + ) + } + + override fun numberShapeConstraintViolationImplBlock(rangeInfo: Range) = writable { + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + Self::Range(value) => crate::model::ValidationExceptionField { + message: format!("${rangeInfo.rangeTrait.validationErrorMessage()}", value, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::ValueNotValid, + } + } + } + """, + "String" to RuntimeType.String, + ) + } + + override fun blobShapeConstraintViolationImplBlock(blobConstraintsInfo: Collection) = writable { + val validationExceptionFields = + blobConstraintsInfo.map { + writable { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.lengthTrait.validationErrorMessage()}", length, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, + }, + """, + ) + } + }.join("\n") + + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } + } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to validationExceptionFields, + ) + } + + override fun mapShapeConstraintViolationImplBlock( + shape: MapShape, + keyShape: StringShape, + valueShape: Shape, + symbolProvider: RustSymbolProvider, + model: Model, + ) = writable { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + rustBlock("match self") { + shape.getTrait()?.also { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.validationErrorMessage()}", length, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, + }, + """, + ) + } + if (isKeyConstrained(keyShape, symbolProvider)) { + rust("""Self::Key(key_constraint_violation) => key_constraint_violation.as_validation_exception_field(path),""") + } + if (isValueConstrained(valueShape, model, symbolProvider)) { + rust("""Self::Value(key, value_constraint_violation) => value_constraint_violation.as_validation_exception_field(path + "/" + key.as_str()),""") + } + } + } + } + + override fun builderConstraintViolationImplBlock(constraintViolations: Collection) = writable { + rustBlock("match self") { + constraintViolations.forEach { + if (it.hasInner()) { + rust("""ConstraintViolation::${it.name()}(inner) => inner.as_validation_exception_field(path + "/${it.forMember.memberName}"),""") + } else { + rust( + """ + ConstraintViolation::${it.name()} => crate::model::ValidationExceptionField { + message: format!("Value null at '{}/${it.forMember.memberName}' failed to satisfy constraint: Member must not be null", path), + name: path + "/${it.forMember.memberName}", + reason: crate::model::ValidationExceptionFieldReason::Other, + }, + """, + ) + } + } + } + } + + override fun collectionShapeConstraintViolationImplBlock( + collectionConstraintsInfo: + Collection, + isMemberConstrained: Boolean, + ) = writable { + val validationExceptionFields = collectionConstraintsInfo.map { + writable { + when (it) { + is CollectionTraitInfo.Length -> { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.lengthTrait.validationErrorMessage()}", length, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::LengthNotValid, + }, + """, + ) + } + is CollectionTraitInfo.UniqueItems -> { + rust( + """ + Self::UniqueItems { duplicate_indices, .. } => + crate::model::ValidationExceptionField { + message: format!("${it.uniqueItemsTrait.validationErrorMessage()}", &duplicate_indices, &path), + name: path, + reason: crate::model::ValidationExceptionFieldReason::ValueNotValid, + }, + """, + ) + } + } + } + }.toMutableList() + + if (isMemberConstrained) { + validationExceptionFields += { + rust( + """ + Self::Member(index, member_constraint_violation) => + member_constraint_violation.as_validation_exception_field(path + "/" + &index.to_string()) + """, + ) + } + } + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{AsValidationExceptionFields:W} + } + } + """, + "String" to RuntimeType.String, + "AsValidationExceptionFields" to validationExceptionFields.join("\n"), + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt index 90b3550b981..ceba38691bf 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt @@ -9,9 +9,11 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customizations.AllowLintsCustomization import software.amazon.smithy.rust.codegen.core.smithy.customizations.CrateVersionCustomization -import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyTypes +import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyErrorTypes +import software.amazon.smithy.rust.codegen.core.smithy.customizations.pubUseSmithyPrimitives import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator /** @@ -29,12 +31,19 @@ class ServerRequiredCustomizations : ServerCodegenDecorator { codegenContext: ServerCodegenContext, baseCustomizations: List, ): List = - baseCustomizations + CrateVersionCustomization() + AllowLintsCustomization() + baseCustomizations + AllowLintsCustomization() override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { // Add rt-tokio feature for `ByteStream::from_path` rustCrate.mergeFeature(Feature("rt-tokio", true, listOf("aws-smithy-http/rt-tokio"))) - pubUseSmithyTypes(codegenContext.runtimeConfig, codegenContext.model, rustCrate) + rustCrate.withModule(ServerRustModule.Types) { + pubUseSmithyPrimitives(codegenContext, codegenContext.model)(this) + pubUseSmithyErrorTypes(codegenContext)(this) + } + + rustCrate.withModule(ServerRustModule.root) { + CrateVersionCustomization.extras(rustCrate, ServerRustModule.root) + } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SmithyValidationExceptionDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SmithyValidationExceptionDecorator.kt new file mode 100644 index 00000000000..42e2eac264e --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/SmithyValidationExceptionDecorator.kt @@ -0,0 +1,239 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.join +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength +import software.amazon.smithy.rust.codegen.server.smithy.generators.CollectionTraitInfo +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstraintViolation +import software.amazon.smithy.rust.codegen.server.smithy.generators.Range +import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitInfo +import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo +import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained +import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage + +/** + * A decorator that adds code to convert from constraint violations to Smithy's `smithy.framework#ValidationException`, + * defined in [0]. This is Smithy's recommended shape to return when validation fails. + * + * This decorator is always enabled when using the `rust-server-codegen` plugin. + * + * [0]: https://github.com/awslabs/smithy/tree/main/smithy-validation-model + * + * TODO(https://github.com/awslabs/smithy-rs/pull/2053): once the RFC is implemented, consider moving this back into the + * generators. + */ +class SmithyValidationExceptionDecorator : ServerCodegenDecorator { + override val name: String + get() = "SmithyValidationExceptionDecorator" + override val order: Byte + get() = 69 + + override fun validationExceptionConversion(codegenContext: ServerCodegenContext): ValidationExceptionConversionGenerator = + SmithyValidationExceptionConversionGenerator(codegenContext) +} + +class SmithyValidationExceptionConversionGenerator(private val codegenContext: ServerCodegenContext) : + ValidationExceptionConversionGenerator { + + // Define a companion object so that we can refer to this shape id globally. + companion object { + val SHAPE_ID: ShapeId = ShapeId.from("smithy.framework#ValidationException") + } + override val shapeId: ShapeId = SHAPE_ID + + override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable { + val codegenScope = arrayOf( + "RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig), + "From" to RuntimeType.From, + ) + rustTemplate( + """ + impl #{From} for #{RequestRejection} { + fn from(constraint_violation: ConstraintViolation) -> Self { + let first_validation_exception_field = constraint_violation.as_validation_exception_field("".to_owned()); + let validation_exception = crate::error::ValidationException { + message: format!("1 validation error detected. {}", &first_validation_exception_field.message), + field_list: Some(vec![first_validation_exception_field]), + }; + Self::ConstraintViolation( + crate::operation_ser::serialize_structure_crate_error_validation_exception(&validation_exception) + .expect("validation exceptions should never fail to serialize; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + ) + } + } + """, + *codegenScope, + ) + } + + override fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection): Writable = writable { + val constraintsInfo: List = stringConstraintsInfo.map(StringTraitInfo::toTraitInfo) + + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } + } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to constraintsInfo.map { it.asValidationExceptionField }.join("\n"), + ) + } + + override fun blobShapeConstraintViolationImplBlock(blobConstraintsInfo: Collection): Writable = writable { + val constraintsInfo: List = blobConstraintsInfo.map(BlobLength::toTraitInfo) + + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } + } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to constraintsInfo.map { it.asValidationExceptionField }.join("\n"), + ) + } + + override fun mapShapeConstraintViolationImplBlock( + shape: MapShape, + keyShape: StringShape, + valueShape: Shape, + symbolProvider: RustSymbolProvider, + model: Model, + ) = writable { + rustBlockTemplate( + "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", + "String" to RuntimeType.String, + ) { + rustBlock("match self") { + shape.getTrait()?.also { + rust( + """ + Self::Length(length) => crate::model::ValidationExceptionField { + message: format!("${it.validationErrorMessage()}", length, &path), + path, + },""", + ) + } + if (isKeyConstrained(keyShape, symbolProvider)) { + // Note how we _do not_ append the key's member name to the path. This is intentional, as + // per the `RestJsonMalformedLengthMapKey` test. Note keys are always strings. + // https://github.com/awslabs/smithy/blob/ee0b4ff90daaaa5101f32da936c25af8c91cc6e9/smithy-aws-protocol-tests/model/restJson1/validation/malformed-length.smithy#L296-L295 + rust("""Self::Key(key_constraint_violation) => key_constraint_violation.as_validation_exception_field(path),""") + } + if (isValueConstrained(valueShape, model, symbolProvider)) { + // `as_str()` works with regular `String`s and constrained string shapes. + rust("""Self::Value(key, value_constraint_violation) => value_constraint_violation.as_validation_exception_field(path + "/" + key.as_str()),""") + } + } + } + } + + override fun enumShapeConstraintViolationImplBlock(enumTrait: EnumTrait) = writable { + val enumValueSet = enumTrait.enumDefinitionValues.joinToString(", ") + val message = "Value {} at '{}' failed to satisfy constraint: Member must satisfy enum value set: [$enumValueSet]" + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + crate::model::ValidationExceptionField { + message: format!(r##"$message"##, &self.0, &path), + path, + } + } + """, + "String" to RuntimeType.String, + ) + } + + override fun numberShapeConstraintViolationImplBlock(rangeInfo: Range) = writable { + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{ValidationExceptionFields:W} + } + } + """, + "String" to RuntimeType.String, + "ValidationExceptionFields" to rangeInfo.toTraitInfo().asValidationExceptionField, + ) + } + + override fun builderConstraintViolationImplBlock(constraintViolations: Collection) = writable { + rustBlock("match self") { + constraintViolations.forEach { + if (it.hasInner()) { + rust("""ConstraintViolation::${it.name()}(inner) => inner.as_validation_exception_field(path + "/${it.forMember.memberName}"),""") + } else { + rust( + """ + ConstraintViolation::${it.name()} => crate::model::ValidationExceptionField { + message: format!("Value null at '{}/${it.forMember.memberName}' failed to satisfy constraint: Member must not be null", path), + path: path + "/${it.forMember.memberName}", + }, + """, + ) + } + } + } + } + + override fun collectionShapeConstraintViolationImplBlock( + collectionConstraintsInfo: + Collection, + isMemberConstrained: Boolean, + ) = writable { + val validationExceptionFields = collectionConstraintsInfo.map { + it.toTraitInfo().asValidationExceptionField + }.toMutableList() + if (isMemberConstrained) { + validationExceptionFields += { + rust( + """Self::Member(index, member_constraint_violation) => + member_constraint_violation.as_validation_exception_field(path + "/" + &index.to_string()) + """, + ) + } + } + rustTemplate( + """ + pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { + match self { + #{AsValidationExceptionFields:W} + } + } + """, + "String" to RuntimeType.String, + "AsValidationExceptionFields" to validationExceptionFields.join(""), + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt index b5b0f192955..8e771cb1222 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt @@ -11,6 +11,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.customize.CombinedCoreCod import software.amazon.smithy.rust.codegen.core.smithy.customize.CoreCodegenDecorator import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ValidationResult +import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import java.util.logging.Logger @@ -21,6 +23,13 @@ typealias ServerProtocolMap = ProtocolMap { fun protocols(serviceId: ShapeId, currentProtocols: ServerProtocolMap): ServerProtocolMap = currentProtocols + fun validationExceptionConversion(codegenContext: ServerCodegenContext): ValidationExceptionConversionGenerator? = null + + /** + * Injection point to allow a decorator to postprocess the error message that arises when an operation is + * constrained but the `ValidationException` shape is not attached to the operation's errors. + */ + fun postprocessValidationExceptionNotAttachedErrorMessage(validationResult: ValidationResult) = validationResult } /** @@ -28,9 +37,12 @@ interface ServerCodegenDecorator : CoreCodegenDecorator { * * This makes the actual concrete codegen simpler by not needing to deal with multiple separate decorators. */ -class CombinedServerCodegenDecorator(decorators: List) : +class CombinedServerCodegenDecorator(private val decorators: List) : CombinedCoreCodegenDecorator(decorators), ServerCodegenDecorator { + + private val orderedDecorators = decorators.sortedBy { it.order } + override val name: String get() = "CombinedServerCodegenDecorator" override val order: Byte @@ -41,6 +53,16 @@ class CombinedServerCodegenDecorator(decorators: List) : decorator.protocols(serviceId, protocolMap) } + override fun validationExceptionConversion(codegenContext: ServerCodegenContext): ValidationExceptionConversionGenerator = + // We use `firstNotNullOf` instead of `firstNotNullOfOrNull` because the [SmithyValidationExceptionDecorator] + // is registered. + orderedDecorators.firstNotNullOf { it.validationExceptionConversion(codegenContext) } + + override fun postprocessValidationExceptionNotAttachedErrorMessage(validationResult: ValidationResult): ValidationResult = + orderedDecorators.foldRight(validationResult) { decorator, accumulated -> + decorator.postprocessValidationExceptionNotAttachedErrorMessage(accumulated) + } + companion object { fun fromClasspath( context: PluginContext, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/CollectionConstraintViolationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/CollectionConstraintViolationGenerator.kt index 7867b045c42..e2a177f5360 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/CollectionConstraintViolationGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/CollectionConstraintViolationGenerator.kt @@ -6,23 +6,25 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.CollectionShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility import software.amazon.smithy.rust.codegen.core.rustlang.join -import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput class CollectionConstraintViolationGenerator( codegenContext: ServerCodegenContext, - private val modelsModuleWriter: RustWriter, + private val inlineModuleCreator: InlineModuleCreator, private val shape: CollectionShape, - private val constraintsInfo: List, + private val collectionConstraintsInfo: List, + private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) { private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider @@ -35,18 +37,25 @@ class CollectionConstraintViolationGenerator( PubCrateConstraintViolationSymbolProvider(this) } } + private val constraintsInfo: List = collectionConstraintsInfo.map { it.toTraitInfo() } fun render() { - val memberShape = model.expectShape(shape.member.target) + val targetShape = model.expectShape(shape.member.target) val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) val constraintViolationName = constraintViolationSymbol.name - val isMemberConstrained = memberShape.canReachConstrainedShape(model, symbolProvider) + val isMemberConstrained = targetShape.canReachConstrainedShape(model, symbolProvider) val constraintViolationVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE) - modelsModuleWriter.withInlineModule(constraintViolationSymbol.module()) { + inlineModuleCreator(constraintViolationSymbol) { val constraintViolationVariants = constraintsInfo.map { it.constraintViolationVariant }.toMutableList() if (isMemberConstrained) { constraintViolationVariants += { + val memberConstraintViolationSymbol = + constraintViolationSymbolProvider.toSymbol(targetShape).letIf( + shape.member.hasTrait(), + ) { + it.makeRustBoxed() + } rustTemplate( """ /// Constraint violation error when an element doesn't satisfy its own constraints. @@ -55,7 +64,7 @@ class CollectionConstraintViolationGenerator( ##[doc(hidden)] Member(usize, #{MemberConstraintViolationSymbol}) """, - "MemberConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(memberShape), + "MemberConstraintViolationSymbol" to memberConstraintViolationSymbol, ) } } @@ -66,6 +75,7 @@ class CollectionConstraintViolationGenerator( // and is for use by the framework. rustTemplate( """ + ##[allow(clippy::enum_variant_names)] ##[derive(Debug, PartialEq)] ${constraintViolationVisibility.toRustQualifier()} enum $constraintViolationName { #{ConstraintViolationVariants:W} @@ -75,30 +85,13 @@ class CollectionConstraintViolationGenerator( ) if (shape.isReachableFromOperationInput()) { - val validationExceptionFields = constraintsInfo.map { it.asValidationExceptionField }.toMutableList() - if (isMemberConstrained) { - validationExceptionFields += { - rust( - """ - Self::Member(index, member_constraint_violation) => - member_constraint_violation.as_validation_exception_field(path + "/" + &index.to_string()) - """, - ) - } - } - rustTemplate( """ impl $constraintViolationName { - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - match self { - #{AsValidationExceptionFields:W} - } - } + #{CollectionShapeConstraintViolationImplBlock} } """, - "String" to RuntimeType.String, - "AsValidationExceptionFields" to validationExceptionFields.join("\n"), + "CollectionShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.collectionShapeConstraintViolationImplBlock(collectionConstraintsInfo, isMemberConstrained), ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt index 41fec1cec75..5a1c3bc4e18 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGenerator.kt @@ -21,9 +21,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained -import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.orNull +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput @@ -31,8 +31,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage class ConstrainedBlobGenerator( val codegenContext: ServerCodegenContext, + private val inlineModuleCreator: InlineModuleCreator, val writer: RustWriter, val shape: BlobShape, + private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) { val model = codegenContext.model val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider @@ -45,9 +47,10 @@ class ConstrainedBlobGenerator( PubCrateConstraintViolationSymbolProvider(this) } } - private val constraintsInfo: List = listOf(LengthTrait::class.java) + private val blobConstraintsInfo: List = listOf(LengthTrait::class.java) .mapNotNull { shape.getTrait(it).orNull() } - .map { BlobLength(it).toTraitInfo() } + .map { BlobLength(it) } + private val constraintsInfo: List = blobConstraintsInfo.map { it.toTraitInfo() } fun render() { val symbol = constrainedShapeSymbolProvider.toSymbol(shape) @@ -108,7 +111,7 @@ class ConstrainedBlobGenerator( "From" to RuntimeType.From, ) - writer.withInlineModule(constraintViolation.module()) { + inlineModuleCreator(constraintViolation) { renderConstraintViolationEnum(this, shape, constraintViolation) } } @@ -128,21 +131,16 @@ class ConstrainedBlobGenerator( writer.rustTemplate( """ impl ${constraintViolation.name} { - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - match self { - #{ValidationExceptionFields:W} - } - } + #{BlobShapeConstraintViolationImplBlock} } """, - "String" to RuntimeType.String, - "ValidationExceptionFields" to constraintsInfo.map { it.asValidationExceptionField }.join("\n"), + "BlobShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.blobShapeConstraintViolationImplBlock(blobConstraintsInfo), ) } } } -private data class BlobLength(val lengthTrait: LengthTrait) { +data class BlobLength(val lengthTrait: LengthTrait) { fun toTraitInfo(): TraitInfo = TraitInfo( { rust("Self::check_length(&value)?;") }, { @@ -155,8 +153,7 @@ private data class BlobLength(val lengthTrait: LengthTrait) { Self::Length(length) => crate::model::ValidationExceptionField { message: format!("${lengthTrait.validationErrorMessage()}", length, &path), path, - }, - """, + },""", ) }, this::renderValidationFunction, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt index 4463220a964..9b5775478e9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGenerator.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.EnumShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.LengthTrait @@ -48,7 +49,7 @@ class ConstrainedCollectionGenerator( val codegenContext: ServerCodegenContext, val writer: RustWriter, val shape: CollectionShape, - private val constraintsInfo: List, + collectionConstraintsInfo: List, private val unconstrainedSymbol: Symbol? = null, ) { private val model = codegenContext.model @@ -63,6 +64,7 @@ class ConstrainedCollectionGenerator( } } private val symbolProvider = codegenContext.symbolProvider + private val constraintsInfo = collectionConstraintsInfo.map { it.toTraitInfo() } fun render() { check(constraintsInfo.isNotEmpty()) { @@ -114,7 +116,9 @@ class ConstrainedCollectionGenerator( #{ValidationFunctions:W} """, *codegenScope, - "ValidationFunctions" to constraintsInfo.map { it.validationFunctionDefinition(constraintViolation, inner) }.join("\n"), + "ValidationFunctions" to constraintsInfo.map { + it.validationFunctionDefinition(constraintViolation, inner) + }.join("\n"), ) } @@ -145,7 +149,8 @@ class ConstrainedCollectionGenerator( if (!publicConstrainedTypes && innerShape.canReachConstrainedShape(model, symbolProvider) && innerShape !is StructureShape && - innerShape !is UnionShape + innerShape !is UnionShape && + innerShape !is EnumShape ) { writer.rustTemplate( """ @@ -178,7 +183,7 @@ class ConstrainedCollectionGenerator( } } -internal sealed class CollectionTraitInfo { +sealed class CollectionTraitInfo { data class UniqueItems(val uniqueItemsTrait: UniqueItemsTrait, val memberSymbol: Symbol) : CollectionTraitInfo() { override fun toTraitInfo(): TraitInfo = TraitInfo( @@ -245,7 +250,7 @@ internal sealed class CollectionTraitInfo { // [1]: https://github.com/awslabs/smithy-typescript/blob/517c85f8baccf0e5334b4e66d8786bdb5791c595/smithy-typescript-ssdk-libs/server-common/src/validation/index.ts#L106-L111 rust( """ - Self::UniqueItems { duplicate_indices, .. } => + Self::UniqueItems { duplicate_indices, .. } => crate::model::ValidationExceptionField { message: format!("${uniqueItemsTrait.validationErrorMessage()}", &duplicate_indices, &path), path, @@ -365,11 +370,10 @@ internal sealed class CollectionTraitInfo { } } - fun fromShape(shape: CollectionShape, symbolProvider: SymbolProvider): List = + fun fromShape(shape: CollectionShape, symbolProvider: SymbolProvider): List = supportedCollectionConstraintTraits .mapNotNull { shape.getTrait(it).orNull() } .map { trait -> fromTrait(trait, shape, symbolProvider) } - .map(CollectionTraitInfo::toTraitInfo) } abstract fun toTraitInfo(): TraitInfo diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt index e5721e77414..28b0d9f8d75 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGenerator.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.LengthTrait @@ -21,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.typeNameContainsNonPublicType /** * [ConstrainedMapGenerator] generates a wrapper tuple newtype holding a constrained `std::collections::HashMap`. @@ -130,6 +132,14 @@ class ConstrainedMapGenerator( valueShape !is StructureShape && valueShape !is UnionShape ) { + val keyShape = model.expectShape(shape.key.target, StringShape::class.java) + val keyNeedsConversion = keyShape.typeNameContainsNonPublicType(model, symbolProvider, publicConstrainedTypes) + val key = if (keyNeedsConversion) { + "k.into()" + } else { + "k" + } + writer.rustTemplate( """ impl #{From}<$name> for #{FullyUnconstrainedSymbol} { @@ -137,7 +147,7 @@ class ConstrainedMapGenerator( value .into_inner() .into_iter() - .map(|(k, v)| (k, v.into())) + .map(|(k, v)| ($key, v.into())) .collect() } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt index 281f0005c11..45ef7dd0a14 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGenerator.kt @@ -21,16 +21,14 @@ import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.documentShape import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained -import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput @@ -43,8 +41,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage */ class ConstrainedNumberGenerator( val codegenContext: ServerCodegenContext, - val writer: RustWriter, + private val inlineModuleCreator: InlineModuleCreator, + private val writer: RustWriter, val shape: NumberShape, + private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) { val model = codegenContext.model val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider @@ -74,7 +74,8 @@ class ConstrainedNumberGenerator( val name = symbol.name val unconstrainedTypeName = unconstrainedType.render() val constraintViolation = constraintViolationSymbolProvider.toSymbol(shape) - val constraintsInfo = listOf(Range(rangeTrait).toTraitInfo(unconstrainedTypeName)) + val rangeInfo = Range(rangeTrait) + val constraintsInfo = listOf(rangeInfo.toTraitInfo()) writer.documentShape(shape, model) writer.docs(rustDocsConstrainedTypeEpilogue(name)) @@ -132,7 +133,7 @@ class ConstrainedNumberGenerator( writer.renderTryFrom(unconstrainedTypeName, name, constraintViolation, constraintsInfo) - writer.withInlineModule(constraintViolation.module()) { + inlineModuleCreator(constraintViolation) { rust( """ ##[derive(Debug, PartialEq)] @@ -143,35 +144,23 @@ class ConstrainedNumberGenerator( ) if (shape.isReachableFromOperationInput()) { - rustBlock("impl ${constraintViolation.name}") { - rustBlockTemplate( - "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", - "String" to RuntimeType.String, - ) { - rustBlock("match self") { - rust( - """ - Self::Range(value) => crate::model::ValidationExceptionField { - message: format!("${rangeTrait.validationErrorMessage()}", value, &path), - path, - }, - """, - ) - } + rustTemplate( + """ + impl ${constraintViolation.name} { + #{NumberShapeConstraintViolationImplBlock} } - } + """, + "NumberShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.numberShapeConstraintViolationImplBlock(rangeInfo), + ) } } } } -private data class Range(val rangeTrait: RangeTrait) { - fun toTraitInfo(unconstrainedTypeName: String): TraitInfo = TraitInfo( +data class Range(val rangeTrait: RangeTrait) { + fun toTraitInfo(): TraitInfo = TraitInfo( { rust("Self::check_range(value)?;") }, - { - docs("Error when a number doesn't satisfy its `@range` requirements.") - rust("Range($unconstrainedTypeName)") - }, + { docs("Error when a number doesn't satisfy its `@range` requirements.") }, { rust( """ diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt index 20a6746aa81..7d3fe751101 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGenerator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.LengthTrait import software.amazon.smithy.model.traits.PatternTrait +import software.amazon.smithy.model.traits.SensitiveTrait import software.amazon.smithy.model.traits.Trait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute import software.amazon.smithy.rust.codegen.core.rustlang.RustType @@ -22,15 +23,17 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.render import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained -import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.testModuleForShape import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.PANIC +import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext @@ -45,8 +48,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage */ class ConstrainedStringGenerator( val codegenContext: ServerCodegenContext, - val writer: RustWriter, + private val inlineModuleCreator: InlineModuleCreator, + private val writer: RustWriter, val shape: StringShape, + private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) { val model = codegenContext.model val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider @@ -60,10 +65,12 @@ class ConstrainedStringGenerator( } } private val symbol = constrainedShapeSymbolProvider.toSymbol(shape) - private val constraintsInfo: List = + private val stringConstraintsInfo: List = supportedStringConstraintTraits .mapNotNull { shape.getTrait(it).orNull() } - .map { StringTraitInfo.fromTrait(symbol, it) } + .map { StringTraitInfo.fromTrait(symbol, it, isSensitive = shape.hasTrait()) } + private val constraintsInfo: List = + stringConstraintsInfo .map(StringTraitInfo::toTraitInfo) fun render() { @@ -133,7 +140,7 @@ class ConstrainedStringGenerator( "From" to RuntimeType.From, ) - writer.withInlineModule(constraintViolation.module()) { + inlineModuleCreator(constraintViolation) { renderConstraintViolationEnum(this, shape, constraintViolation) } @@ -155,15 +162,10 @@ class ConstrainedStringGenerator( writer.rustTemplate( """ impl ${constraintViolation.name} { - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - match self { - #{ValidationExceptionFields:W} - } - } + #{StringShapeConstraintViolationImplBlock:W} } """, - "String" to RuntimeType.String, - "ValidationExceptionFields" to constraintsInfo.map { it.asValidationExceptionField }.join("\n"), + "StringShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.stringShapeConstraintViolationImplBlock(stringConstraintsInfo), ) } } @@ -184,7 +186,7 @@ class ConstrainedStringGenerator( } } } -private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() { +data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() { override fun toTraitInfo(): TraitInfo = TraitInfo( tryFromCheck = { rust("Self::check_length(&value)?;") }, constraintViolationVariant = { @@ -229,10 +231,8 @@ private data class Length(val lengthTrait: LengthTrait) : StringTraitInfo() { } } -private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) : StringTraitInfo() { +data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait, val isSensitive: Boolean) : StringTraitInfo() { override fun toTraitInfo(): TraitInfo { - val pattern = patternTrait.pattern - return TraitInfo( tryFromCheck = { rust("let value = Self::check_pattern(value)?;") }, constraintViolationVariant = { @@ -241,13 +241,15 @@ private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) : rust("Pattern(String)") }, asValidationExceptionField = { - rust( + Attribute.AllowUnusedVariables.render(this) + rustTemplate( """ Self::Pattern(string) => crate::model::ValidationExceptionField { - message: format!("${patternTrait.validationErrorMessage()}", &string, &path, r##"$pattern"##), + message: #{ErrorMessage:W}, path }, """, + "ErrorMessage" to errorMessage(), ) }, this::renderValidationFunction, @@ -264,6 +266,28 @@ private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) : ) } + fun errorMessage(): Writable { + val pattern = patternTrait.pattern + + return if (isSensitive) { + writable { + rust( + """ + format!("Value at '{}' failed to satisfy constraint: Member must satisfy regular expression pattern: {}", &path, r##"$pattern"##) + """, + ) + } + } else { + writable { + rust( + """ + format!("Value {} at '{}' failed to satisfy constraint: Member must satisfy regular expression pattern: {}", &string, &path, r##"$pattern"##) + """, + ) + } + } + } + /** * Renders a `check_pattern` function to validate the string matches the * supplied regex in the `@pattern` trait. @@ -301,16 +325,18 @@ private data class Pattern(val symbol: Symbol, val patternTrait: PatternTrait) : } } -private sealed class StringTraitInfo { +sealed class StringTraitInfo { companion object { - fun fromTrait(symbol: Symbol, trait: Trait) = + fun fromTrait(symbol: Symbol, trait: Trait, isSensitive: Boolean) = when (trait) { is PatternTrait -> { - Pattern(symbol, trait) + Pattern(symbol, trait, isSensitive) } + is LengthTrait -> { Length(trait) } + else -> PANIC("StringTraitInfo.fromTrait called with unsupported trait $trait") } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt index 759f8870883..a0dcf07ba9e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/DocHandlerGenerator.kt @@ -12,12 +12,11 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule -import software.amazon.smithy.rust.codegen.core.smithy.InputsModule -import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.outputShape +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Error as ErrorModule +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Input as InputModule +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output as OutputModule /** * Generates a handler implementation stub for use within documentation. @@ -33,22 +32,22 @@ class DocHandlerGenerator( private val inputSymbol = symbolProvider.toSymbol(operation.inputShape(model)) private val outputSymbol = symbolProvider.toSymbol(operation.outputShape(model)) - private val errorSymbol = operation.errorSymbol(symbolProvider) + private val errorSymbol = symbolProvider.symbolForOperationError(operation) /** * Returns the function signature for an operation handler implementation. Used in the documentation. */ fun docSignature(): Writable { val outputT = if (operation.errors.isEmpty()) { - "${OutputsModule.name}::${outputSymbol.name}" + "${OutputModule.name}::${outputSymbol.name}" } else { - "Result<${OutputsModule.name}::${outputSymbol.name}, ${ErrorsModule.name}::${errorSymbol.name}>" + "Result<${OutputModule.name}::${outputSymbol.name}, ${ErrorModule.name}::${errorSymbol.name}>" } return writable { rust( """ - $commentToken async fn $handlerName(input: ${InputsModule.name}::${inputSymbol.name}) -> $outputT { + $commentToken async fn $handlerName(input: ${InputModule.name}::${inputSymbol.name}) -> $outputT { $commentToken todo!() $commentToken } """.trimIndent(), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt index cfcf5a53e1e..065a4067c77 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt @@ -8,25 +8,22 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.LengthTrait -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility -import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.module -import software.amazon.smithy.rust.codegen.core.util.getTrait +import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.letIf +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput -import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage class MapConstraintViolationGenerator( codegenContext: ServerCodegenContext, - private val modelsModuleWriter: RustWriter, + private val inlineModuleCreator: InlineModuleCreator, val shape: MapShape, + private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) { private val model = codegenContext.model private val constrainedShapeSymbolProvider = codegenContext.constrainedShapeSymbolProvider @@ -52,7 +49,14 @@ class MapConstraintViolationGenerator( constraintViolationCodegenScopeMutableList.add("KeyConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(keyShape)) } if (isValueConstrained(valueShape, model, symbolProvider)) { - constraintViolationCodegenScopeMutableList.add("ValueConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(valueShape)) + constraintViolationCodegenScopeMutableList.add( + "ValueConstraintViolationSymbol" to + constraintViolationSymbolProvider.toSymbol(valueShape).letIf( + shape.value.hasTrait(), + ) { + it.makeRustBoxed() + }, + ) constraintViolationCodegenScopeMutableList.add("KeySymbol" to constrainedShapeSymbolProvider.toSymbol(keyShape)) } val constraintViolationCodegenScope = constraintViolationCodegenScopeMutableList.toTypedArray() @@ -62,13 +66,15 @@ class MapConstraintViolationGenerator( } else { Visibility.PUBCRATE } - modelsModuleWriter.withInlineModule(constraintViolationSymbol.module()) { + + inlineModuleCreator(constraintViolationSymbol) { // TODO(https://github.com/awslabs/smithy-rs/issues/1401) We should really have two `ConstraintViolation` // types here. One will just have variants for each constraint trait on the map shape, for use by the user. // The other one will have variants if the shape's key or value is directly or transitively constrained, // and is for use by the framework. rustTemplate( """ + ##[allow(clippy::enum_variant_names)] ##[derive(Debug, PartialEq)] pub${ if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate) " else "" } enum $constraintViolationName { ${if (shape.hasTrait()) "Length(usize)," else ""} @@ -80,35 +86,20 @@ class MapConstraintViolationGenerator( ) if (shape.isReachableFromOperationInput()) { - rustBlock("impl $constraintViolationName") { - rustBlockTemplate( - "pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField", - "String" to RuntimeType.String, - ) { - rustBlock("match self") { - shape.getTrait()?.also { - rust( - """ - Self::Length(length) => crate::model::ValidationExceptionField { - message: format!("${it.validationErrorMessage()}", length, &path), - path, - }, - """, - ) - } - if (isKeyConstrained(keyShape, symbolProvider)) { - // Note how we _do not_ append the key's member name to the path. This is intentional, as - // per the `RestJsonMalformedLengthMapKey` test. Note keys are always strings. - // https://github.com/awslabs/smithy/blob/ee0b4ff90daaaa5101f32da936c25af8c91cc6e9/smithy-aws-protocol-tests/model/restJson1/validation/malformed-length.smithy#L296-L295 - rust("""Self::Key(key_constraint_violation) => key_constraint_violation.as_validation_exception_field(path),""") - } - if (isValueConstrained(valueShape, model, symbolProvider)) { - // `as_str()` works with regular `String`s and constrained string shapes. - rust("""Self::Value(key, value_constraint_violation) => value_constraint_violation.as_validation_exception_field(path + "/" + key.as_str()),""") - } - } + rustTemplate( + """ + impl $constraintViolationName { + #{MapShapeConstraintViolationImplBlock} } - } + """, + "MapShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.mapShapeConstraintViolationImplBlock( + shape, + keyShape, + valueShape, + symbolProvider, + model, + ), + ) } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt index 09f9352cdec..1a563c25b98 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedCollectionGenerator.kt @@ -7,7 +7,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.MapShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock @@ -16,7 +15,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional -import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained @@ -41,7 +40,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.typeNameContainsNonPubl */ class PubCrateConstrainedCollectionGenerator( val codegenContext: ServerCodegenContext, - val writer: RustWriter, + private val inlineModuleCreator: InlineModuleCreator, val shape: CollectionShape, ) { private val model = codegenContext.model @@ -74,7 +73,7 @@ class PubCrateConstrainedCollectionGenerator( "From" to RuntimeType.From, ) - writer.withInlineModule(constrainedSymbol.module()) { + inlineModuleCreator(constrainedSymbol) { rustTemplate( """ ##[derive(Debug, Clone)] @@ -109,11 +108,11 @@ class PubCrateConstrainedCollectionGenerator( impl #{From}<#{Symbol}> for $name { fn from(v: #{Symbol}) -> Self { ${ - if (innerNeedsConstraining) { - "Self(v.into_iter().map(|item| item.into()).collect())" - } else { - "Self(v)" - } + if (innerNeedsConstraining) { + "Self(v.into_iter().map(|item| item.into()).collect())" + } else { + "Self(v)" + } } } } @@ -121,11 +120,11 @@ class PubCrateConstrainedCollectionGenerator( impl #{From}<$name> for #{Symbol} { fn from(v: $name) -> Self { ${ - if (innerNeedsConstraining) { - "v.0.into_iter().map(|item| item.into()).collect()" - } else { - "v.0" - } + if (innerNeedsConstraining) { + "v.0.into_iter().map(|item| item.into()).collect()" + } else { + "v.0" + } } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt index 9d5ad811253..838d4da0856 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/PubCrateConstrainedMapGenerator.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.CollectionShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock @@ -17,7 +16,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional -import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained @@ -40,7 +39,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.typeNameContainsNonPubl */ class PubCrateConstrainedMapGenerator( val codegenContext: ServerCodegenContext, - val writer: RustWriter, + private val inlineModuleCreator: InlineModuleCreator, val shape: MapShape, ) { private val model = codegenContext.model @@ -75,7 +74,7 @@ class PubCrateConstrainedMapGenerator( "From" to RuntimeType.From, ) - writer.withInlineModule(constrainedSymbol.module()) { + inlineModuleCreator(constrainedSymbol) { rustTemplate( """ ##[derive(Debug, Clone)] diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt index 55eae76c510..1ecb71d8785 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt @@ -17,17 +17,16 @@ import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed -import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait /** * Renders constraint violation types that arise when building a structure shape builder. @@ -38,6 +37,7 @@ class ServerBuilderConstraintViolations( codegenContext: ServerCodegenContext, private val shape: StructureShape, private val builderTakesInUnconstrainedTypes: Boolean, + private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) { private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider @@ -71,7 +71,11 @@ class ServerBuilderConstraintViolations( writer.docs("Holds one variant for each of the ways the builder can fail.") if (nonExhaustive) Attribute.NonExhaustive.render(writer) val constraintViolationSymbolName = constraintViolationSymbolProvider.toSymbol(shape).name - writer.rustBlock("pub${if (visibility == Visibility.PUBCRATE) " (crate) " else ""} enum $constraintViolationSymbolName") { + writer.rustBlock( + """ + ##[allow(clippy::enum_variant_names)] + pub${if (visibility == Visibility.PUBCRATE) " (crate) " else ""} enum $constraintViolationSymbolName""", + ) { renderConstraintViolations(writer) } @@ -129,7 +133,11 @@ class ServerBuilderConstraintViolations( for (constraintViolation in all) { when (constraintViolation.kind) { ConstraintViolationKind.MISSING_MEMBER -> { - writer.docs("${constraintViolation.message(symbolProvider, model).replaceFirstChar { it.uppercaseChar() }}.") + writer.docs( + "${constraintViolation.message(symbolProvider, model).replaceFirstChar { + it.uppercaseChar() + }}.", + ) writer.rust("${constraintViolation.name()},") } @@ -138,14 +146,18 @@ class ServerBuilderConstraintViolations( val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(targetShape) - // If the corresponding structure's member is boxed, box this constraint violation symbol too. - .letIf(constraintViolation.forMember.hasTrait()) { + // Box this constraint violation symbol if necessary. + .letIf(constraintViolation.forMember.hasTrait()) { it.makeRustBoxed() } // Note we cannot express the inner constraint violation as `>::Error`, because `T` might // be `pub(crate)` and that would leak `T` in a public interface. - writer.docs("${constraintViolation.message(symbolProvider, model)}.".replaceFirstChar { it.uppercaseChar() }) + writer.docs( + "${constraintViolation.message(symbolProvider, model)}.".replaceFirstChar { + it.uppercaseChar() + }, + ) Attribute.DocHidden.render(writer) writer.rust("${constraintViolation.name()}(#T),", constraintViolationSymbol) } @@ -154,25 +166,6 @@ class ServerBuilderConstraintViolations( } private fun renderAsValidationExceptionFieldList(writer: RustWriter) { - val validationExceptionFieldWritable = writable { - rustBlock("match self") { - all.forEach { - if (it.hasInner()) { - rust("""ConstraintViolation::${it.name()}(inner) => inner.as_validation_exception_field(path + "/${it.forMember.memberName}"),""") - } else { - rust( - """ - ConstraintViolation::${it.name()} => crate::model::ValidationExceptionField { - message: format!("Value null at '{}/${it.forMember.memberName}' failed to satisfy constraint: Member must not be null", path), - path: path + "/${it.forMember.memberName}", - }, - """, - ) - } - } - } - } - writer.rustTemplate( """ impl ConstraintViolation { @@ -181,7 +174,7 @@ class ServerBuilderConstraintViolations( } } """, - "ValidationExceptionFieldWritable" to validationExceptionFieldWritable, + "ValidationExceptionFieldWritable" to validationExceptionConversionGenerator.builderConstraintViolationImplBlock((all)), "String" to RuntimeType.String, ) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt index 27bfa69d32c..c544c5b4c35 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt @@ -29,6 +29,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed @@ -49,7 +50,9 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTraitOrTargetHasConstraintTrait import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput +import software.amazon.smithy.rust.codegen.server.smithy.withInMemoryInlineModule import software.amazon.smithy.rust.codegen.server.smithy.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled /** @@ -86,8 +89,9 @@ import software.amazon.smithy.rust.codegen.server.smithy.wouldHaveConstrainedWra * [derive_builder]: https://docs.rs/derive_builder/latest/derive_builder/index.html */ class ServerBuilderGenerator( - codegenContext: ServerCodegenContext, + val codegenContext: ServerCodegenContext, private val shape: StructureShape, + private val customValidationExceptionWithReasonConversionGenerator: ValidationExceptionConversionGenerator, ) { companion object { /** @@ -141,7 +145,7 @@ class ServerBuilderGenerator( private val builderSymbol = shape.serverBuilderSymbol(codegenContext) private val isBuilderFallible = hasFallibleBuilder(shape, model, symbolProvider, takeInUnconstrainedTypes) private val serverBuilderConstraintViolations = - ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes) + ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes, customValidationExceptionWithReasonConversionGenerator) private val codegenScope = arrayOf( "RequestRejection" to ServerRuntimeType.requestRejection(runtimeConfig), @@ -151,9 +155,9 @@ class ServerBuilderGenerator( "MaybeConstrained" to RuntimeType.MaybeConstrained, ) - fun render(writer: RustWriter) { - writer.docs("See #D.", structureSymbol) - writer.withInlineModule(builderSymbol.module()) { + fun render(rustCrate: RustCrate, writer: RustWriter) { + val docWriter: () -> Unit = { writer.docs("See #D.", structureSymbol) } + rustCrate.withInMemoryInlineModule(writer, builderSymbol.module(), docWriter) { renderBuilder(this) } } @@ -187,7 +191,9 @@ class ServerBuilderGenerator( // since we are a builder and everything is optional. val baseDerives = structureSymbol.expectRustMetadata().derives // Filter out any derive that isn't Debug or Clone. Then add a Default derive - val builderDerives = baseDerives.filter { it == RuntimeType.Debug || it == RuntimeType.Clone } + RuntimeType.Default + val builderDerives = baseDerives.filter { + it == RuntimeType.Debug || it == RuntimeType.Clone + } + RuntimeType.Default Attribute(derive(builderDerives)).render(writer) writer.rustBlock("${visibility.toRustQualifier()} struct Builder") { members.forEach { renderBuilderMember(this, it) } @@ -214,21 +220,9 @@ class ServerBuilderGenerator( private fun renderImplFromConstraintViolationForRequestRejection(writer: RustWriter) { writer.rustTemplate( """ - impl #{From} for #{RequestRejection} { - fn from(constraint_violation: ConstraintViolation) -> Self { - let first_validation_exception_field = constraint_violation.as_validation_exception_field("".to_owned()); - let validation_exception = crate::error::ValidationException { - message: format!("1 validation error detected. {}", &first_validation_exception_field.message), - field_list: Some(vec![first_validation_exception_field]), - }; - Self::ConstraintViolation( - crate::operation_ser::serialize_structure_crate_error_validation_exception(&validation_exception) - .expect("impossible") - ) - } - } + #{Converter:W} """, - *codegenScope, + "Converter" to customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(), ) } @@ -401,12 +395,12 @@ class ServerBuilderGenerator( rust( """ self.$memberName = ${ - // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): See above. - if (symbolProvider.toSymbol(member).isOptional()) { - "input.map(|v| v.into())" - } else { - "Some(input.into())" - } + // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): See above. + if (symbolProvider.toSymbol(member).isOptional()) { + "input.map(|v| v.into())" + } else { + "Some(input.into())" + } }; self """, @@ -552,6 +546,8 @@ class ServerBuilderGenerator( val hasBox = builderMemberSymbol(member) .mapRustType { it.stripOuter() } .isRustBoxed() + val errHasBox = member.hasTrait() + if (hasBox) { writer.rustTemplate( """ @@ -559,11 +555,6 @@ class ServerBuilderGenerator( #{MaybeConstrained}::Constrained(x) => Ok(Box::new(x)), #{MaybeConstrained}::Unconstrained(x) => Ok(Box::new(x.try_into()?)), }) - .map(|res| - res${ if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())" } - .map_err(|err| ConstraintViolation::${constraintViolation.name()}(Box::new(err))) - ) - .transpose()? """, *codegenScope, ) @@ -574,16 +565,22 @@ class ServerBuilderGenerator( #{MaybeConstrained}::Constrained(x) => Ok(x), #{MaybeConstrained}::Unconstrained(x) => x.try_into(), }) - .map(|res| - res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"} - .map_err(ConstraintViolation::${constraintViolation.name()}) - ) - .transpose()? """, *codegenScope, ) } + writer.rustTemplate( + """ + .map(|res| + res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"} ${if (errHasBox) ".map_err(Box::new)" else "" } + .map_err(ConstraintViolation::${constraintViolation.name()}) + ) + .transpose()? + """, + *codegenScope, + ) + // Constrained types are not public and this is a member shape that would have generated a // public constrained type, were the setting to be enabled. // We've just checked the constraints hold by going through the non-public diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt index f16e2640b2e..389f0dc173e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt @@ -39,6 +39,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumMemberModel import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.expectTrait @@ -143,7 +144,9 @@ fun defaultValue( .entries .filter { entry -> entry.value == value } .map { entry -> - symbolProvider.toEnumVariantName( + EnumMemberModel.toEnumVariantName( + symbolProvider, + target, EnumDefinition.builder().name(entry.key).value(entry.value.toString()).build(), )!! } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt index b30252a8a74..2ae17d1a80a 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt @@ -23,12 +23,14 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeOptional import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.withInMemoryInlineModule /** * Generates a builder for the Rust type associated with the [StructureShape]. @@ -46,6 +48,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType class ServerBuilderGeneratorWithoutPublicConstrainedTypes( private val codegenContext: ServerCodegenContext, shape: StructureShape, + validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, ) { companion object { /** @@ -79,7 +82,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( private val builderSymbol = shape.serverBuilderSymbol(symbolProvider, false) private val isBuilderFallible = hasFallibleBuilder(shape, symbolProvider) private val serverBuilderConstraintViolations = - ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false) + ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false, validationExceptionConversionGenerator) private val codegenScope = arrayOf( "RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig), @@ -89,12 +92,12 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( "MaybeConstrained" to RuntimeType.MaybeConstrained, ) - fun render(writer: RustWriter) { + fun render(rustCrate: RustCrate, writer: RustWriter) { check(!codegenContext.settings.codegenConfig.publicConstrainedTypes) { "ServerBuilderGeneratorWithoutPublicConstrainedTypes should only be used when `publicConstrainedTypes` is false" } - writer.docs("See #D.", structureSymbol) - writer.withInlineModule(builderSymbol.module()) { + val docWriter = { writer.docs("See #D.", structureSymbol) } + rustCrate.withInMemoryInlineModule(writer, builderSymbol.module(), docWriter) { renderBuilder(this) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt index 9720717383c..948aea0d897 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderSymbol.kt @@ -2,7 +2,6 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ - package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol @@ -17,13 +16,15 @@ import software.amazon.smithy.rust.codegen.core.smithy.rustType import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +// TODO(https://github.com/awslabs/smithy-rs/issues/2396): Replace this with `RustSymbolProvider.symbolForBuilder` fun StructureShape.serverBuilderSymbol(codegenContext: ServerCodegenContext): Symbol = this.serverBuilderSymbol( codegenContext.symbolProvider, !codegenContext.settings.codegenConfig.publicConstrainedTypes, ) -fun StructureShape.serverBuilderSymbol(symbolProvider: SymbolProvider, pubCrate: Boolean): Symbol { +// TODO(https://github.com/awslabs/smithy-rs/issues/2396): Replace this with `RustSymbolProvider.moduleForBuilder` +fun StructureShape.serverBuilderModule(symbolProvider: SymbolProvider, pubCrate: Boolean): RustModule.LeafModule { val structureSymbol = symbolProvider.toSymbol(this) val builderNamespace = RustReservedWords.escapeIfNeeded(structureSymbol.name.toSnakeCase()) + if (pubCrate) { @@ -35,7 +36,12 @@ fun StructureShape.serverBuilderSymbol(symbolProvider: SymbolProvider, pubCrate: true -> Visibility.PUBCRATE false -> Visibility.PUBLIC } - val builderModule = RustModule.new(builderNamespace, visibility, parent = structureSymbol.module(), inline = true) + return RustModule.new(builderNamespace, visibility, parent = structureSymbol.module(), inline = true) +} + +// TODO(https://github.com/awslabs/smithy-rs/issues/2396): Replace this with `RustSymbolProvider.symbolForBuilder` +fun StructureShape.serverBuilderSymbol(symbolProvider: SymbolProvider, pubCrate: Boolean): Symbol { + val builderModule = serverBuilderModule(symbolProvider, pubCrate) val rustType = RustType.Opaque("Builder", builderModule.fullyQualifiedPath()) return Symbol.builder() .rustType(rustType) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt index 1a55cb88796..88ca5e4fef7 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt @@ -5,27 +5,26 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGeneratorContext +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumType import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput -open class ServerEnumGenerator( - val codegenContext: ServerCodegenContext, - private val writer: RustWriter, - shape: StringShape, -) : EnumGenerator(codegenContext.model, codegenContext.symbolProvider, writer, shape, shape.expectTrait()) { - override var target: CodegenTarget = CodegenTarget.SERVER - +open class ConstrainedEnum( + codegenContext: ServerCodegenContext, + private val shape: StringShape, + private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, +) : EnumType() { private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes private val constraintViolationSymbolProvider = with(codegenContext.constraintViolationSymbolProvider) { @@ -41,8 +40,8 @@ open class ServerEnumGenerator( "String" to RuntimeType.String, ) - override fun renderFromForStr() { - writer.withInlineModule(constraintViolationSymbol.module()) { + override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { + withInlineModule(constraintViolationSymbol.module()) { rustTemplate( """ ##[derive(Debug, PartialEq)] @@ -52,39 +51,33 @@ open class ServerEnumGenerator( ) if (shape.isReachableFromOperationInput()) { - val enumValueSet = enumTrait.enumDefinitionValues.joinToString(", ") - val message = "Value {} at '{}' failed to satisfy constraint: Member must satisfy enum value set: [$enumValueSet]" - rustTemplate( """ impl $constraintViolationName { - pub(crate) fn as_validation_exception_field(self, path: #{String}) -> crate::model::ValidationExceptionField { - crate::model::ValidationExceptionField { - message: format!(r##"$message"##, &self.0, &path), - path, - } - } + #{EnumShapeConstraintViolationImplBlock:W} } """, - *codegenScope, + "EnumShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock( + context.enumTrait, + ), ) } } - writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.TryFrom) { + rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) { rust("type Error = #T;", constraintViolationSymbol) rustBlock("fn try_from(s: &str) -> Result>::Error>", RuntimeType.TryFrom) { rustBlock("match s") { - sortedMembers.forEach { member -> - rust("${member.value.dq()} => Ok($enumName::${member.derivedName()}),") + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),") } rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol) } } } - writer.rustTemplate( + rustTemplate( """ - impl #{TryFrom}<#{String}> for $enumName { - type Error = #{UnknownVariantSymbol}; + impl #{TryFrom}<#{String}> for ${context.enumName} { + type Error = #{ConstraintViolation}; fn try_from(s: #{String}) -> std::result::Result>::Error> { s.as_str().try_into() } @@ -92,21 +85,32 @@ open class ServerEnumGenerator( """, "String" to RuntimeType.String, "TryFrom" to RuntimeType.TryFrom, - "UnknownVariantSymbol" to constraintViolationSymbol, + "ConstraintViolation" to constraintViolationSymbol, ) } - override fun renderFromStr() { - writer.rustTemplate( + override fun implFromStr(context: EnumGeneratorContext): Writable = writable { + rustTemplate( """ - impl std::str::FromStr for $enumName { - type Err = #{UnknownVariantSymbol}; + impl std::str::FromStr for ${context.enumName} { + type Err = #{ConstraintViolation}; fn from_str(s: &str) -> std::result::Result::Err> { Self::try_from(s) } } """, - "UnknownVariantSymbol" to constraintViolationSymbol, + "ConstraintViolation" to constraintViolationSymbol, ) } } + +class ServerEnumGenerator( + codegenContext: ServerCodegenContext, + shape: StringShape, + validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, +) : EnumGenerator( + codegenContext.model, + codegenContext.symbolProvider, + shape, + enumType = ConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator), +) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGenerator.kt index 47b42825fa0..0163f094f5f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGenerator.kt @@ -85,10 +85,14 @@ class LabelSensitivity(internal val labelIndexes: List, internal val greedy private fun hasRedactions(): Boolean = labelIndexes.isNotEmpty() || greedyLabel != null /** Returns the type of the `MakeFmt`. */ - fun type(): Writable = if (hasRedactions()) writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::uri::MakeLabel bool>", *codegenScope) - } else writable { - rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + fun type(): Writable = if (hasRedactions()) { + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::uri::MakeLabel bool>", *codegenScope) + } + } else { + writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::MakeIdentity", *codegenScope) + } } /** Returns the value of the `GreedyLabel`. */ @@ -105,9 +109,13 @@ class LabelSensitivity(internal val labelIndexes: List, internal val greedy } /** Returns the setter enclosing the closure or suffix position. */ - fun setter(): Writable = if (hasRedactions()) writable { - rustTemplate(".label(#{Closure:W}, #{GreedyLabel:W})", "Closure" to closure(), "GreedyLabel" to greedyLabelStruct()) - } else writable { } + fun setter(): Writable = if (hasRedactions()) { + writable { + rustTemplate(".label(#{Closure:W}, #{GreedyLabel:W})", "Closure" to closure(), "GreedyLabel" to greedyLabelStruct()) + } + } else { + writable { } + } } /** Models the ways headers can be bound and sensitive */ @@ -156,11 +164,15 @@ sealed class HeaderSensitivity( /** Returns the closure used during construction. */ internal fun closure(): Writable { - val nameMatch = if (headerKeys.isEmpty()) writable { - rust("false") - } else writable { - val matches = headerKeys.joinToString("|") { it.dq() } - rust("matches!(name.as_str(), $matches)") + val nameMatch = if (headerKeys.isEmpty()) { + writable { + rust("false") + } + } else { + writable { + val matches = headerKeys.joinToString("|") { it.dq() } + rust("matches!(name.as_str(), $matches)") + } } val suffixAndValue = when (this) { @@ -252,11 +264,15 @@ sealed class QuerySensitivity( is SensitiveMapValue -> writable { rust("true") } - is NotSensitiveMapValue -> if (queryKeys.isEmpty()) writable { - rust("false;") - } else writable { - val matches = queryKeys.joinToString("|") { it.dq() } - rust("matches!(name, $matches);") + is NotSensitiveMapValue -> if (queryKeys.isEmpty()) { + writable { + rust("false;") + } + } else { + writable { + val matches = queryKeys.joinToString("|") { it.dq() } + rust("matches!(name, $matches);") + } } } @@ -498,7 +514,9 @@ class ServerHttpSensitivityGenerator( ) } - val value = writable { rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::RequestFmt::new()", *codegenScope) } + headerSensitivity.setter() + labelSensitivity.setter() + querySensitivity.setters() + val value = writable { + rustTemplate("#{SmithyHttpServer}::instrumentation::sensitivity::RequestFmt::new()", *codegenScope) + } + headerSensitivity.setter() + labelSensitivity.setter() + querySensitivity.setters() return MakeFmt(type, value) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt index 3179b5370e8..3d3105f9c70 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiator.kt @@ -16,6 +16,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator import software.amazon.smithy.rust.codegen.core.smithy.generators.InstantiatorCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.InstantiatorSection import software.amazon.smithy.rust.codegen.core.smithy.isOptional +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput @@ -47,12 +48,26 @@ class ServerBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiat override fun hasFallibleBuilder(shape: StructureShape): Boolean { // Only operation input builders take in unconstrained types. val takesInUnconstrainedTypes = shape.isReachableFromOperationInput() - return ServerBuilderGenerator.hasFallibleBuilder( - shape, - codegenContext.model, - codegenContext.symbolProvider, - takesInUnconstrainedTypes, - ) + + val publicConstrainedTypes = if (codegenContext is ServerCodegenContext) { + codegenContext.settings.codegenConfig.publicConstrainedTypes + } else { + true + } + + return if (publicConstrainedTypes) { + ServerBuilderGenerator.hasFallibleBuilder( + shape, + codegenContext.model, + codegenContext.symbolProvider, + takesInUnconstrainedTypes, + ) + } else { + ServerBuilderGeneratorWithoutPublicConstrainedTypes.hasFallibleBuilder( + shape, + codegenContext.symbolProvider, + ) + } } override fun setterName(memberShape: MemberShape): String = codegenContext.symbolProvider.toMemberName(memberShape) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGenerator.kt index 44048276d8e..68e80f7acb9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGenerator.kt @@ -7,7 +7,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Visibility @@ -18,6 +21,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors +import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.toSnakeCase /** @@ -27,28 +33,32 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase open class ServerOperationErrorGenerator( private val model: Model, private val symbolProvider: RustSymbolProvider, - private val operationSymbol: Symbol, - private val errors: List, + private val operationOrEventStream: Shape, ) { - open fun render(writer: RustWriter) { - val symbol = RuntimeType("crate::error::${operationSymbol.name}Error") - if (errors.isNotEmpty()) { - renderErrors(writer, symbol, operationSymbol) - } - } + private val symbol = symbolProvider.toSymbol(operationOrEventStream) - fun renderErrors( - writer: RustWriter, - errorSymbol: RuntimeType, - operationSymbol: Symbol, - ) { + private fun operationErrors(): List = + (operationOrEventStream as OperationShape).operationErrors(model).map { it.asStructureShape().get() } + private fun eventStreamErrors(): List = + (operationOrEventStream as UnionShape).eventStreamErrors() + .map { model.expectShape(it.asMemberShape().get().target, StructureShape::class.java) } + + fun render(writer: RustWriter) { + val (errorSymbol, errors) = when (operationOrEventStream) { + is OperationShape -> symbolProvider.symbolForOperationError(operationOrEventStream) to operationErrors() + is UnionShape -> symbolProvider.symbolForEventStreamError(operationOrEventStream) to eventStreamErrors() + else -> UNREACHABLE("OperationErrorGenerator only supports operation or event stream shapes") + } + if (errors.isEmpty()) { + return + } val meta = RustMetadata( derives = setOf(RuntimeType.Debug), visibility = Visibility.PUBLIC, ) - writer.rust("/// Error type for the `${operationSymbol.name}` operation.") - writer.rust("/// Each variant represents an error that can occur for the `${operationSymbol.name}` operation.") + writer.rust("/// Error type for the `${symbol.name}` operation.") + writer.rust("/// Each variant represents an error that can occur for the `${symbol.name}` operation.") meta.render(writer) writer.rustBlock("enum ${errorSymbol.name}") { errors.forEach { errorVariant -> @@ -120,7 +130,7 @@ open class ServerOperationErrorGenerator( */ private fun RustWriter.delegateToVariants( errors: List, - symbol: RuntimeType, + symbol: Symbol, writable: Writable, ) { rustBlock("match &self") { diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index cf405dc5e99..4d2a275fce6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -15,17 +15,18 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule -import software.amazon.smithy.rust.codegen.core.smithy.InputsModule -import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport +import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Error as ErrorModule +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Input as InputModule +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output as OutputModule /** * ServerServiceGenerator @@ -42,12 +43,14 @@ open class ServerServiceGenerator( ) { private val index = TopDownIndex.of(codegenContext.model) protected val operations = index.getContainedOperations(codegenContext.serviceShape).sortedBy { it.id } - private val serviceName = codegenContext.serviceShape.id.name.toString() + private val serviceName = codegenContext.serviceShape.id.name.toPascalCase() fun documentation(writer: RustWriter) { val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet(compareBy { it.id }) val builderFieldNames = - operations.associateWith { RustReservedWords.escapeIfNeeded(codegenContext.symbolProvider.toSymbol(it).name.toSnakeCase()) } + operations.associateWith { + RustReservedWords.escapeIfNeeded(codegenContext.symbolProvider.toSymbol(it).name.toSnakeCase()) + } .toSortedMap( compareBy { it.id }, ) @@ -68,7 +71,7 @@ open class ServerServiceGenerator( //! //! The primary entrypoint is [`$serviceName`]: it satisfies the [`Service`](#{Tower}::Service) //! trait and therefore can be handed to a [`hyper` server](https://github.com/hyperium/hyper) via [`$serviceName::into_make_service`] or used in Lambda via [`LambdaHandler`](#{SmithyHttpServer}::routing::LambdaHandler). - //! The [`crate::${InputsModule.name}`], ${if (!hasErrors) "and " else ""}[`crate::${OutputsModule.name}`], ${if (hasErrors) "and [`crate::${ErrorsModule.name}`]" else "" } + //! The [`crate::${InputModule.name}`], ${if (!hasErrors) "and " else ""}[`crate::${OutputModule.name}`], ${if (hasErrors) "and [`crate::${ErrorModule.name}`]" else "" } //! modules provide the types used in each operation. //! //! ###### Running on Hyper @@ -248,7 +251,7 @@ open class ServerServiceGenerator( for (operation in operations) { if (operation.errors.isNotEmpty()) { - rustCrate.withModule(RustModule.Error) { + rustCrate.withModule(ErrorModule) { renderCombinedErrors(this, operation) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt index e8483f4a206..3b7d09c3cea 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt @@ -18,9 +18,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule -import software.amazon.smithy.rust.codegen.core.smithy.InputsModule -import software.amazon.smithy.rust.codegen.core.smithy.OutputsModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.letIf @@ -29,6 +26,9 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Error as ErrorModule +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Input as InputModule +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output as OutputModule class ServerServiceGeneratorV2( private val codegenContext: ServerCodegenContext, @@ -540,8 +540,8 @@ class ServerServiceGeneratorV2( */ fun handlerImports(crateName: String, operations: Collection, commentToken: String = "///") = writable { val hasErrors = operations.any { it.errors.isNotEmpty() } - val errorImport = if (hasErrors) ", ${ErrorsModule.name}" else "" + val errorImport = if (hasErrors) ", ${ErrorModule.name}" else "" if (operations.isNotEmpty()) { - rust("$commentToken use $crateName::{${InputsModule.name}, ${OutputsModule.name}$errorImport};") + rust("$commentToken use $crateName::{${InputModule.name}, ${OutputModule.name}$errorImport};") } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt index b5a9d45895b..6916617b6ff 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGenerator.kt @@ -18,12 +18,14 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained -import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.UnconstrainedShapeSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait /** * Generates a Rust type for a constrained collection shape that is able to hold values for the corresponding @@ -38,7 +40,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained */ class UnconstrainedCollectionGenerator( val codegenContext: ServerCodegenContext, - private val unconstrainedModuleWriter: RustWriter, + private val inlineModuleCreator: InlineModuleCreator, val shape: CollectionShape, ) { private val model = codegenContext.model @@ -70,7 +72,7 @@ class UnconstrainedCollectionGenerator( val innerMemberSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape.member) - unconstrainedModuleWriter.withInlineModule(symbol.module()) { + inlineModuleCreator(symbol) { rustTemplate( """ ##[derive(Debug, Clone)] @@ -107,7 +109,11 @@ class UnconstrainedCollectionGenerator( constrainedShapeSymbolProvider.toSymbol(shape.member) } val innerConstraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(innerShape) - + val boxErr = if (shape.member.hasTrait()) { + ".map_err(|(idx, inner_violation)| (idx, Box::new(inner_violation)))" + } else { + "" + } val constrainValueWritable = writable { conditionalBlock("inner.map(|inner| ", ").transpose()", constrainedMemberSymbol.isOptional()) { rust("inner.try_into().map_err(|inner_violation| (idx, inner_violation))") @@ -124,7 +130,9 @@ class UnconstrainedCollectionGenerator( #{ConstrainValueWritable:W} }) .collect(); - let inner = res.map_err(|(idx, inner_violation)| Self::Error::Member(idx, inner_violation))?; + let inner = res + $boxErr + .map_err(|(idx, inner_violation)| Self::Error::Member(idx, inner_violation))?; """, "Vec" to RuntimeType.Vec, "ConstrainedMemberSymbol" to constrainedMemberSymbol, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt index e18d372c75e..0862a26987c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGenerator.kt @@ -19,11 +19,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained -import software.amazon.smithy.rust.codegen.core.smithy.module +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait /** * Generates a Rust type for a constrained map shape that is able to hold values for the corresponding @@ -38,7 +40,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained */ class UnconstrainedMapGenerator( val codegenContext: ServerCodegenContext, - private val unconstrainedModuleWriter: RustWriter, + private val inlineModuleCreator: InlineModuleCreator, val shape: MapShape, ) { private val model = codegenContext.model @@ -72,7 +74,7 @@ class UnconstrainedMapGenerator( val keySymbol = unconstrainedShapeSymbolProvider.toSymbol(keyShape) val valueMemberSymbol = unconstrainedShapeSymbolProvider.toSymbol(shape.value) - unconstrainedModuleWriter.withInlineModule(symbol.module()) { + inlineModuleCreator(symbol) { rustTemplate( """ ##[derive(Debug, Clone)] @@ -125,6 +127,11 @@ class UnconstrainedMapGenerator( ) } val constrainValueWritable = writable { + val boxErr = if (shape.value.hasTrait()) { + ".map_err(Box::new)" + } else { + "" + } if (constrainedMemberValueSymbol.isOptional()) { // The map is `@sparse`. rustBlock("match v") { @@ -133,7 +140,7 @@ class UnconstrainedMapGenerator( // DRYing this up with the else branch below would make this less understandable. rustTemplate( """ - match #{ConstrainedValueSymbol}::try_from(v) { + match #{ConstrainedValueSymbol}::try_from(v)$boxErr { Ok(v) => Ok((k, Some(v))), Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), } @@ -145,7 +152,7 @@ class UnconstrainedMapGenerator( } else { rustTemplate( """ - match #{ConstrainedValueSymbol}::try_from(v) { + match #{ConstrainedValueSymbol}::try_from(v)$boxErr { Ok(v) => #{Epilogue:W}, Err(inner_constraint_violation) => Err(Self::Error::Value(k, inner_constraint_violation)), } @@ -214,9 +221,10 @@ class UnconstrainedMapGenerator( // ``` rustTemplate( """ - let hm: std::collections::HashMap<#{KeySymbol}, #{ValueSymbol}> = + let hm: #{HashMap}<#{KeySymbol}, #{ValueSymbol}> = hm.into_iter().map(|(k, v)| (k, v.into())).collect(); """, + "HashMap" to RuntimeType.HashMap, "KeySymbol" to symbolProvider.toSymbol(keyShape), "ValueSymbol" to symbolProvider.toSymbol(valueShape), ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt index 72655675a03..f3ec8322c7e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGenerator.kt @@ -23,16 +23,17 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed -import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.letIf import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape import software.amazon.smithy.rust.codegen.server.smithy.isDirectlyConstrained import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput /** @@ -48,7 +49,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromO */ class UnconstrainedUnionGenerator( val codegenContext: ServerCodegenContext, - private val unconstrainedModuleWriter: RustWriter, + private val inlineModuleCreator: InlineModuleCreator, private val modelsModuleWriter: RustWriter, val shape: UnionShape, ) { @@ -76,7 +77,7 @@ class UnconstrainedUnionGenerator( val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) val constraintViolationName = constraintViolationSymbol.name - unconstrainedModuleWriter.withInlineModule(symbol.module()) { + inlineModuleCreator(symbol) { rustBlock( """ ##[allow(clippy::enum_variant_names)] @@ -132,11 +133,16 @@ class UnconstrainedUnionGenerator( } else { Visibility.PUBCRATE } - modelsModuleWriter.withInlineModule( - constraintViolationSymbol.module(), + + inlineModuleCreator( + constraintViolationSymbol, ) { Attribute(derive(RuntimeType.Debug, RuntimeType.PartialEq)).render(this) - rustBlock("pub${if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate)" else ""} enum $constraintViolationName") { + rustBlock( + """ + ##[allow(clippy::enum_variant_names)] + pub${if (constraintViolationVisibility == Visibility.PUBCRATE) " (crate)" else ""} enum $constraintViolationName""", + ) { constraintViolations().forEach { renderConstraintViolation(this, it) } } @@ -171,8 +177,8 @@ class UnconstrainedUnionGenerator( val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(targetShape) - // If the corresponding union's member is boxed, box this constraint violation symbol too. - .letIf(constraintViolation.forMember.hasTrait()) { + // Box this constraint violation symbol if necessary. + .letIf(constraintViolation.forMember.hasTrait()) { it.makeRustBoxed() } @@ -201,10 +207,15 @@ class UnconstrainedUnionGenerator( (!publicConstrainedTypes || !targetShape.isDirectlyConstrained(symbolProvider)) val (unconstrainedVar, boxIt) = if (member.hasTrait()) { - "(*unconstrained)" to ".map(Box::new).map_err(Box::new)" + "(*unconstrained)" to ".map(Box::new)" } else { "unconstrained" to "" } + val boxErr = if (member.hasTrait()) { + ".map_err(Box::new)" + } else { + "" + } if (resolveToNonPublicConstrainedType) { val constrainedSymbol = @@ -217,8 +228,7 @@ class UnconstrainedUnionGenerator( """ { let constrained: #{ConstrainedSymbol} = $unconstrainedVar - .try_into() - $boxIt + .try_into() $boxIt $boxErr .map_err(Self::Error::${ConstraintViolation(member).name()})?; constrained.into() } @@ -231,6 +241,7 @@ class UnconstrainedUnionGenerator( $unconstrainedVar .try_into() $boxIt + $boxErr .map_err(Self::Error::${ConstraintViolation(member).name()})? """, ) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ValidationExceptionConversionGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ValidationExceptionConversionGenerator.kt new file mode 100644 index 00000000000..7db44265863 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ValidationExceptionConversionGenerator.kt @@ -0,0 +1,50 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider + +/** + * Collection of methods that will be invoked by the respective generators to generate code to convert constraint + * violations to validation exceptions. + * This is only rendered for shapes that lie in a constrained operation's closure. + */ +interface ValidationExceptionConversionGenerator { + val shapeId: ShapeId + + /** + * Convert from a top-level operation input's constraint violation into + * `aws_smithy_http_server::rejection::RequestRejection`. + */ + fun renderImplFromConstraintViolationForRequestRejection(): Writable + + // Simple shapes. + fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection): Writable + fun enumShapeConstraintViolationImplBlock(enumTrait: EnumTrait): Writable + fun numberShapeConstraintViolationImplBlock(rangeInfo: Range): Writable + fun blobShapeConstraintViolationImplBlock(blobConstraintsInfo: Collection): Writable + + // Aggregate shapes. + fun mapShapeConstraintViolationImplBlock( + shape: MapShape, + keyShape: StringShape, + valueShape: Shape, + symbolProvider: RustSymbolProvider, + model: Model, + ): Writable + fun builderConstraintViolationImplBlock(constraintViolations: Collection): Writable + fun collectionShapeConstraintViolationImplBlock( + collectionConstraintsInfo: Collection, + isMemberConstrained: Boolean, + ): Writable +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt index b01c2f633e7..41201b8695f 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerRequestBindingGenerator.kt @@ -7,7 +7,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.http import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustType import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable @@ -23,25 +22,19 @@ import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape class ServerRequestBindingGenerator( protocol: Protocol, - private val codegenContext: ServerCodegenContext, + codegenContext: ServerCodegenContext, operationShape: OperationShape, ) { - private fun serverBuilderSymbol(shape: StructureShape): Symbol = shape.serverBuilderSymbol( - codegenContext.symbolProvider, - !codegenContext.settings.codegenConfig.publicConstrainedTypes, - ) private val httpBindingGenerator = HttpBindingGenerator( protocol, codegenContext, codegenContext.unconstrainedShapeSymbolProvider, operationShape, - ::serverBuilderSymbol, listOf( ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUnconstrainedMapHttpBindingCustomization( codegenContext, @@ -54,11 +47,11 @@ class ServerRequestBindingGenerator( fun generateDeserializePayloadFn( binding: HttpBindingDescriptor, - errorT: RuntimeType, + errorSymbol: Symbol, structuredHandler: RustWriter.(String) -> Unit, ): RuntimeType = httpBindingGenerator.generateDeserializePayloadFn( binding, - errorT, + errorSymbol, structuredHandler, HttpMessageType.REQUEST, ) @@ -82,7 +75,9 @@ class ServerRequestAfterDeserializingIntoAHashMapOfHttpPrefixHeadersWrapInUncons if (section.memberShape.targetCanReachConstrainedShape(codegenContext.model, codegenContext.unconstrainedShapeSymbolProvider)) { rust( "let out = out.map(#T);", - codegenContext.unconstrainedShapeSymbolProvider.toSymbol(section.memberShape).mapRustType { it.stripOuter() }, + codegenContext.unconstrainedShapeSymbolProvider.toSymbol(section.memberShape).mapRustType { + it.stripOuter() + }, ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt index e30d9cc6336..33f062b6b98 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/http/ServerResponseBindingGenerator.kt @@ -41,7 +41,6 @@ class ServerResponseBindingGenerator( codegenContext, codegenContext.symbolProvider, operationShape, - ::builderSymbol, listOf( ServerResponseBeforeIteratingOverMapBoundWithHttpPrefixHeadersUnwrapConstrainedMapHttpBindingCustomization( codegenContext, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index 09eb45a5e64..7df2a98c7fe 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -87,8 +87,6 @@ class ServerAwsJsonProtocol( private val runtimeConfig = codegenContext.runtimeConfig override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - fun builderSymbol(shape: StructureShape): Symbol = - shape.serverBuilderSymbol(serverCodegenContext) fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse = if (shape.canReachConstrainedShape(codegenContext.model, serverCodegenContext.symbolProvider)) { ReturnSymbolToParse(serverCodegenContext.unconstrainedShapeSymbolProvider.toSymbol(shape), true) @@ -99,7 +97,6 @@ class ServerAwsJsonProtocol( codegenContext, httpBindingResolver, ::awsJsonFieldName, - ::builderSymbol, ::returnSymbolToParse, listOf( ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization(serverCodegenContext), @@ -168,7 +165,6 @@ class ServerRestJsonProtocol( codegenContext, httpBindingResolver, ::restJsonFieldName, - ::builderSymbol, ::returnSymbolToParse, listOf( ServerRequestBeforeBoxingDeserializedMemberConvertToMaybeConstrainedJsonParserCustomization( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 9727ca580bd..d1ff2b37d4c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -344,8 +344,8 @@ class ServerProtocolTestGenerator( val operationErrorName = "crate::error::${operationSymbol.name}Error" if (!protocolSupport.responseSerialization || ( - !protocolSupport.errorSerialization && shape.hasTrait() - ) + !protocolSupport.errorSerialization && shape.hasTrait() + ) ) { rust("/* test case disabled for this protocol (not yet supported) */") return @@ -415,22 +415,22 @@ class ServerProtocolTestGenerator( rustTemplate( """ .body(${ - if (body != null) { - // The `replace` is necessary to fix the malformed request test `RestJsonInvalidJsonBody`. - // https://github.com/awslabs/smithy/blob/887ae4f6d118e55937105583a07deb90d8fabe1c/smithy-aws-protocol-tests/model/restJson1/malformedRequests/malformed-request-body.smithy#L47 - // - // Smithy is written in Java, which parses `\u000c` within a `String` as a single char given by the - // corresponding Unicode code point. That is the "form feed" 0x0c character. When printing it, - // it gets written as "\f", which is an invalid Rust escape sequence: https://static.rust-lang.org/doc/master/reference.html#literals - // So we need to write the corresponding Rust Unicode escape sequence to make the program compile. - // - // We also escape to avoid interactions with templating in the case where the body contains `#`. - val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() - - "#{SmithyHttpServer}::body::Body::from(#{Bytes}::from_static($sanitizedBody.as_bytes()))" - } else { - "#{SmithyHttpServer}::body::Body::empty()" - } + if (body != null) { + // The `replace` is necessary to fix the malformed request test `RestJsonInvalidJsonBody`. + // https://github.com/awslabs/smithy/blob/887ae4f6d118e55937105583a07deb90d8fabe1c/smithy-aws-protocol-tests/model/restJson1/malformedRequests/malformed-request-body.smithy#L47 + // + // Smithy is written in Java, which parses `\u000c` within a `String` as a single char given by the + // corresponding Unicode code point. That is the "form feed" 0x0c character. When printing it, + // it gets written as "\f", which is an invalid Rust escape sequence: https://static.rust-lang.org/doc/master/reference.html#literals + // So we need to write the corresponding Rust Unicode escape sequence to make the program compile. + // + // We also escape to avoid interactions with templating in the case where the body contains `#`. + val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq() + + "#{SmithyHttpServer}::body::Body::from(#{Bytes}::from_static($sanitizedBody.as_bytes()))" + } else { + "#{SmithyHttpServer}::body::Body::empty()" + } }).unwrap(); """, *codegenScope, @@ -643,7 +643,7 @@ class ServerProtocolTestGenerator( assertOk(rustWriter) { rustWriter.rust( "#T(&body, ${ - rustWriter.escape(body).dq() + rustWriter.escape(body).dq() }, #T::from(${(mediaType ?: "unknown").dq()}))", RuntimeType.protocolTest(codegenContext.runtimeConfig, "validate_body"), RuntimeType.protocolTest(codegenContext.runtimeConfig, "MediaType"), @@ -772,90 +772,18 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonEndpointTrait", TestType.Request), FailingTest(RestJson, "RestJsonEndpointTraitWithHostLabel", TestType.Request), - FailingTest(RestJson, "RestJsonWithBodyExpectsApplicationJsonContentType", TestType.MalformedRequest), - - // Tests involving constraint traits, which are not yet fully implemented. - // See https://github.com/awslabs/smithy-rs/issues/1401. + // Tests involving `@range` on floats. + // Pending resolution from the Smithy team, see https://github.com/awslabs/smithy-rs/issues/2007. FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloat_case1", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloat", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloat", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternSensitiveString", TestType.MalformedRequest), - - // See https://github.com/awslabs/smithy-rs/issues/1969 - FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShortOverride_case0", TestType.MalformedRequest), - FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeShortOverride_case1", TestType.MalformedRequest), - FailingTest( - MalformedRangeValidation, - "RestJsonMalformedRangeIntegerOverride_case0", - TestType.MalformedRequest, - ), - FailingTest( - MalformedRangeValidation, - "RestJsonMalformedRangeIntegerOverride_case1", - TestType.MalformedRequest, - ), - FailingTest( - MalformedRangeValidation, - "RestJsonMalformedRangeLongOverride_case0", - TestType.MalformedRequest, - ), - FailingTest( - MalformedRangeValidation, - "RestJsonMalformedRangeLongOverride_case1", - TestType.MalformedRequest, - ), - FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxShortOverride", TestType.MalformedRequest), - FailingTest( - MalformedRangeValidation, - "RestJsonMalformedRangeMaxIntegerOverride", - TestType.MalformedRequest, - ), - FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMaxLongOverride", TestType.MalformedRequest), - FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinShortOverride", TestType.MalformedRequest), - FailingTest( - MalformedRangeValidation, - "RestJsonMalformedRangeMinIntegerOverride", - TestType.MalformedRequest, - ), - FailingTest(MalformedRangeValidation, "RestJsonMalformedRangeMinLongOverride", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeByteOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeByteOverride_case1", TestType.MalformedRequest), + + // Tests involving floating point shapes and the `@range` trait; see https://github.com/awslabs/smithy-rs/issues/2007 FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloatOverride_case0", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeFloatOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMaxStringOverride", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMinStringOverride", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxByteOverride", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMaxFloatOverride", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinByteOverride", TestType.MalformedRequest), FailingTest(RestJsonValidation, "RestJsonMalformedRangeMinFloatOverride", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternListOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternListOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternMapKeyOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternMapKeyOverride_case1", TestType.MalformedRequest), - FailingTest( - RestJsonValidation, - "RestJsonMalformedPatternMapValueOverride_case0", - TestType.MalformedRequest, - ), - FailingTest( - RestJsonValidation, - "RestJsonMalformedPatternMapValueOverride_case1", - TestType.MalformedRequest, - ), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternStringOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternStringOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternUnionOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedPatternUnionOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthBlobOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthBlobOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthListOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthListOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthMapOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthStringOverride_case0", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthStringOverride_case1", TestType.MalformedRequest), - FailingTest(RestJsonValidation, "RestJsonMalformedLengthStringOverride_case2", TestType.MalformedRequest), // Some tests for the S3 service (restXml). FailingTest("com.amazonaws.s3#AmazonS3", "GetBucketLocationUnwrappedOutput", TestType.Response), @@ -872,16 +800,11 @@ class ServerProtocolTestGenerator( FailingTest("aws.protocoltests.json10#JsonRpc10", "AwsJson10EndpointTrait", TestType.Request), // AwsJson1.1 failing tests. - FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTraitWithHostLabel", TestType.Request), - FailingTest("aws.protocoltests.json#JsonProtocol", "AwsJson11EndpointTrait", TestType.Request), - FailingTest("aws.protocoltests.json#JsonProtocol", "parses_httpdate_timestamps", TestType.Response), - FailingTest("aws.protocoltests.json#JsonProtocol", "parses_iso8601_timestamps", TestType.Response), - FailingTest( - "aws.protocoltests.json#JsonProtocol", - "parses_the_request_id_from_the_response", - TestType.Response, - ), - + FailingTest(AwsJson11, "AwsJson11EndpointTraitWithHostLabel", TestType.Request), + FailingTest(AwsJson11, "AwsJson11EndpointTrait", TestType.Request), + FailingTest(AwsJson11, "parses_httpdate_timestamps", TestType.Response), + FailingTest(AwsJson11, "parses_iso8601_timestamps", TestType.Response), + FailingTest(AwsJson11, "parses_the_request_id_from_the_response", TestType.Response), ) private val RunOnly: Set? = null diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index ff7496b973d..adf7cf0a933 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -44,7 +44,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization import software.amazon.smithy.rust.codegen.core.smithy.generators.TypeConversionGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTraitImplGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName @@ -53,6 +52,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.mapRustType import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation +import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors @@ -253,7 +253,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) // Implement `into_response` for output types. - val errorSymbol = operationShape.errorSymbol(symbolProvider) + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) rustTemplate( """ @@ -365,7 +365,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( private fun serverSerializeError(operationShape: OperationShape): RuntimeType { val fnName = "serialize_${operationShape.id.name.toSnakeCase()}_error" - val errorSymbol = operationShape.errorSymbol(symbolProvider) + val errorSymbol = symbolProvider.symbolForOperationError(operationShape) return RuntimeType.forInlineFun(fnName, operationSerModule) { Attribute.AllowClippyUnnecessaryWraps.render(this) rustBlockTemplate( @@ -385,7 +385,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( private fun RustWriter.serverRenderErrorShapeResponseSerializer( operationShape: OperationShape, - errorSymbol: RuntimeType, + errorSymbol: Symbol, ) { val operationName = symbolProvider.toSymbol(operationShape).name val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) @@ -628,7 +628,20 @@ private class ServerHttpBoundProtocolTraitImplGenerator( """ let bytes = #{Hyper}::body::to_bytes(body).await?; if !bytes.is_empty() { - input = #{parser}(bytes.as_ref(), input)?; + """, + *codegenScope, + ) + if (protocol is RestJson) { + rustTemplate( + """ + #{SmithyHttpServer}::protocols::content_type_header_classifier(&parts.headers, Some("application/json"))?; + """, + *codegenScope, + ) + } + rustTemplate( + """ + input = #{parser}(bytes.as_ref(), input)?; } """, *codegenScope, @@ -645,11 +658,11 @@ private class ServerHttpBoundProtocolTraitImplGenerator( """ { input = input.${member.setterName()}(${ - if (symbolProvider.toSymbol(binding.member).isOptional()) { - "Some(value)" - } else { - "value" - } + if (symbolProvider.toSymbol(binding.member).isOptional()) { + "Some(value)" + } else { + "value" + } }); } """, @@ -677,7 +690,9 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) ) { "?" - } else "" + } else { + "" + } rustTemplate("input.build()$err", *codegenScope) } @@ -827,7 +842,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( // * a map of list of string; or // * a map of set of string. enum class QueryParamsTargetMapValueType { - STRING, LIST, SET; + STRING, LIST, SET } private fun queryParamsTargetMapValueType(targetMapValue: Shape): QueryParamsTargetMapValueType = @@ -1010,7 +1025,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( rustBlock("if !$memberName.is_empty()") { withBlock( "input = input.${ - binding.member.setterName() + binding.member.setterName() }(", ");", ) { @@ -1019,7 +1034,9 @@ private class ServerHttpBoundProtocolTraitImplGenerator( "#T(", ")", conditional = hasConstrainedTarget, - unconstrainedShapeSymbolProvider.toSymbol(binding.member).mapRustType { it.stripOuter() }, + unconstrainedShapeSymbolProvider.toSymbol(binding.member).mapRustType { + it.stripOuter() + }, ) { write(memberName) } @@ -1140,18 +1157,18 @@ private class ServerHttpBoundProtocolTraitImplGenerator( * Returns the error type of the function that deserializes a non-streaming HTTP payload (a byte slab) into the * shape targeted by the `httpPayload` trait. */ - private fun getDeserializePayloadErrorSymbol(binding: HttpBindingDescriptor): RuntimeType { + private fun getDeserializePayloadErrorSymbol(binding: HttpBindingDescriptor): Symbol { check(binding.location == HttpLocation.PAYLOAD) if (model.expectShape(binding.member.target) is StringShape) { - return ServerRuntimeType.requestRejection(runtimeConfig) + return ServerRuntimeType.requestRejection(runtimeConfig).toSymbol() } return when (codegenContext.protocol) { RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> { - RuntimeType.smithyJson(runtimeConfig).resolve("deserialize::error::DeserializeError") + RuntimeType.smithyJson(runtimeConfig).resolve("deserialize::error::DeserializeError").toSymbol() } RestXmlTrait.ID -> { - RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError") + RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError").toSymbol() } else -> { TODO("Protocol ${codegenContext.protocol} not supported yet") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt new file mode 100644 index 00000000000..fc83f1392b0 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerCodegenIntegrationTest.kt @@ -0,0 +1,54 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.testutil + +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.build.SmithyBuildPlugin +import software.amazon.smithy.model.Model +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.codegenIntegrationTest +import software.amazon.smithy.rust.codegen.server.smithy.RustServerCodegenPlugin +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import java.nio.file.Path + +/** + * This file is entirely analogous to [software.amazon.smithy.rust.codegen.client.testutil.ClientCodegenIntegrationTest.kt]. + */ + +fun serverIntegrationTest( + model: Model, + params: IntegrationTestParams = IntegrationTestParams(), + additionalDecorators: List = listOf(), + test: (ServerCodegenContext, RustCrate) -> Unit = { _, _ -> }, +): Path { + fun invokeRustCodegenPlugin(ctx: PluginContext) { + val codegenDecorator = object : ServerCodegenDecorator { + override val name: String = "Add tests" + override val order: Byte = 0 + + override fun classpathDiscoverable(): Boolean = false + + override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { + test(codegenContext, rustCrate) + } + } + RustServerCodegenPlugin().executeWithDecorator(ctx, codegenDecorator, *additionalDecorators.toTypedArray()) + } + return codegenIntegrationTest(model, params, invokePlugin = ::invokeRustCodegenPlugin) +} + +abstract class ServerDecoratableBuildPlugin : SmithyBuildPlugin { + abstract fun executeWithDecorator( + context: PluginContext, + vararg decorator: ServerCodegenDecorator, + ) + + override fun execute(context: PluginContext) { + executeWithDecorator(context) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt index 9d49dfb9fea..d035c9c7fca 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/testutil/ServerTestHelpers.kt @@ -12,25 +12,28 @@ import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProviderConfig import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig -import software.amazon.smithy.rust.codegen.server.smithy.RustCodegenServerPlugin +import software.amazon.smithy.rust.codegen.server.smithy.RustServerCodegenPlugin import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerModuleProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings import software.amazon.smithy.rust.codegen.server.smithy.ServerSymbolProviders +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator // These are the settings we default to if the user does not override them in their `smithy-build.json`. -val ServerTestSymbolVisitorConfig = SymbolVisitorConfig( +val ServerTestRustSymbolProviderConfig = RustSymbolProviderConfig( runtimeConfig = TestRuntimeConfig, renameExceptions = false, nullabilityCheckMode = NullableIndex.CheckMode.SERVER, + moduleProvider = ServerModuleProvider, ) private fun testServiceShapeFor(model: Model) = @@ -45,15 +48,16 @@ fun serverTestSymbolProviders( settings: ServerRustSettings? = null, ) = ServerSymbolProviders.from( + serverTestRustSettings(), model, serviceShape ?: testServiceShapeFor(model), - ServerTestSymbolVisitorConfig, + ServerTestRustSymbolProviderConfig, ( settings ?: serverTestRustSettings( (serviceShape ?: testServiceShapeFor(model)).id, ) ).codegenConfig.publicConstrainedTypes, - RustCodegenServerPlugin::baseSymbolProvider, + RustServerCodegenPlugin::baseSymbolProvider, ) fun serverTestRustSettings( @@ -94,11 +98,12 @@ fun serverTestCodegenContext( ?: ServiceShape.builder().version("test").id("test#Service").build() val protocol = protocolShapeId ?: ShapeId.from("test#Protocol") val serverSymbolProviders = ServerSymbolProviders.from( + settings, model, service, - ServerTestSymbolVisitorConfig, + ServerTestRustSymbolProviderConfig, settings.codegenConfig.publicConstrainedTypes, - RustCodegenServerPlugin::baseSymbolProvider, + RustServerCodegenPlugin::baseSymbolProvider, ) return ServerCodegenContext( @@ -117,14 +122,14 @@ fun serverTestCodegenContext( /** * In tests, we frequently need to generate a struct, a builder, and an impl block to access said builder. */ -fun StructureShape.serverRenderWithModelBuilder(model: Model, symbolProvider: RustSymbolProvider, writer: RustWriter) { - StructureGenerator(model, symbolProvider, writer, this).render(CodegenTarget.SERVER) +fun StructureShape.serverRenderWithModelBuilder(rustCrate: RustCrate, model: Model, symbolProvider: RustSymbolProvider, writer: RustWriter) { + StructureGenerator(model, symbolProvider, writer, this, emptyList()).render() val serverCodegenContext = serverTestCodegenContext(model) // Note that this always uses `ServerBuilderGenerator` and _not_ `ServerBuilderGeneratorWithoutPublicConstrainedTypes`, // regardless of the `publicConstrainedTypes` setting. - val modelBuilder = ServerBuilderGenerator(serverCodegenContext, this) - modelBuilder.render(writer) - writer.implBlock(this, symbolProvider) { + val modelBuilder = ServerBuilderGenerator(serverCodegenContext, this, SmithyValidationExceptionConversionGenerator(serverCodegenContext)) + modelBuilder.render(rustCrate, writer) + writer.implBlock(symbolProvider.toSymbol(this)) { modelBuilder.renderConvenienceMethod(this) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt new file mode 100644 index 00000000000..9aee2b884ef --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/ConstraintViolationRustBoxTrait.kt @@ -0,0 +1,25 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.traits + +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.Trait + +/** + * This shape is analogous to [software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait], but for the + * constraint violation graph. The sets of shapes we tag are different, and they are interpreted by the code generator + * differently, so we need a separate tag. + * + * This is used to handle recursive constraint violations. + * See [software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer]. + */ +class ConstraintViolationRustBoxTrait : Trait { + val ID = ShapeId.from("software.amazon.smithy.rust.codegen.smithy.rust.synthetic#constraintViolationBox") + override fun toNode(): Node = Node.objectNode() + + override fun toShapeId(): ShapeId = ID +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/SyntheticStructureFromConstrainedMemberTrait.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/SyntheticStructureFromConstrainedMemberTrait.kt new file mode 100644 index 00000000000..de5890c87ad --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/traits/SyntheticStructureFromConstrainedMemberTrait.kt @@ -0,0 +1,21 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.traits + +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.AnnotationTrait + +/** + * Trait applied to an overridden shape indicating the member of this new shape type + */ +class SyntheticStructureFromConstrainedMemberTrait(val container: Shape, val member: MemberShape) : AnnotationTrait(SyntheticStructureFromConstrainedMemberTrait.ID, Node.objectNode()) { + companion object { + val ID: ShapeId = ShapeId.from("smithy.api.internal#overriddenMember") + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt index 02e1d4be643..68840bde201 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/AttachValidationExceptionToConstrainedOperationInputsInAllowList.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker import software.amazon.smithy.rust.codegen.core.util.inputShape +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTrait /** @@ -60,11 +61,11 @@ object AttachValidationExceptionToConstrainedOperationInputsInAllowList { walker.walkShapes(operationShape.inputShape(model)) .any { it is SetShape || it is EnumShape || it.hasConstraintTrait() } } - .filter { !it.errors.contains(ShapeId.from("smithy.framework#ValidationException")) } + .filter { !it.errors.contains(SmithyValidationExceptionConversionGenerator.SHAPE_ID) } return ModelTransformer.create().mapShapes(model) { shape -> if (shape is OperationShape && operationsWithConstrainedInputWithoutValidationException.contains(shape)) { - shape.toBuilder().addError("smithy.framework#ValidationException").build() + shape.toBuilder().addError(SmithyValidationExceptionConversionGenerator.SHAPE_ID).build() } else { shape } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ConstrainedMemberTransform.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ConstrainedMemberTransform.kt new file mode 100644 index 00000000000..3ab99ce64db --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/ConstrainedMemberTransform.kt @@ -0,0 +1,226 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.AbstractShapeBuilder +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.RequiredTrait +import software.amazon.smithy.model.traits.Trait +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.orNull +import software.amazon.smithy.rust.codegen.server.smithy.allConstraintTraits +import software.amazon.smithy.rust.codegen.server.smithy.traits.SyntheticStructureFromConstrainedMemberTrait +import software.amazon.smithy.utils.ToSmithyBuilder +import java.lang.IllegalStateException +import java.util.* + +/** + * Transforms all member shapes that have constraints on them into equivalent non-constrained + * member shapes targeting synthetic constrained structure shapes with the member's constraints. + * + * E.g.: + * ``` + * structure A { + * @length(min: 1, max: 69) + * string: ConstrainedString + * } + * + * @length(min: 2, max: 10) + * @pattern("^[A-Za-z]+$") + * string ConstrainedString + * ``` + * + * to + * + * ``` + * structure A { + * string: OverriddenConstrainedString + * } + * + * @length(min: 1, max: 69) + * @pattern("^[A-Za-z]+$") + * OverriddenConstrainedString + * + * @length(min: 2, max: 10) + * @pattern("^[A-Za-z]+$") + * string ConstrainedString + * ``` + */ +object ConstrainedMemberTransform { + private data class MemberShapeTransformation( + val newShape: Shape, + val memberToChange: MemberShape, + val traitsToKeep: List, + ) + + private val memberConstraintTraitsToOverride = allConstraintTraits - RequiredTrait::class.java + + private fun Shape.hasMemberConstraintTrait() = + memberConstraintTraitsToOverride.any(this::hasTrait) + + fun transform(model: Model): Model { + val additionalNames = HashSet() + val walker = DirectedWalker(model) + + // Find all synthetic input / output structures that have been added by + // the OperationNormalizer, get constrained members out of those structures, + // convert them into non-constrained members and then pass them to the transformer. + // The transformer will add new shapes, and will replace existing member shapes' target + // with the newly added shapes. + val transformations = model.operationShapes + .flatMap { listOfNotNull(it.input.orNull(), it.output.orNull()) + it.errors } + .mapNotNull { model.expectShape(it).asStructureShape().orElse(null) } + .filter { it.hasTrait(SyntheticInputTrait.ID) || it.hasTrait(SyntheticOutputTrait.ID) } + .flatMap { walker.walkShapes(it) } + .filter { it is StructureShape || it is ListShape || it is UnionShape || it is MapShape } + .flatMap { it.constrainedMembers() } + .mapNotNull { + val transformation = it.makeNonConstrained(model, additionalNames) + // Keep record of new names that have been generated to ensure none of them regenerated. + additionalNames.add(transformation.newShape.id) + + transformation + } + + return applyTransformations(model, transformations) + } + + /*** + * Returns a Model that has all the transformations applied on the original model. + */ + private fun applyTransformations( + model: Model, + transformations: List, + ): Model { + val modelBuilder = model.toBuilder() + + val memberShapesToReplace = transformations.map { + // Add the new shape to the model. + modelBuilder.addShape(it.newShape) + + it.memberToChange.toBuilder() + .target(it.newShape.id) + .traits(it.traitsToKeep) + .build() + } + + // Change all original constrained member shapes with the new standalone shapes. + return ModelTransformer.create() + .replaceShapes(modelBuilder.build(), memberShapesToReplace) + } + + /** + * Returns a list of members that have constraint traits applied to them + */ + private fun Shape.constrainedMembers(): List = + this.allMembers.values.filter { + it.hasMemberConstraintTrait() + } + + /** + * Returns the unique (within the model) shape ID of the new shape + */ + private fun overriddenShapeId( + model: Model, + additionalNames: Set, + memberShape: ShapeId, + ): ShapeId { + val structName = memberShape.name + val memberName = memberShape.member.orElse(null) + .replaceFirstChar { if (it.isLowerCase()) it.titlecase(Locale.getDefault()) else it.toString() } + + fun makeStructName(suffix: String = "") = + ShapeId.from("${memberShape.namespace}#${structName}${memberName}$suffix") + + fun structNameIsUnique(newName: ShapeId) = + model.getShape(newName).isEmpty && !additionalNames.contains(newName) + + fun generateUniqueName(): ShapeId { + // Ensure the name does not already exist in the model, else make it unique + // by appending a new number as the suffix. + (0..100).forEach { + val extractedStructName = if (it == 0) makeStructName("") else makeStructName("$it") + if (structNameIsUnique(extractedStructName)) { + return extractedStructName + } + } + + throw IllegalStateException("A unique name for the overridden structure type could not be generated") + } + + return generateUniqueName() + } + + /** + * Returns the transformation that would be required to turn the given member shape + * into a non-constrained member shape. + */ + private fun MemberShape.makeNonConstrained( + model: Model, + additionalNames: MutableSet, + ): MemberShapeTransformation { + val (memberConstraintTraits, otherTraits) = this.allTraits.values + .partition { + memberConstraintTraitsToOverride.contains(it.javaClass) + } + + check(memberConstraintTraits.isNotEmpty()) { + "There must at least be one member constraint on the shape" + } + + // Build a new shape similar to the target of the constrained member shape. It should + // have all of the original constraints that have not been overridden, and the ones + // that this member shape overrides. + val targetShape = model.expectShape(this.target) + if (targetShape !is ToSmithyBuilder<*>) { + UNREACHABLE("Member target shapes will always be buildable") + } + + return when (val builder = targetShape.toBuilder()) { + is AbstractShapeBuilder<*, *> -> { + // Use the target builder to create a new standalone shape that would + // be added to the model later on. Keep all existing traits on the target + // but replace the ones that are overridden on the member shape. + val nonOverriddenConstraintTraits = + builder.allTraits.values.filter { existingTrait -> + memberConstraintTraits.none { it.toShapeId() == existingTrait.toShapeId() } + } + + // Add a synthetic constraint on all new shapes being defined, that would link + // the new shape to the root structure from which it is reachable. + val syntheticTrait = + SyntheticStructureFromConstrainedMemberTrait(model.expectShape(this.container), this) + + // Combine target traits, overridden traits and the synthetic trait + val newTraits = + nonOverriddenConstraintTraits + memberConstraintTraits + syntheticTrait + + // Create a new unique standalone shape that will be added to the model later on + val shapeId = overriddenShapeId(model, additionalNames, this.id) + val standaloneShape = builder.id(shapeId) + .traits(newTraits) + .build() + + // Since the new shape has not been added to the model as yet, the current + // memberShape's target cannot be changed to the new shape. + MemberShapeTransformation(standaloneShape, this, otherTraits) + } + + else -> UNREACHABLE("Constraint traits cannot to applied on ${this.id}") + } + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt new file mode 100644 index 00000000000..d2e41ead36a --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxer.kt @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait + +object RecursiveConstraintViolationBoxer { + /** + * Transform a model which may contain recursive shapes into a model annotated with [ConstraintViolationRustBoxTrait]. + * + * See [RecursiveShapeBoxer] for how the tagging algorithm works. + * + * The constraint violation graph needs to box types in recursive paths more often. Since we don't collect + * constraint violations (yet, see [0]), the constraint violation graph never holds `Vec`s or `HashMap`s, + * only simple types. Indeed, the following simple recursive model: + * + * ```smithy + * union Recursive { + * list: List + * } + * + * @length(min: 69) + * list List { + * member: Recursive + * } + * ``` + * + * has a cycle that goes through a list shape, so no shapes in it need boxing in the regular shape graph. However, + * the constraint violation graph is infinitely recursive if we don't introduce boxing somewhere: + * + * ```rust + * pub mod model { + * pub mod list { + * pub enum ConstraintViolation { + * Length(usize), + * Member( + * usize, + * crate::model::recursive::ConstraintViolation, + * ), + * } + * } + * + * pub mod recursive { + * pub enum ConstraintViolation { + * List(crate::model::list::ConstraintViolation), + * } + * } + * } + * ``` + * + * So what we do to fix this is to configure the `RecursiveShapeBoxer` model transform so that the "cycles through + * lists and maps introduce indirection" assumption can be lifted. This allows this model transform to tag member + * shapes along recursive paths with a new trait, `ConstraintViolationRustBoxTrait`, that the constraint violation + * type generation then utilizes to ensure that no infinitely recursive constraint violation types get generated. + * Places where constraint violations are handled (like where unconstrained types are converted to constrained + * types) must account for the scenario where they now are or need to be boxed. + * + * [0] https://github.com/awslabs/smithy-rs/pull/2040 + */ + fun transform(model: Model): Model = RecursiveShapeBoxer( + containsIndirectionPredicate = ::constraintViolationLoopContainsIndirection, + boxShapeFn = ::addConstraintViolationRustBoxTrait, + ).transform(model) + + private fun constraintViolationLoopContainsIndirection(loop: Collection): Boolean = + loop.find { it.hasTrait() } != null + + private fun addConstraintViolationRustBoxTrait(memberShape: MemberShape): MemberShape = + memberShape.toBuilder().addTrait(ConstraintViolationRustBoxTrait()).build() +} diff --git a/codegen-server/src/main/resources/META-INF/services/software.amazon.smithy.build.SmithyBuildPlugin b/codegen-server/src/main/resources/META-INF/services/software.amazon.smithy.build.SmithyBuildPlugin index 00f891cb228..392b1d43920 100644 --- a/codegen-server/src/main/resources/META-INF/services/software.amazon.smithy.build.SmithyBuildPlugin +++ b/codegen-server/src/main/resources/META-INF/services/software.amazon.smithy.build.SmithyBuildPlugin @@ -2,4 +2,4 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 # -software.amazon.smithy.rust.codegen.server.smithy.RustCodegenServerPlugin +software.amazon.smithy.rust.codegen.server.smithy.RustServerCodegenPlugin diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt index d96c761fa8c..21d7e5c48f6 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstrainedShapeSymbolProviderTest.kt @@ -90,7 +90,7 @@ class ConstrainedShapeSymbolProviderTest { private val model = baseModelString.asSmithyModel() private val serviceShape = model.lookup("test#TestService") private val symbolProvider = serverTestSymbolProvider(model, serviceShape) - private val constrainedShapeSymbolProvider = ConstrainedShapeSymbolProvider(symbolProvider, model, serviceShape) + private val constrainedShapeSymbolProvider = ConstrainedShapeSymbolProvider(symbolProvider, serviceShape, true) companion object { @JvmStatic diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt new file mode 100644 index 00000000000..952140d7262 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsMemberShapeTest.kt @@ -0,0 +1,499 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe +import org.junit.jupiter.api.Test +import software.amazon.smithy.aws.traits.DataTrait +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.SourceLocation +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.RequiredTrait +import software.amazon.smithy.model.traits.Trait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeCrateLocation +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.runCommand +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.customizations.CustomValidationExceptionWithReasonDecorator +import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionDecorator +import software.amazon.smithy.rust.codegen.server.smithy.customize.CombinedServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.transformers.ConstrainedMemberTransform +import java.io.File +import java.nio.file.Path + +class ConstraintsMemberShapeTest { + private val outputModelOnly = """ + namespace constrainedMemberShape + + use aws.protocols#restJson1 + use aws.api#data + + @restJson1 + service ConstrainedService { + operations: [OperationUsingGet] + } + + @http(uri: "/anOperation", method: "GET") + operation OperationUsingGet { + output : OperationUsingGetOutput + } + structure OperationUsingGetOutput { + plainLong : Long + plainInteger : Integer + plainShort : Short + plainByte : Byte + plainFloat: Float + plainString: String + + @range(min: 1, max:100) + constrainedLong : Long + @range(min: 2, max:100) + constrainedInteger : Integer + @range(min: 3, max:100) + constrainedShort : Short + @range(min: 4, max:100) + constrainedByte : Byte + @length(max: 100) + constrainedString: String + + @required + @range(min: 5, max:100) + requiredConstrainedLong : Long + @required + @range(min: 6, max:100) + requiredConstrainedInteger : Integer + @required + @range(min: 7, max:100) + requiredConstrainedShort : Short + @required + @range(min: 8, max:100) + requiredConstrainedByte : Byte + @required + @length(max: 101) + requiredConstrainedString: String + + patternString : PatternString + + @data("content") + @pattern("^[g-m]+${'$'}") + constrainedPatternString : PatternString + + plainStringList : PlainStringList + patternStringList : PatternStringList + patternStringListOverride : PatternStringListOverride + + plainStructField : PlainStructWithInteger + structWithConstrainedMember : StructWithConstrainedMember + structWithConstrainedMemberOverride : StructWithConstrainedMemberOverride + + patternUnion: PatternUnion + patternUnionOverride: PatternUnionOverride + patternMap : PatternMap + patternMapOverride: PatternMapOverride + } + list ListWithIntegerMemberStruct { + member: PlainStructWithInteger + } + structure PlainStructWithInteger { + lat : Integer + long : Integer + } + structure StructWithConstrainedMember { + @range(min: 100) + lat : Integer + long : Integer + } + structure StructWithConstrainedMemberOverride { + @range(min: 10) + lat : RangedInteger + @range(min: 10, max:100) + long : RangedInteger + } + list PlainStringList { + member: String + } + list PatternStringList { + member: PatternString + } + list PatternStringListOverride { + @pattern("^[g-m]+${'$'}") + member: PatternString + } + map PatternMap { + key: PatternString, + value: PatternString + } + map PatternMapOverride { + @pattern("^[g-m]+${'$'}") + key: PatternString, + @pattern("^[g-m]+${'$'}") + value: PatternString + } + union PatternUnion { + first: PatternString, + second: PatternString + } + union PatternUnionOverride { + @pattern("^[g-m]+${'$'}") + first: PatternString, + @pattern("^[g-m]+${'$'}") + second: PatternString + } + @pattern("^[a-m]+${'$'}") + string PatternString + @range(min: 0, max:1000) + integer RangedInteger + """.asSmithyModel() + + private fun loadModel(model: Model): Model = + ConstrainedMemberTransform.transform(OperationNormalizer.transform(model)) + + @Test + fun `non constrained fields should not be changed`() { + val transformedModel = loadModel(outputModelOnly) + + fun checkFieldTargetRemainsSame(fieldName: String) { + checkMemberShapeIsSame( + transformedModel, + outputModelOnly, + "constrainedMemberShape.synthetic#OperationUsingGetOutput\$$fieldName", + "constrainedMemberShape#OperationUsingGetOutput\$$fieldName", + ) { + "OperationUsingGetOutput$fieldName has changed whereas it is not constrained and should have remained same" + } + } + + setOf( + "plainInteger", + "plainLong", + "plainByte", + "plainShort", + "plainFloat", + "patternString", + "plainStringList", + "patternStringList", + "patternStringListOverride", + "plainStructField", + "structWithConstrainedMember", + "structWithConstrainedMemberOverride", + "patternUnion", + "patternUnionOverride", + "patternMap", + "patternMapOverride", + ).forEach(::checkFieldTargetRemainsSame) + + checkMemberShapeIsSame( + transformedModel, + outputModelOnly, + "constrainedMemberShape#StructWithConstrainedMember\$long", + "constrainedMemberShape#StructWithConstrainedMember\$long", + ) + } + + @Test + fun `constrained members should have a different target now`() { + val transformedModel = loadModel(outputModelOnly) + checkMemberShapeChanged( + transformedModel, + outputModelOnly, + "constrainedMemberShape#PatternStringListOverride\$member", + "constrainedMemberShape#PatternStringListOverride\$member", + ) + + fun checkSyntheticFieldTargetChanged(fieldName: String) { + checkMemberShapeChanged( + transformedModel, + outputModelOnly, + "constrainedMemberShape.synthetic#OperationUsingGetOutput\$$fieldName", + "constrainedMemberShape#OperationUsingGetOutput\$$fieldName", + ) { + "constrained member $fieldName should have been changed into a new type." + } + } + + fun checkFieldTargetChanged(memberNameWithContainer: String) { + checkMemberShapeChanged( + transformedModel, + outputModelOnly, + "constrainedMemberShape#$memberNameWithContainer", + "constrainedMemberShape#$memberNameWithContainer", + ) { + "constrained member $memberNameWithContainer should have been changed into a new type." + } + } + + setOf( + "constrainedLong", + "constrainedByte", + "constrainedShort", + "constrainedInteger", + "constrainedString", + "requiredConstrainedString", + "requiredConstrainedLong", + "requiredConstrainedByte", + "requiredConstrainedInteger", + "requiredConstrainedShort", + "constrainedPatternString", + ).forEach(::checkSyntheticFieldTargetChanged) + + setOf( + "StructWithConstrainedMember\$lat", + "PatternMapOverride\$key", + "PatternMapOverride\$value", + "PatternStringListOverride\$member", + ).forEach(::checkFieldTargetChanged) + } + + @Test + fun `extra trait on a constrained member should remain on it`() { + val transformedModel = loadModel(outputModelOnly) + checkShapeHasTrait( + transformedModel, + outputModelOnly, + "constrainedMemberShape.synthetic#OperationUsingGetOutput\$constrainedPatternString", + "constrainedMemberShape#OperationUsingGetOutput\$constrainedPatternString", + DataTrait("content", SourceLocation.NONE), + ) + } + + @Test + fun `required remains on constrained member shape`() { + val transformedModel = loadModel(outputModelOnly) + checkShapeHasTrait( + transformedModel, + outputModelOnly, + "constrainedMemberShape.synthetic#OperationUsingGetOutput\$requiredConstrainedString", + "constrainedMemberShape#OperationUsingGetOutput\$requiredConstrainedString", + RequiredTrait(), + ) + } + + private fun runServerCodeGen(model: Model, dirToUse: File? = null, writable: Writable): Path { + val runtimeConfig = + RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.Path(File("../rust-runtime").absolutePath)) + + val (context, dir) = generatePluginContext( + model, + runtimeConfig = runtimeConfig, + overrideTestDir = dirToUse, + ) + val codegenDecorator = + CombinedServerCodegenDecorator.fromClasspath( + context, + ServerRequiredCustomizations(), + SmithyValidationExceptionDecorator(), + CustomValidationExceptionWithReasonDecorator(), + ) + + ServerCodegenVisitor(context, codegenDecorator) + .execute() + + val codegenContext = serverTestCodegenContext(model) + val settings = ServerRustSettings.from(context.model, context.settings) + val rustCrate = RustCrate( + context.fileManifest, + codegenContext.symbolProvider, + settings.codegenConfig, + ) + + // We cannot write to the lib anymore as the RustWriter overwrites it, so writing code directly to check.rs + // and then adding a `mod check;` to the lib.rs + rustCrate.withModule(RustModule.public("check")) { + writable(this) + File("$dir/src/check.rs").writeText(toString()) + } + + val lib = File("$dir/src/lib.rs") + val libContents = lib.readText() + "\nmod check;" + lib.writeText(libContents) + + return dir + } + + @Test + fun `generate code and check member constrained shapes are in the right modules`() { + val dir = runServerCodeGen(outputModelOnly) { + fun RustWriter.testTypeExistsInBuilderModule(typeName: String) { + unitTest( + "builder_module_has_${typeName.toSnakeCase()}", + """ + #[allow(unused_imports)] use crate::output::operation_using_get_output::$typeName; + """, + ) + } + + // All directly constrained members of the output structure should be in the builder module + setOf( + "ConstrainedLong", + "ConstrainedByte", + "ConstrainedShort", + "ConstrainedInteger", + "ConstrainedString", + "RequiredConstrainedString", + "RequiredConstrainedLong", + "RequiredConstrainedByte", + "RequiredConstrainedInteger", + "RequiredConstrainedShort", + "ConstrainedPatternString", + ).forEach(::testTypeExistsInBuilderModule) + + fun Set.generateUseStatements(prefix: String) = + this.joinToString(separator = "\n") { + "#[allow(unused_imports)] use $prefix::$it;" + } + + unitTest( + "map_overridden_enum", + setOf( + "Value", + "value::ConstraintViolation as ValueCV", + "Key", + "key::ConstraintViolation as KeyCV", + ).generateUseStatements("crate::model::pattern_map_override"), + ) + + unitTest( + "union_overridden_enum", + setOf( + "First", + "first::ConstraintViolation as FirstCV", + "Second", + "second::ConstraintViolation as SecondCV", + ).generateUseStatements("crate::model::pattern_union_override"), + ) + + unitTest( + "list_overridden_enum", + setOf( + "Member", + "member::ConstraintViolation as MemberCV", + ).generateUseStatements("crate::model::pattern_string_list_override"), + ) + } + + val env = mapOf("RUSTFLAGS" to "-A dead_code") + "cargo test".runCommand(dir, env) + } + + /** + * Checks that the given member shape: + * 1. Has been changed to a new shape + * 2. New shape has the same type as the original shape's target e.g. float Centigrade, + * float newType + */ + private fun checkMemberShapeChanged( + model: Model, + baseModel: Model, + member: String, + orgModelMember: String, + lazyMessage: () -> Any = ::defaultError, + ) { + val memberId = ShapeId.from(member) + assert(model.getShape(memberId).isPresent) { lazyMessage } + + val memberShape = model.expectShape(memberId).asMemberShape().get() + val memberTargetShape = model.expectShape(memberShape.target) + val orgMemberId = ShapeId.from(orgModelMember) + assert(baseModel.getShape(orgMemberId).isPresent) { lazyMessage } + + val beforeTransformMemberShape = baseModel.expectShape(orgMemberId).asMemberShape().get() + val originalTargetShape = model.expectShape(beforeTransformMemberShape.target) + + val extractableConstraintTraits = allConstraintTraits - RequiredTrait::class.java + + // New member shape should not have the overridden constraints on it + assert(!extractableConstraintTraits.any(memberShape::hasTrait)) { lazyMessage } + + // Target shape has to be changed to a new shape + memberTargetShape.id.name shouldNotBe beforeTransformMemberShape.target.name + + // Target shape's name should match the expected name + val expectedName = memberShape.container.name.substringAfter('#') + + memberShape.memberName.substringBefore('#').toPascalCase() + + memberTargetShape.id.name shouldBe expectedName + + // New shape should have all the constraint traits that were on the member shape, + // and it should also have the traits that the target type contains. + val beforeTransformConstraintTraits = + beforeTransformMemberShape.allTraits.values.filter { allConstraintTraits.contains(it.javaClass) }.toSet() + val newShapeConstrainedTraits = + memberTargetShape.allTraits.values.filter { allConstraintTraits.contains(it.javaClass) }.toSet() + + val leftOutConstraintTrait = beforeTransformConstraintTraits - newShapeConstrainedTraits + assert( + leftOutConstraintTrait.isEmpty() || leftOutConstraintTrait.all { + it.toShapeId() == RequiredTrait.ID + }, + ) { lazyMessage } + + // In case the target shape has some more constraints, which the member shape did not override, + // then those still need to apply on the new standalone shape that has been defined. + val leftOverTraits = originalTargetShape.allTraits.values + .filter { beforeOverridingTrait -> + beforeTransformConstraintTraits.none { + beforeOverridingTrait.toShapeId() == it.toShapeId() + } + } + val allNewShapeTraits = memberTargetShape.allTraits.values.toList() + assert((leftOverTraits + newShapeConstrainedTraits).all { it in allNewShapeTraits }) { lazyMessage } + } + + private fun defaultError() = "test failed" + + /** + * Checks that the given shape has not changed in the transformed model and is exactly + * the same as the original model + */ + private fun checkMemberShapeIsSame( + model: Model, + baseModel: Model, + member: String, + orgModelMember: String, + lazyMessage: () -> Any = ::defaultError, + ) { + val memberId = ShapeId.from(member) + assert(model.getShape(memberId).isPresent) { lazyMessage } + + val memberShape = model.expectShape(memberId).asMemberShape().get() + val memberTargetShape = model.expectShape(memberShape.target) + val originalShape = baseModel.expectShape(ShapeId.from(orgModelMember)).asMemberShape().get() + + // Member shape should not have any constraints on it + assert(!memberShape.hasConstraintTrait()) { lazyMessage } + // Target shape has to be same as the original shape + memberTargetShape.id shouldBe originalShape.target + } + + private fun checkShapeHasTrait( + model: Model, + orgModel: Model, + member: String, + orgModelMember: String, + trait: Trait, + ) { + val memberId = ShapeId.from(member) + val memberShape = model.expectShape(memberId).asMemberShape().get() + val orgMemberShape = orgModel.expectShape(ShapeId.from(orgModelMember)).asMemberShape().get() + + val newMemberTrait = memberShape.expectTrait(trait::class.java) + val oldMemberTrait = orgMemberShape.expectTrait(trait::class.java) + + newMemberTrait shouldBe oldMemberTrait + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index 946027ce02b..30e5e648130 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.smithy import io.kotest.inspectors.forAll import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.MemberShape @@ -81,7 +82,12 @@ class ConstraintsTest { @length(min: 1, max: 5) mapAPrecedence: MapA } - """.asSmithyModel() + + structure StructWithInnerDefault { + @default(false) + inner: PrimitiveBoolean + } + """.asSmithyModel(smithyVersion = "2") private val symbolProvider = serverTestSymbolProvider(model) private val testInputOutput = model.lookup("test#TestInputOutput") @@ -93,6 +99,8 @@ class ConstraintsTest { private val structA = model.lookup("test#StructureA") private val structAInt = model.lookup("test#StructureA\$int") private val structAString = model.lookup("test#StructureA\$string") + private val structWithInnerDefault = model.lookup("test#StructWithInnerDefault") + private val primitiveBoolean = model.lookup("smithy.api#PrimitiveBoolean") @Test fun `it should detect supported constrained traits as constrained`() { @@ -119,4 +127,10 @@ class ConstraintsTest { mapB.canReachConstrainedShape(model, symbolProvider) shouldBe true recursiveShape.canReachConstrainedShape(model, symbolProvider) shouldBe true } + + @Test + fun `it should not consider shapes with the default trait as constrained`() { + structWithInnerDefault.canReachConstrainedShape(model, symbolProvider) shouldBe false + primitiveBoolean.isDirectlyConstrained(symbolProvider) shouldBe false + } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt index 55c0c6e680b..5f2dea66e2d 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/DeriveEqAndHashSymbolMetadataProviderTest.kt @@ -171,8 +171,8 @@ internal class DeriveEqAndHashSymbolMetadataProviderTest { """.asSmithyModel(smithyVersion = "2.0") private val serviceShape = model.lookup("test#TestService") private val deriveEqAndHashSymbolMetadataProvider = serverTestSymbolProvider(model, serviceShape) - .let { BaseSymbolMetadataProvider(it, model, additionalAttributes = listOf()) } - .let { DeriveEqAndHashSymbolMetadataProvider(it, model) } + .let { BaseSymbolMetadataProvider(it, additionalAttributes = listOf()) } + .let { DeriveEqAndHashSymbolMetadataProvider(it) } companion object { @JvmStatic diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt index f0b339a4852..8d07a5959af 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/PubCrateConstrainedShapeSymbolProviderTest.kt @@ -55,13 +55,17 @@ class PubCrateConstrainedShapeSymbolProviderTest { @Test fun `it should crash when provided with a shape that is directly constrained`() { val constrainedStringShape = model.lookup("test#ConstrainedString") - shouldThrow { pubCrateConstrainedShapeSymbolProvider.toSymbol(constrainedStringShape) } + shouldThrow { + pubCrateConstrainedShapeSymbolProvider.toSymbol(constrainedStringShape) + } } @Test fun `it should crash when provided with a shape that is unconstrained`() { val unconstrainedStringShape = model.lookup("test#UnconstrainedString") - shouldThrow { pubCrateConstrainedShapeSymbolProvider.toSymbol(unconstrainedStringShape) } + shouldThrow { + pubCrateConstrainedShapeSymbolProvider.toSymbol(unconstrainedStringShape) + } } @Test diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt new file mode 100644 index 00000000000..7467d0d76ff --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RecursiveConstraintViolationsTest.kt @@ -0,0 +1,185 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.model.Model +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.util.stream.Stream + +internal class RecursiveConstraintViolationsTest { + + data class TestCase( + /** The test name is only used in the generated report, to easily identify a failing test. **/ + val testName: String, + /** The model to generate **/ + val model: Model, + /** The shape ID of the member shape that should have the marker trait attached. **/ + val shapeIdWithConstraintViolationRustBoxTrait: String, + ) + + class RecursiveConstraintViolationsTestProvider : ArgumentsProvider { + private val baseModel = + """ + namespace com.amazonaws.recursiveconstraintviolations + + use aws.protocols#restJson1 + use smithy.framework#ValidationException + + @restJson1 + service RecursiveConstraintViolations { + operations: [ + Operation + ] + } + + @http(uri: "/operation", method: "POST") + operation Operation { + input: Recursive + output: Recursive + errors: [ValidationException] + } + """ + + private fun recursiveListModel(sparse: Boolean, listPrefix: String = ""): Pair = + """ + $baseModel + + structure Recursive { + list: ${listPrefix}List + } + + ${ if (sparse) { "@sparse" } else { "" } } + @length(min: 69) + list ${listPrefix}List { + member: Recursive + } + """.asSmithyModel() to if ("${listPrefix}List" < "Recursive") { + "com.amazonaws.recursiveconstraintviolations#${listPrefix}List\$member" + } else { + "com.amazonaws.recursiveconstraintviolations#Recursive\$list" + } + + private fun recursiveMapModel(sparse: Boolean, mapPrefix: String = ""): Pair = + """ + $baseModel + + structure Recursive { + map: ${mapPrefix}Map + } + + ${ if (sparse) { "@sparse" } else { "" } } + @length(min: 69) + map ${mapPrefix}Map { + key: String, + value: Recursive + } + """.asSmithyModel() to if ("${mapPrefix}Map" < "Recursive") { + "com.amazonaws.recursiveconstraintviolations#${mapPrefix}Map\$value" + } else { + "com.amazonaws.recursiveconstraintviolations#Recursive\$map" + } + + private fun recursiveUnionModel(unionPrefix: String = ""): Pair = + """ + $baseModel + + structure Recursive { + attributeValue: ${unionPrefix}AttributeValue + } + + // Named `${unionPrefix}AttributeValue` in honor of DynamoDB's famous `AttributeValue`. + // https://docs.rs/aws-sdk-dynamodb/latest/aws_sdk_dynamodb/model/enum.AttributeValue.html + union ${unionPrefix}AttributeValue { + set: SetAttribute + } + + @uniqueItems + list SetAttribute { + member: ${unionPrefix}AttributeValue + } + """.asSmithyModel() to + // The first loop the algorithm picks out to fix turns out to be the `list <-> union` loop: + // + // ``` + // [ + // ${unionPrefix}AttributeValue, + // ${unionPrefix}AttributeValue$set, + // SetAttribute, + // SetAttribute$member + // ] + // ``` + // + // Upon which, after fixing it, the other loop (`structure <-> list <-> union`) already contains + // indirection, so we disregard it. + // + // This is hence a good test in that it tests that `RecursiveConstraintViolationBoxer` does not + // superfluously add more indirection than strictly necessary. + // However, it is a bad test in that if the Smithy library ever returns the recursive paths in a + // different order, the (`structure <-> list <-> union`) loop might be fixed first, and this test might + // start to fail! So watch out for that. Nonetheless, `RecursiveShapeBoxer` calls out: + // + // This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. + // + // So I think it's fair to write this test under the above assumption. + if ("${unionPrefix}AttributeValue" < "SetAttribute") { + "com.amazonaws.recursiveconstraintviolations#${unionPrefix}AttributeValue\$set" + } else { + "com.amazonaws.recursiveconstraintviolations#SetAttribute\$member" + } + + override fun provideArguments(context: ExtensionContext?): Stream { + val listModels = listOf(false, true).flatMap { isSparse -> + listOf("", "ZZZ").map { listPrefix -> + val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveListModel(isSparse, listPrefix) + var testName = "${ if (isSparse) "sparse" else "non-sparse" } recursive list" + if (listPrefix.isNotEmpty()) { + testName += " with shape name prefix $listPrefix" + } + TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) + } + } + val mapModels = listOf(false, true).flatMap { isSparse -> + listOf("", "ZZZ").map { mapPrefix -> + val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveMapModel(isSparse, mapPrefix) + var testName = "${ if (isSparse) "sparse" else "non-sparse" } recursive map" + if (mapPrefix.isNotEmpty()) { + testName += " with shape name prefix $mapPrefix" + } + TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) + } + } + val unionModels = listOf("", "ZZZ").map { unionPrefix -> + val (model, shapeIdWithConstraintViolationRustBoxTrait) = recursiveUnionModel(unionPrefix) + var testName = "recursive union" + if (unionPrefix.isNotEmpty()) { + testName += " with shape name prefix $unionPrefix" + } + TestCase(testName, model, shapeIdWithConstraintViolationRustBoxTrait) + } + return listOf(listModels, mapModels, unionModels) + .flatten() + .map { Arguments.of(it) }.stream() + } + } + + /** + * Ensures the models generate code that compiles. + * + * Make sure the tests in [software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxerTest] + * are all passing before debugging any of these tests, since the former tests test preconditions for these. + */ + @ParameterizedTest + @ArgumentsSource(RecursiveConstraintViolationsTestProvider::class) + fun `recursive constraint violation code generation test`(testCase: TestCase) { + serverIntegrationTest(testCase.model) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriterTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriterTest.kt new file mode 100644 index 00000000000..3637548d9a4 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/RustCrateInlineModuleComposingWriterTest.kt @@ -0,0 +1,271 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy + +import io.kotest.matchers.collections.shouldContain +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Visibility +import software.amazon.smithy.rust.codegen.core.rustlang.comment +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeCrateLocation +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import java.io.File + +class RustCrateInlineModuleComposingWriterTest { + private val rustCrate: RustCrate + private val codegenContext: ServerCodegenContext + private val model: Model = """ + ${'$'}version: "2.0" + namespace test + + use aws.api#data + use aws.protocols#restJson1 + + @title("Weather Service") + @restJson1 + service WeatherService { + operations: [MalformedPatternOverride] + } + + @suppress(["UnstableTrait"]) + @http(uri: "/MalformedPatternOverride", method: "GET") + operation MalformedPatternOverride { + output: MalformedPatternOverrideInput, + errors: [] + } + + structure MalformedPatternOverrideInput { + @pattern("^[g-m]+${'$'}") + string: PatternString, + } + + @pattern("^[a-m]+${'$'}") + string PatternString + """.trimIndent().asSmithyModel() + + init { + codegenContext = serverTestCodegenContext(model) + val runtimeConfig = + RuntimeConfig(runtimeCrateLocation = RuntimeCrateLocation.Path(File("../rust-runtime").absolutePath)) + + val (context, _) = generatePluginContext( + model, + runtimeConfig = runtimeConfig, + ) + val settings = ServerRustSettings.from(context.model, context.settings) + rustCrate = RustCrate(context.fileManifest, codegenContext.symbolProvider, settings.codegenConfig) + } + + private fun createTestInlineModule(parentModule: RustModule, moduleName: String, documentation: String? = null): RustModule.LeafModule = + RustModule.new( + moduleName, + visibility = Visibility.PUBLIC, + documentation = documentation ?: moduleName, + parent = parentModule, + inline = true, + ) + + private fun createTestOrphanInlineModule(moduleName: String): RustModule.LeafModule = + RustModule.new( + moduleName, + visibility = Visibility.PUBLIC, + documentation = moduleName, + parent = RustModule.LibRs, + inline = true, + ) + + private fun helloWorld(writer: RustWriter, moduleName: String) { + writer.rustBlock("pub fn hello_world()") { + writer.comment("Module $moduleName") + } + } + + private fun byeWorld(writer: RustWriter, moduleName: String) { + writer.rustBlock("pub fn bye_world()") { + writer.comment("Module $moduleName") + writer.rust("""println!("from inside $moduleName");""") + } + } + + @Test + fun `calling withModule multiple times returns same object on rustModule`() { + val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) + val writers: MutableSet = mutableSetOf() + testProject.withModule(ServerRustModule.Model) { + writers.add(this) + } + testProject.withModule(ServerRustModule.Model) { + writers shouldContain this + } + } + + @Test + fun `simple inline module works`() { + val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) + val moduleA = createTestInlineModule(ServerRustModule.Model, "a") + testProject.withModule(ServerRustModule.Model) { + testProject.getInlineModuleWriter().withInlineModule(this, moduleA) { + helloWorld(this, "a") + } + } + + testProject.getInlineModuleWriter().render() + testProject.withModule(ServerRustModule.Model) { + this.unitTest("test_a") { + rust("crate::model::a::hello_world();") + } + } + testProject.compileAndTest() + } + + @Test + fun `creating nested modules works from different rustWriters`() { + // Define the following functions in different inline modules. + // crate::model::a::hello_world(); + // crate::model::a::bye_world(); + // crate::model::b::hello_world(); + // crate::model::b::bye_world(); + // crate::model::b::c::hello_world(); + // crate::model::b::c::bye_world(); + // crate::input::e::hello_world(); + // crate::output::f::hello_world(); + // crate::output::f::g::hello_world(); + // crate::output::h::hello_world(); + // crate::output::h::i::hello_world(); + + val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) + val modules = hashMapOf( + "a" to createTestOrphanInlineModule("a"), + "d" to createTestOrphanInlineModule("d"), + "e" to createTestOrphanInlineModule("e"), + "i" to createTestOrphanInlineModule("i"), + ) + + modules["b"] = createTestInlineModule(ServerRustModule.Model, "b") + modules["c"] = createTestInlineModule(modules["b"]!!, "c") + modules["f"] = createTestInlineModule(ServerRustModule.Output, "f") + modules["g"] = createTestInlineModule(modules["f"]!!, "g") + modules["h"] = createTestInlineModule(ServerRustModule.Output, "h") + // A different kotlin object but would still go in the right place + + testProject.withModule(ServerRustModule.Model) { + testProject.getInlineModuleWriter().withInlineModule(this, modules["a"]!!) { + helloWorld(this, "a") + } + testProject.getInlineModuleWriter().withInlineModule(this, modules["b"]!!) { + helloWorld(this, "b") + testProject.getInlineModuleWriter().withInlineModule(this, modules["c"]!!) { + byeWorld(this, "b::c") + } + } + // Writing to the same module crate::model::a second time should work. + testProject.getInlineModuleWriter().withInlineModule(this, modules["a"]!!) { + byeWorld(this, "a") + } + // Writing to model::b, when model::b and model::b::c have already been written to + // should work. + testProject.getInlineModuleWriter().withInlineModule(this, modules["b"]!!) { + byeWorld(this, "b") + } + } + + // Write directly to an inline module without specifying the immediate parent. crate::model::b::c + // should have a `hello_world` fn in it now. + testProject.withModule(ServerRustModule.Model) { + testProject.getInlineModuleWriter().withInlineModuleHierarchy(this, modules["c"]!!) { + helloWorld(this, "c") + } + } + // Write to a different top level module to confirm that works. + testProject.withModule(ServerRustModule.Input) { + testProject.getInlineModuleWriter().withInlineModuleHierarchy(this, modules["e"]!!) { + helloWorld(this, "e") + } + } + + // Create a descendent inner module crate::output::f::g and then try writing to the intermediate inner module + // that did not exist before the descendent was dded. + testProject.getInlineModuleWriter().withInlineModuleHierarchyUsingCrate(testProject, modules["f"]!!) { + testProject.getInlineModuleWriter().withInlineModuleHierarchyUsingCrate(testProject, modules["g"]!!) { + helloWorld(this, "g") + } + } + + testProject.getInlineModuleWriter().withInlineModuleHierarchyUsingCrate(testProject, modules["f"]!!) { + helloWorld(this, "f") + } + + // It should work even if the inner descendent module was added using `withInlineModuleHierarchy` and then + // code is added to the intermediate module using `withInlineModuleHierarchyUsingCrate` + testProject.withModule(ServerRustModule.Output) { + testProject.getInlineModuleWriter().withInlineModuleHierarchy(this, modules["h"]!!) { + testProject.getInlineModuleWriter().withInlineModuleHierarchy(this, modules["i"]!!) { + helloWorld(this, "i") + } + testProject.withModule(ServerRustModule.Model) { + // While writing to output::h::i, it should be able to a completely different module + testProject.getInlineModuleWriter().withInlineModuleHierarchy(this, modules["b"]!!) { + rustBlock("pub fn some_other_writer_wrote_this()") { + rust("""println!("from inside crate::model::b::some_other_writer_wrote_this");""") + } + } + } + } + } + testProject.getInlineModuleWriter().withInlineModuleHierarchyUsingCrate(testProject, modules["h"]!!) { + helloWorld(this, "h") + } + + // Render all of the code. + testProject.getInlineModuleWriter().render() + + testProject.withModule(ServerRustModule.Model) { + this.unitTest("test_a") { + rust("crate::model::a::hello_world();") + rust("crate::model::a::bye_world();") + } + this.unitTest("test_b") { + rust("crate::model::b::hello_world();") + rust("crate::model::b::bye_world();") + } + this.unitTest("test_someother_writer_wrote") { + rust("crate::model::b::some_other_writer_wrote_this();") + } + this.unitTest("test_b_c") { + rust("crate::model::b::c::hello_world();") + rust("crate::model::b::c::bye_world();") + } + this.unitTest("test_e") { + rust("crate::input::e::hello_world();") + } + this.unitTest("test_f") { + rust("crate::output::f::hello_world();") + } + this.unitTest("test_g") { + rust("crate::output::f::g::hello_world();") + } + this.unitTest("test_h") { + rust("crate::output::h::hello_world();") + } + this.unitTest("test_h_i") { + rust("crate::output::h::i::hello_world();") + } + } + testProject.compileAndTest() + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt index 2f23c143f43..942e5f76170 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitorTest.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.generatePluginContext import software.amazon.smithy.rust.codegen.server.smithy.customizations.ServerRequiredCustomizations +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionDecorator import software.amazon.smithy.rust.codegen.server.smithy.customize.CombinedServerCodegenDecorator import kotlin.io.path.writeText @@ -45,8 +46,12 @@ class ServerCodegenVisitorTest { """.asSmithyModel(smithyVersion = "2.0") val (ctx, testDir) = generatePluginContext(model) testDir.resolve("src/main.rs").writeText("fn main() {}") - val codegenDecorator: CombinedServerCodegenDecorator = - CombinedServerCodegenDecorator.fromClasspath(ctx, ServerRequiredCustomizations()) + val codegenDecorator = + CombinedServerCodegenDecorator.fromClasspath( + ctx, + ServerRequiredCustomizations(), + SmithyValidationExceptionDecorator(), + ) val visitor = ServerCodegenVisitor(ctx, codegenDecorator) val baselineModel = visitor.baselineTransformInternalTest(model) baselineModel.getShapesWithTrait(ShapeId.from("smithy.api#mixin")).isEmpty() shouldBe true diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt index 83ea701a975..68b88a7978d 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ValidateUnsupportedConstraintsAreNotUsedTest.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import java.util.logging.Level internal class ValidateUnsupportedConstraintsAreNotUsedTest { @@ -26,7 +27,6 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { namespace test service TestService { - version: "123", operations: [TestOperation] } @@ -53,7 +53,11 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { } """.asSmithyModel() val service = model.lookup("test#TestService") - val validationResult = validateOperationsWithConstrainedInputHaveValidationExceptionAttached(model, service) + val validationResult = validateOperationsWithConstrainedInputHaveValidationExceptionAttached( + model, + service, + SmithyValidationExceptionConversionGenerator.SHAPE_ID, + ) validationResult.messages shouldHaveSize 1 @@ -71,39 +75,6 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { """.trimIndent() } - @Test - fun `it should detect when unsupported constraint traits on member shapes are used`() { - val model = - """ - $baseModel - - structure TestInputOutput { - @length(min: 1, max: 69) - lengthString: String - } - """.asSmithyModel() - val validationResult = validateModel(model) - - validationResult.messages shouldHaveSize 1 - validationResult.messages[0].message shouldContain "The member shape `test#TestInputOutput\$lengthString` has the constraint trait `smithy.api#length` attached" - } - - @Test - fun `it should not detect when the required trait on a member shape is used`() { - val model = - """ - $baseModel - - structure TestInputOutput { - @required - string: String - } - """.asSmithyModel() - val validationResult = validateModel(model) - - validationResult.messages shouldHaveSize 0 - } - private val constraintTraitOnStreamingBlobShapeModel = """ $baseModel @@ -181,6 +152,49 @@ internal class ValidateUnsupportedConstraintsAreNotUsedTest { } } + private val mapShapeReachableFromUniqueItemsListShapeModel = + """ + $baseModel + + structure TestInputOutput { + uniqueItemsList: UniqueItemsList + } + + @uniqueItems + list UniqueItemsList { + member: Map + } + + map Map { + key: String + value: String + } + """.asSmithyModel() + + @Test + fun `it should detect when a map shape is reachable from a uniqueItems list shape`() { + val validationResult = validateModel(mapShapeReachableFromUniqueItemsListShapeModel) + + validationResult.messages shouldHaveSize 1 + validationResult.shouldAbort shouldBe true + validationResult.messages[0].message shouldContain( + """ + The map shape `test#Map` is reachable from the list shape `test#UniqueItemsList`, which has the + `@uniqueItems` trait attached. + """.trimIndent().replace("\n", " ") + ) + } + + @Test + fun `it should abort when a map shape is reachable from a uniqueItems list shape, despite opting into ignoreUnsupportedConstraintTraits`() { + val validationResult = validateModel( + mapShapeReachableFromUniqueItemsListShapeModel, + ServerCodegenConfig().copy(ignoreUnsupportedConstraints = true), + ) + + validationResult.shouldAbort shouldBe true + } + @Test fun `it should abort when constraint traits in event streams are used, despite opting into ignoreUnsupportedConstraintTraits`() { val validationResult = validateModel( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecoratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecoratorTest.kt new file mode 100644 index 00000000000..31b46023819 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/CustomValidationExceptionWithReasonDecoratorTest.kt @@ -0,0 +1,109 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.transform.ModelTransformer +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.io.File +import kotlin.streams.toList + +fun swapOutSmithyValidationExceptionForCustomOne(model: Model): Model { + val customValidationExceptionModel = + """ + namespace com.amazonaws.constraints + + enum ValidationExceptionFieldReason { + LENGTH_NOT_VALID = "LengthNotValid" + PATTERN_NOT_VALID = "PatternNotValid" + SYNTAX_NOT_VALID = "SyntaxNotValid" + VALUE_NOT_VALID = "ValueNotValid" + OTHER = "Other" + } + + /// Stores information about a field passed inside a request that resulted in an exception. + structure ValidationExceptionField { + /// The field name. + @required + Name: String + + @required + Reason: ValidationExceptionFieldReason + + /// Message describing why the field failed validation. + @required + Message: String + } + + /// A list of fields. + list ValidationExceptionFieldList { + member: ValidationExceptionField + } + + enum ValidationExceptionReason { + FIELD_VALIDATION_FAILED = "FieldValidationFailed" + UNKNOWN_OPERATION = "UnknownOperation" + CANNOT_PARSE = "CannotParse" + OTHER = "Other" + } + + /// The input fails to satisfy the constraints specified by an AWS service. + @error("client") + @httpError(400) + structure ValidationException { + /// Description of the error. + @required + Message: String + + /// Reason the request failed validation. + @required + Reason: ValidationExceptionReason + + /// The field that caused the error, if applicable. If more than one field + /// caused the error, pick one and elaborate in the message. + Fields: ValidationExceptionFieldList + } + """.asSmithyModel(smithyVersion = "2.0") + + // Remove Smithy's `ValidationException`. + var model = ModelTransformer.create().removeShapes( + model, + listOf(model.expectShape(SmithyValidationExceptionConversionGenerator.SHAPE_ID)), + ) + // Add our custom one. + model = ModelTransformer.create().replaceShapes(model, customValidationExceptionModel.shapes().toList()) + // Make all operations use our custom one. + val newOperationShapes = model.operationShapes.map { operationShape -> + operationShape.toBuilder().addError(ShapeId.from("com.amazonaws.constraints#ValidationException")).build() + } + return ModelTransformer.create().replaceShapes(model, newOperationShapes) +} + +internal class CustomValidationExceptionWithReasonDecoratorTest { + @Test + fun `constraints model with the CustomValidationExceptionWithReasonDecorator applied compiles`() { + var model = File("../codegen-core/common-test-models/constraints.smithy").readText().asSmithyModel() + model = swapOutSmithyValidationExceptionForCustomOne(model) + + serverIntegrationTest( + model, + IntegrationTestParams( + additionalSettings = Node.objectNodeBuilder().withMember( + "codegen", + Node.objectNodeBuilder() + .withMember("experimentalCustomValidationExceptionWithReasonPleaseDoNotUse", "com.amazonaws.constraints#ValidationException") + .build(), + ).build(), + ), + ) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/PostprocessValidationExceptionNotAttachedErrorMessageDecoratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/PostprocessValidationExceptionNotAttachedErrorMessageDecoratorTest.kt new file mode 100644 index 00000000000..9130d30b330 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/PostprocessValidationExceptionNotAttachedErrorMessageDecoratorTest.kt @@ -0,0 +1,73 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.customizations + +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.server.smithy.LogMessage +import software.amazon.smithy.rust.codegen.server.smithy.ValidationResult +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest + +internal class PostprocessValidationExceptionNotAttachedErrorMessageDecoratorTest { + @Test + fun `validation exception not attached error message is postprocessed if decorator is registered`() { + val model = + """ + namespace test + use aws.protocols#restJson1 + + @restJson1 + service TestService { + operations: ["ConstrainedOperation"], + } + + operation ConstrainedOperation { + input: ConstrainedOperationInput + } + + structure ConstrainedOperationInput { + @required + requiredString: String + } + """.asSmithyModel() + + val validationExceptionNotAttachedErrorMessageDummyPostprocessorDecorator = object : ServerCodegenDecorator { + override val name: String + get() = "ValidationExceptionNotAttachedErrorMessageDummyPostprocessorDecorator" + override val order: Byte + get() = 69 + + override fun postprocessValidationExceptionNotAttachedErrorMessage(validationResult: ValidationResult): ValidationResult { + check(validationResult.messages.size == 1) + + val level = validationResult.messages.first().level + val message = + """ + ${validationResult.messages.first().message} + + There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man. + """ + + return validationResult.copy(messages = listOf(LogMessage(level, message))) + } + } + + val exception = assertThrows { + serverIntegrationTest( + model, + additionalDecorators = listOf(validationExceptionNotAttachedErrorMessageDummyPostprocessorDecorator), + ) + } + val exceptionCause = (exception.cause!! as ValidationResult) + exceptionCause.messages.size shouldBe 1 + exceptionCause.messages.first().message shouldContain "There are three things all wise men fear: the sea in storm, a night with no moon, and the anger of a gentle man." + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGeneratorTest.kt index cfb8bbd38bb..060e0166a4d 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedBlobGeneratorTest.kt @@ -15,7 +15,6 @@ import org.junit.jupiter.params.provider.ArgumentsSource import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel @@ -23,6 +22,9 @@ import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext import java.util.stream.Stream @@ -66,9 +68,16 @@ class ConstrainedBlobGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(ModelsModule) { + project.withModule(ServerRustModule.Model) { addDependency(RuntimeType.blob(codegenContext.runtimeConfig).toSymbol()) - ConstrainedBlobGenerator(codegenContext, this, constrainedBlobShape).render() + + ConstrainedBlobGenerator( + codegenContext, + this.createTestInlineModuleCreator(), + this, + constrainedBlobShape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() unitTest( name = "try_from_success", @@ -119,9 +128,15 @@ class ConstrainedBlobGeneratorTest { val codegenContext = serverTestCodegenContext(model) - val writer = RustWriter.forModule(ModelsModule.name) + val writer = RustWriter.forModule(ServerRustModule.Model.name) - ConstrainedBlobGenerator(codegenContext, writer, constrainedBlobShape).render() + ConstrainedBlobGenerator( + codegenContext, + writer.createTestInlineModuleCreator(), + writer, + constrainedBlobShape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() // Check that the wrapped type is `pub(crate)`. writer.toString() shouldContain "pub struct ConstrainedBlob(pub(crate) aws_smithy_types::Blob);" diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGeneratorTest.kt index 8e043ec47c0..cce7da4aa63 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedCollectionGeneratorTest.kt @@ -25,7 +25,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -33,6 +32,9 @@ import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.stream.Stream @@ -174,7 +176,7 @@ class ConstrainedCollectionGeneratorTest { else -> UNREACHABLE("Shape is either list or set.") } - project.withModule(ModelsModule) { + project.withModule(ServerRustModule.Model) { render(codegenContext, this, shape) val instantiator = serverInstantiator(codegenContext) @@ -268,7 +270,7 @@ class ConstrainedCollectionGeneratorTest { """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform) val constrainedCollectionShape = model.lookup("test#ConstrainedList") - val writer = RustWriter.forModule(ModelsModule.name) + val writer = RustWriter.forModule(ServerRustModule.Model.name) val codegenContext = serverTestCodegenContext(model) render(codegenContext, writer, constrainedCollectionShape) @@ -284,6 +286,12 @@ class ConstrainedCollectionGeneratorTest { ) { val constraintsInfo = CollectionTraitInfo.fromShape(constrainedCollectionShape, codegenContext.symbolProvider) ConstrainedCollectionGenerator(codegenContext, writer, constrainedCollectionShape, constraintsInfo).render() - CollectionConstraintViolationGenerator(codegenContext, writer, constrainedCollectionShape, constraintsInfo).render() + CollectionConstraintViolationGenerator( + codegenContext, + writer.createTestInlineModuleCreator(), + constrainedCollectionShape, + constraintsInfo, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt index 2da058cde55..0eebb7e36bb 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedMapGeneratorTest.kt @@ -17,13 +17,15 @@ import software.amazon.smithy.model.node.ObjectNode import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger import java.util.stream.Stream @@ -76,7 +78,7 @@ class ConstrainedMapGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(ModelsModule) { + project.withModule(ServerRustModule.Model) { render(codegenContext, this, constrainedMapShape) val instantiator = serverInstantiator(codegenContext) @@ -138,7 +140,7 @@ class ConstrainedMapGeneratorTest { """.asSmithyModel().let(ShapesReachableFromOperationInputTagger::transform) val constrainedMapShape = model.lookup("test#ConstrainedMap") - val writer = RustWriter.forModule(ModelsModule.name) + val writer = RustWriter.forModule(ServerRustModule.Model.name) val codegenContext = serverTestCodegenContext(model) render(codegenContext, writer, constrainedMapShape) @@ -153,6 +155,11 @@ class ConstrainedMapGeneratorTest { constrainedMapShape: MapShape, ) { ConstrainedMapGenerator(codegenContext, writer, constrainedMapShape).render() - MapConstraintViolationGenerator(codegenContext, writer, constrainedMapShape).render() + MapConstraintViolationGenerator( + codegenContext, + writer.createTestInlineModuleCreator(), + constrainedMapShape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGeneratorTest.kt index 681d0fffe35..5a78574c93e 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedNumberGeneratorTest.kt @@ -14,12 +14,14 @@ import org.junit.jupiter.params.provider.ArgumentsSource import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext import java.util.stream.Stream @@ -70,8 +72,14 @@ class ConstrainedNumberGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(ModelsModule) { - ConstrainedNumberGenerator(codegenContext, this, shape).render() + project.withModule(ServerRustModule.Model) { + ConstrainedNumberGenerator( + codegenContext, + this.createTestInlineModuleCreator(), + this, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() unitTest( name = "try_from_success", @@ -131,8 +139,14 @@ class ConstrainedNumberGeneratorTest { val codegenContext = serverTestCodegenContext(model) - val writer = RustWriter.forModule(ModelsModule.name) - ConstrainedNumberGenerator(codegenContext, writer, constrainedShape).render() + val writer = RustWriter.forModule(ServerRustModule.Model.name) + ConstrainedNumberGenerator( + codegenContext, + writer.createTestInlineModuleCreator(), + writer, + constrainedShape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() // Check that the wrapped type is `pub(crate)`. writer.toString() shouldContain "pub struct $shapeName(pub(crate) $rustType);" diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt index a8fa6e441d3..62a7d061c30 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ConstrainedStringGeneratorTest.kt @@ -16,13 +16,15 @@ import org.junit.jupiter.params.provider.ArgumentsSource import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.CommandFailed import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext import java.util.stream.Stream @@ -81,8 +83,14 @@ class ConstrainedStringGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(ModelsModule) { - ConstrainedStringGenerator(codegenContext, this, constrainedStringShape).render() + project.withModule(ServerRustModule.Model) { + ConstrainedStringGenerator( + codegenContext, + this.createTestInlineModuleCreator(), + this, + constrainedStringShape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() unitTest( name = "try_from_success", @@ -134,9 +142,15 @@ class ConstrainedStringGeneratorTest { val codegenContext = serverTestCodegenContext(model) - val writer = RustWriter.forModule(ModelsModule.name) + val writer = RustWriter.forModule(ServerRustModule.Model.name) - ConstrainedStringGenerator(codegenContext, writer, constrainedStringShape).render() + ConstrainedStringGenerator( + codegenContext, + writer.createTestInlineModuleCreator(), + writer, + constrainedStringShape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() // Check that the wrapped type is `pub(crate)`. writer.toString() shouldContain "pub struct ConstrainedString(pub(crate) std::string::String);" @@ -161,9 +175,22 @@ class ConstrainedStringGeneratorTest { val project = TestWorkspace.testProject(codegenContext.symbolProvider) - project.withModule(ModelsModule) { - ConstrainedStringGenerator(codegenContext, this, constrainedStringShape).render() - ConstrainedStringGenerator(codegenContext, this, sensitiveConstrainedStringShape).render() + project.withModule(ServerRustModule.Model) { + val validationExceptionConversionGenerator = SmithyValidationExceptionConversionGenerator(codegenContext) + ConstrainedStringGenerator( + codegenContext, + this.createTestInlineModuleCreator(), + this, + constrainedStringShape, + validationExceptionConversionGenerator, + ).render() + ConstrainedStringGenerator( + codegenContext, + this.createTestInlineModuleCreator(), + this, + sensitiveConstrainedStringShape, + validationExceptionConversionGenerator, + ).render() unitTest( name = "non_sensitive_string_display_implementation", @@ -200,8 +227,14 @@ class ConstrainedStringGeneratorTest { val codegenContext = serverTestCodegenContext(model) val project = TestWorkspace.testProject(codegenContext.symbolProvider) - project.withModule(ModelsModule) { - ConstrainedStringGenerator(codegenContext, this, constrainedStringShape).render() + project.withModule(ServerRustModule.Model) { + ConstrainedStringGenerator( + codegenContext, + this.createTestInlineModuleCreator(), + this, + constrainedStringShape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() } assertThrows { diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolationsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolationsTest.kt new file mode 100644 index 00000000000..1f685909f38 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolationsTest.kt @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest + +class ServerBuilderConstraintViolationsTest { + + // This test exists not to regress on [this](https://github.com/awslabs/smithy-rs/issues/2343) issue. + // We generated constraint violation variants, pointing to a structure (StructWithInnerDefault below), + // but the structure was not constrained, because the structure's member have a default value + // and default values are validated at generation time from the model. + @Test + fun `it should not generate constraint violations for members with a default value`() { + val model = """ + namespace test + + use aws.protocols#restJson1 + use smithy.framework#ValidationException + + @restJson1 + service SimpleService { + operations: [Operation] + } + + @http(uri: "/operation", method: "POST") + operation Operation { + input: OperationInput + errors: [ValidationException] + } + + structure OperationInput { + @required + requiredStructureWithInnerDefault: StructWithInnerDefault + } + + structure StructWithInnerDefault { + @default(false) + inner: PrimitiveBoolean + } + """.asSmithyModel(smithyVersion = "2") + serverIntegrationTest(model) + } +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt index 2219c2e653b..580ebff60cb 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt @@ -15,13 +15,14 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.withBlock import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -31,6 +32,8 @@ import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider @@ -104,10 +107,10 @@ class ServerBuilderDefaultValuesTest { project.withModule(RustModule.public("model")) { when (builderGeneratorKind) { BuilderGeneratorKind.SERVER_BUILDER_GENERATOR -> { - writeServerBuilderGenerator(this, model, symbolProvider) + writeServerBuilderGenerator(project, this, model, symbolProvider) } BuilderGeneratorKind.SERVER_BUILDER_GENERATOR_WITHOUT_PUBLIC_CONSTRAINED_TYPES -> { - writeServerBuilderGeneratorWithoutPublicConstrainedTypes(this, model, symbolProvider) + writeServerBuilderGeneratorWithoutPublicConstrainedTypes(project, this, model, symbolProvider) } } @@ -143,6 +146,7 @@ class ServerBuilderDefaultValuesTest { ) } + project.renderInlineMemoryModules() // Run clippy because the builder's code for handling `@default` is prone to upset it. project.compileAndTest(runClippy = true) } @@ -167,7 +171,7 @@ class ServerBuilderDefaultValuesTest { .map { it.key to "${it.value}.into()" } } - private fun writeServerBuilderGeneratorWithoutPublicConstrainedTypes(writer: RustWriter, model: Model, symbolProvider: RustSymbolProvider) { + private fun writeServerBuilderGeneratorWithoutPublicConstrainedTypes(rustCrate: RustCrate, writer: RustWriter, model: Model, symbolProvider: RustSymbolProvider) { val struct = model.lookup("com.test#MyStruct") val codegenContext = serverTestCodegenContext( model, @@ -175,29 +179,37 @@ class ServerBuilderDefaultValuesTest { codegenConfig = ServerCodegenConfig(publicConstrainedTypes = false), ), ) - val builderGenerator = ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, struct) + val builderGenerator = ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, struct, SmithyValidationExceptionConversionGenerator(codegenContext)) - writer.implBlock(struct, symbolProvider) { + writer.implBlock(symbolProvider.toSymbol(struct)) { builderGenerator.renderConvenienceMethod(writer) } - builderGenerator.render(writer) - - ServerEnumGenerator(codegenContext, writer, model.lookup("com.test#Language")).render() - StructureGenerator(model, symbolProvider, writer, struct).render() + builderGenerator.render(rustCrate, writer) + + ServerEnumGenerator( + codegenContext, + model.lookup("com.test#Language"), + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render(writer) + StructureGenerator(model, symbolProvider, writer, struct, emptyList()).render() } - private fun writeServerBuilderGenerator(writer: RustWriter, model: Model, symbolProvider: RustSymbolProvider) { + private fun writeServerBuilderGenerator(rustCrate: RustCrate, writer: RustWriter, model: Model, symbolProvider: RustSymbolProvider) { val struct = model.lookup("com.test#MyStruct") val codegenContext = serverTestCodegenContext(model) - val builderGenerator = ServerBuilderGenerator(codegenContext, struct) + val builderGenerator = ServerBuilderGenerator(codegenContext, struct, SmithyValidationExceptionConversionGenerator(codegenContext)) - writer.implBlock(struct, symbolProvider) { + writer.implBlock(symbolProvider.toSymbol(struct)) { builderGenerator.renderConvenienceMethod(writer) } - builderGenerator.render(writer) - - ServerEnumGenerator(codegenContext, writer, model.lookup("com.test#Language")).render() - StructureGenerator(model, symbolProvider, writer, struct).render() + builderGenerator.render(rustCrate, writer) + + ServerEnumGenerator( + codegenContext, + model.lookup("com.test#Language"), + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render(writer) + StructureGenerator(model, symbolProvider, writer, struct, emptyList()).render() } private fun structSetters(values: Map, optional: Boolean) = writable { diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorTest.kt index 24688f36546..d26e6b35db6 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorTest.kt @@ -7,13 +7,17 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.rustlang.implBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext class ServerBuilderGeneratorTest { @@ -35,23 +39,40 @@ class ServerBuilderGeneratorTest { """.asSmithyModel() val codegenContext = serverTestCodegenContext(model) - val writer = RustWriter.forModule("model") - val shape = model.lookup("test#Credentials") - StructureGenerator(model, codegenContext.symbolProvider, writer, shape).render(CodegenTarget.SERVER) - val builderGenerator = ServerBuilderGenerator(codegenContext, shape) - builderGenerator.render(writer) - writer.implBlock(shape, codegenContext.symbolProvider) { - builderGenerator.renderConvenienceMethod(this) + val project = TestWorkspace.testProject() + project.withModule(ServerRustModule.Model) { + val writer = this + val shape = model.lookup("test#Credentials") + + StructureGenerator(model, codegenContext.symbolProvider, writer, shape, emptyList()).render() + val builderGenerator = ServerBuilderGenerator( + codegenContext, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ) + + builderGenerator.render(project, writer) + + writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) { + builderGenerator.renderConvenienceMethod(this) + } + + project.renderInlineMemoryModules() + } + + project.unitTest { + rust( + """ + use super::*; + use crate::model::*; + let builder = Credentials::builder() + .username(Some("admin".to_owned())) + .password(Some("pswd".to_owned())) + .secret_key(Some("12345".to_owned())); + assert_eq!(format!("{:?}", builder), "Builder { username: Some(\"admin\"), password: \"*** Sensitive Data Redacted ***\", secret_key: \"*** Sensitive Data Redacted ***\" }"); + """, + ) } - writer.compileAndTest( - """ - use super::*; - let builder = Credentials::builder() - .username(Some("admin".to_owned())) - .password(Some("pswd".to_owned())) - .secret_key(Some("12345".to_owned())); - assert_eq!(format!("{:?}", builder), "Builder { username: Some(\"admin\"), password: \"*** Sensitive Data Redacted ***\", secret_key: \"*** Sensitive Data Redacted ***\" }"); - """, - ) + project.compileAndTest() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt index 0e813cebbd4..bffb82bb408 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext class ServerEnumGeneratorTest { @@ -40,7 +41,11 @@ class ServerEnumGeneratorTest { @Test fun `it generates TryFrom, FromStr and errors for enums`() { - ServerEnumGenerator(codegenContext, writer, shape).render() + ServerEnumGenerator( + codegenContext, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render(writer) writer.compileAndTest( """ use std::str::FromStr; @@ -53,10 +58,14 @@ class ServerEnumGeneratorTest { @Test fun `it generates enums without the unknown variant`() { - ServerEnumGenerator(codegenContext, writer, shape).render() + ServerEnumGenerator( + codegenContext, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render(writer) writer.compileAndTest( """ - // check no unknown + // Check no `Unknown` variant. let instance = InstanceType::T2Micro; match instance { InstanceType::T2Micro => (), @@ -68,7 +77,11 @@ class ServerEnumGeneratorTest { @Test fun `it generates enums without non_exhaustive`() { - ServerEnumGenerator(codegenContext, writer, shape).render() + ServerEnumGenerator( + codegenContext, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render(writer) writer.toString() shouldNotContain "#[non_exhaustive]" } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGeneratorTest.kt index 9a8ae01cabf..a90acec6e48 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerHttpSensitivityGeneratorTest.kt @@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.testModule import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -57,7 +58,7 @@ class ServerHttpSensitivityGeneratorTest { assertEquals(listOf("query_b"), (querySensitivity as QuerySensitivity.NotSensitiveMapValue).queryKeys) val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - testProject.lib { + testProject.testModule { unitTest("query_closure") { rustTemplate( """ @@ -70,6 +71,7 @@ class ServerHttpSensitivityGeneratorTest { ) } } + testProject.compileAndTest() } @@ -104,7 +106,7 @@ class ServerHttpSensitivityGeneratorTest { querySensitivity as QuerySensitivity.SensitiveMapValue val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - testProject.lib { + testProject.testModule { unitTest("query_params_closure") { rustTemplate( """ @@ -152,7 +154,7 @@ class ServerHttpSensitivityGeneratorTest { assert((querySensitivity as QuerySensitivity.NotSensitiveMapValue).queryKeys.isEmpty()) val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - testProject.lib { + testProject.testModule { unitTest("query_params_special_closure") { rustTemplate( """ @@ -200,7 +202,7 @@ class ServerHttpSensitivityGeneratorTest { querySensitivity as QuerySensitivity.SensitiveMapValue val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - testProject.lib { + testProject.testModule { unitTest("query_params_special_closure") { rustTemplate( """ @@ -277,7 +279,7 @@ class ServerHttpSensitivityGeneratorTest { assertEquals(null, (headerData as HeaderSensitivity.NotSensitiveMapValue).prefixHeader) val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - testProject.lib { + testProject.testModule { unitTest("header_closure") { rustTemplate( """ @@ -325,7 +327,7 @@ class ServerHttpSensitivityGeneratorTest { assertEquals("prefix-", (headerData as HeaderSensitivity.SensitiveMapValue).prefixHeader) val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - testProject.lib { + testProject.testModule { unitTest("prefix_headers_closure") { rustTemplate( """ @@ -408,7 +410,7 @@ class ServerHttpSensitivityGeneratorTest { assertEquals("prefix-", asMapValue.prefixHeader) val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - testProject.lib { + testProject.testModule { unitTest("prefix_headers_special_closure") { rustTemplate( """ @@ -462,7 +464,7 @@ class ServerHttpSensitivityGeneratorTest { assert(!asSensitiveMapValue.keySensitive) val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - testProject.lib { + testProject.testModule { unitTest("prefix_headers_special_closure") { rustTemplate( """ @@ -514,7 +516,7 @@ class ServerHttpSensitivityGeneratorTest { assertEquals(listOf(1, 2), labelData.labelIndexes) val testProject = TestWorkspace.testProject(serverTestSymbolProvider(model)) - testProject.lib { + testProject.testModule { unitTest("uri_closure") { rustTemplate( """ diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt index 1bfbd26e2f5..eef80be1c4b 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt @@ -13,7 +13,6 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace @@ -21,8 +20,10 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext @@ -123,7 +124,7 @@ class ServerInstantiatorTest { }, ]) string NamedEnum - """.asSmithyModel().let { RecursiveShapeBoxer.transform(it) } + """.asSmithyModel().let { RecursiveShapeBoxer().transform(it) } private val codegenContext = serverTestCodegenContext(model) private val symbolProvider = codegenContext.symbolProvider @@ -138,45 +139,50 @@ class ServerInstantiatorTest { val data = Node.parse("{}") val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { - structure.serverRenderWithModelBuilder(model, symbolProvider, this) - inner.serverRenderWithModelBuilder(model, symbolProvider, this) - nestedStruct.serverRenderWithModelBuilder(model, symbolProvider, this) + + project.withModule(ServerRustModule.Model) { + structure.serverRenderWithModelBuilder(project, model, symbolProvider, this) + inner.serverRenderWithModelBuilder(project, model, symbolProvider, this) + nestedStruct.serverRenderWithModelBuilder(project, model, symbolProvider, this) UnionGenerator(model, symbolProvider, this, union).render() - unitTest("server_instantiator_test") { - withBlock("let result = ", ";") { - sut.render(this, structure, data) - } + withInlineModule(RustModule.inlineTests()) { + unitTest("server_instantiator_test") { + withBlock("let result = ", ";") { + sut.render(this, structure, data) + } - rust( - """ - use std::collections::HashMap; - use aws_smithy_types::{DateTime, Document}; - - let expected = MyStructRequired { - str: "".to_owned(), - primitive_int: 0, - int: 0, - ts: DateTime::from_secs(0), - byte: 0, - union: NestedUnion::Struct(NestedStruct { - str: "".to_owned(), - num: 0, - }), - structure: NestedStruct { + rust( + """ + use std::collections::HashMap; + use aws_smithy_types::{DateTime, Document}; + use super::*; + + let expected = MyStructRequired { str: "".to_owned(), - num: 0, - }, - list: Vec::new(), - map: HashMap::new(), - doc: Document::Object(HashMap::new()), - }; - assert_eq!(result, expected); - """, - ) + primitive_int: 0, + int: 0, + ts: DateTime::from_secs(0), + byte: 0, + union: NestedUnion::Struct(NestedStruct { + str: "".to_owned(), + num: 0, + }), + structure: NestedStruct { + str: "".to_owned(), + num: 0, + }, + list: Vec::new(), + map: HashMap::new(), + doc: Document::Object(HashMap::new()), + }; + assert_eq!(result, expected); + """, + ) + } } } + project.renderInlineMemoryModules() project.compileAndTest() } @@ -187,8 +193,12 @@ class ServerInstantiatorTest { val data = Node.parse("t2.nano".dq()) val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { - EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + project.withModule(ServerRustModule.Model) { + ServerEnumGenerator( + codegenContext, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render(this) unitTest("generate_named_enums") { withBlock("let result = ", ";") { sut.render(this, shape, data) @@ -206,8 +216,12 @@ class ServerInstantiatorTest { val data = Node.parse("t2.nano".dq()) val project = TestWorkspace.testProject() - project.withModule(RustModule.Model) { - EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + project.withModule(ServerRustModule.Model) { + ServerEnumGenerator( + codegenContext, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render(this) unitTest("generate_unnamed_enums") { withBlock("let result = ", ";") { sut.render(this, shape, data) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGeneratorTest.kt index 32f565feba1..5851344d964 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationErrorGeneratorTest.kt @@ -7,13 +7,14 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.smithy.ErrorsModule import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider @@ -51,16 +52,14 @@ class ServerOperationErrorGeneratorTest { @Test fun `generates combined error enums`() { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(ErrorsModule) { + project.withModule(ServerRustModule.Error) { listOf("FooException", "ComplexError", "InvalidGreeting", "Deprecated").forEach { - model.lookup("error#$it").serverRenderWithModelBuilder(model, symbolProvider, this) + model.lookup("error#$it").serverRenderWithModelBuilder(project, model, symbolProvider, this) } - val errors = listOf("FooException", "ComplexError", "InvalidGreeting").map { model.lookup("error#$it") } ServerOperationErrorGenerator( model, symbolProvider, - symbolProvider.toSymbol(model.lookup("error#Greeting")), - errors, + model.lookup("error#Greeting"), ).render(this) unitTest( @@ -96,7 +95,7 @@ class ServerOperationErrorGeneratorTest { let error: GreetingError = variant.into(); """, ) - + project.renderInlineMemoryModules() project.compileAndTest() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt index 6a02765c9a1..28a8c1ef4c6 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedCollectionGeneratorTest.kt @@ -8,14 +8,15 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule -import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext @@ -50,25 +51,35 @@ class UnconstrainedCollectionGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(ModelsModule) { - model.lookup("test#StructureC").serverRenderWithModelBuilder(model, symbolProvider, this) + project.withModule(ServerRustModule.Model) { + model.lookup("test#StructureC").serverRenderWithModelBuilder(project, model, symbolProvider, this) } - project.withModule(ConstrainedModule) { + project.withModule(ServerRustModule.ConstrainedModule) { listOf(listA, listB).forEach { - PubCrateConstrainedCollectionGenerator(codegenContext, this, it).render() + PubCrateConstrainedCollectionGenerator( + codegenContext, + this.createTestInlineModuleCreator(), + it, + ).render() } } - project.withModule(UnconstrainedModule) unconstrainedModuleWriter@{ - project.withModule(ModelsModule) modelsModuleWriter@{ + project.withModule(ServerRustModule.UnconstrainedModule) unconstrainedModuleWriter@{ + project.withModule(ServerRustModule.Model) modelsModuleWriter@{ listOf(listA, listB).forEach { UnconstrainedCollectionGenerator( codegenContext, - this@unconstrainedModuleWriter, + this@unconstrainedModuleWriter.createTestInlineModuleCreator(), it, ).render() - CollectionConstraintViolationGenerator(codegenContext, this@modelsModuleWriter, it, listOf()).render() + CollectionConstraintViolationGenerator( + codegenContext, + this@modelsModuleWriter.createTestInlineModuleCreator(), + it, + CollectionTraitInfo.fromShape(it, codegenContext.constrainedShapeSymbolProvider), + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() } this@unconstrainedModuleWriter.unitTest( @@ -121,6 +132,7 @@ class UnconstrainedCollectionGeneratorTest { ) } } + project.renderInlineMemoryModules() project.compileAndTest() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt index 75e176be39d..dbb52b85a77 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedMapGeneratorTest.kt @@ -8,14 +8,17 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.smithy.ConstrainedModule -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule -import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule +import software.amazon.smithy.rust.codegen.core.smithy.CoreCodegenConfig import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Model +import software.amazon.smithy.rust.codegen.server.smithy.createTestInlineModuleCreator +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext @@ -50,23 +53,35 @@ class UnconstrainedMapGeneratorTest { val mapA = model.lookup("test#MapA") val mapB = model.lookup("test#MapB") - val project = TestWorkspace.testProject(symbolProvider, debugMode = true) + val project = TestWorkspace.testProject(symbolProvider, CoreCodegenConfig(debugMode = true)) - project.withModule(ModelsModule) { - model.lookup("test#StructureC").serverRenderWithModelBuilder(model, symbolProvider, this) + project.withModule(Model) { + model.lookup("test#StructureC").serverRenderWithModelBuilder(project, model, symbolProvider, this) } - project.withModule(ConstrainedModule) { + project.withModule(ServerRustModule.ConstrainedModule) { listOf(mapA, mapB).forEach { - PubCrateConstrainedMapGenerator(codegenContext, this, it).render() + PubCrateConstrainedMapGenerator( + codegenContext, + this.createTestInlineModuleCreator(), + it, + ).render() } } - project.withModule(UnconstrainedModule) unconstrainedModuleWriter@{ - project.withModule(ModelsModule) modelsModuleWriter@{ + project.withModule(ServerRustModule.UnconstrainedModule) unconstrainedModuleWriter@{ + project.withModule(Model) modelsModuleWriter@{ listOf(mapA, mapB).forEach { - UnconstrainedMapGenerator(codegenContext, this@unconstrainedModuleWriter, it).render() - - MapConstraintViolationGenerator(codegenContext, this@modelsModuleWriter, it).render() + UnconstrainedMapGenerator( + codegenContext, + this@unconstrainedModuleWriter.createTestInlineModuleCreator(), it, + ).render() + + MapConstraintViolationGenerator( + codegenContext, + this@modelsModuleWriter.createTestInlineModuleCreator(), + it, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render() } this@unconstrainedModuleWriter.unitTest( @@ -159,7 +174,7 @@ class UnconstrainedMapGeneratorTest { ) } } - + project.renderInlineMemoryModules() project.compileAndTest() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt index 4b5eca2d1e4..bf1904757e6 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/UnconstrainedUnionGeneratorTest.kt @@ -8,14 +8,15 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import org.junit.jupiter.api.Test import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape -import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule -import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule +import software.amazon.smithy.rust.codegen.server.smithy.createInlineModuleCreator +import software.amazon.smithy.rust.codegen.server.smithy.renderInlineMemoryModules import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext @@ -42,16 +43,17 @@ class UnconstrainedUnionGeneratorTest { val project = TestWorkspace.testProject(symbolProvider) - project.withModule(ModelsModule) { - model.lookup("test#Structure").serverRenderWithModelBuilder(model, symbolProvider, this) + project.withModule(ServerRustModule.Model) { + model.lookup("test#Structure").serverRenderWithModelBuilder(project, model, symbolProvider, this) } - project.withModule(ModelsModule) { + project.withModule(ServerRustModule.Model) { UnionGenerator(model, symbolProvider, this, unionShape, renderUnknownVariant = false).render() } - project.withModule(UnconstrainedModule) unconstrainedModuleWriter@{ - project.withModule(ModelsModule) modelsModuleWriter@{ - UnconstrainedUnionGenerator(codegenContext, this@unconstrainedModuleWriter, this@modelsModuleWriter, unionShape).render() + + project.withModule(ServerRustModule.UnconstrainedModule) unconstrainedModuleWriter@{ + project.withModule(ServerRustModule.Model) modelsModuleWriter@{ + UnconstrainedUnionGenerator(codegenContext, project.createInlineModuleCreator(), this@modelsModuleWriter, unionShape).render() this@unconstrainedModuleWriter.unitTest( name = "unconstrained_union_fail_to_constrain", @@ -97,6 +99,7 @@ class UnconstrainedUnionGeneratorTest { ) } } + project.renderInlineMemoryModules() project.compileAndTest() } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt deleted file mode 100644 index e3ff924db85..00000000000 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamBaseRequirements.kt +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream - -import org.junit.jupiter.api.extension.ExtensionContext -import org.junit.jupiter.params.provider.Arguments -import org.junit.jupiter.params.provider.ArgumentsProvider -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes.ServiceShape -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements -import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig -import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator -import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGeneratorWithoutPublicConstrainedTypes -import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerOperationErrorGenerator -import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings -import java.util.stream.Stream - -data class TestCase( - val eventStreamTestCase: EventStreamTestModels.TestCase, - val publicConstrainedTypes: Boolean, -) { - override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes" -} - -class TestCasesProvider : ArgumentsProvider { - override fun provideArguments(context: ExtensionContext?): Stream = - EventStreamTestModels.TEST_CASES - .flatMap { testCase -> - listOf( - TestCase(testCase, publicConstrainedTypes = false), - TestCase(testCase, publicConstrainedTypes = true), - ) - }.map { Arguments.of(it) }.stream() -} - -abstract class ServerEventStreamBaseRequirements : EventStreamTestRequirements { - abstract val publicConstrainedTypes: Boolean - - override fun createCodegenContext( - model: Model, - serviceShape: ServiceShape, - protocolShapeId: ShapeId, - codegenTarget: CodegenTarget, - ): ServerCodegenContext = serverTestCodegenContext( - model, serviceShape, - serverTestRustSettings( - codegenConfig = ServerCodegenConfig(publicConstrainedTypes = publicConstrainedTypes), - ), - protocolShapeId, - ) - - override fun renderBuilderForShape( - writer: RustWriter, - codegenContext: ServerCodegenContext, - shape: StructureShape, - ) { - if (codegenContext.settings.codegenConfig.publicConstrainedTypes) { - ServerBuilderGenerator(codegenContext, shape).apply { - render(writer) - writer.implBlock(shape, codegenContext.symbolProvider) { - renderConvenienceMethod(writer) - } - } - } else { - ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape).apply { - render(writer) - writer.implBlock(shape, codegenContext.symbolProvider) { - renderConvenienceMethod(writer) - } - } - } - } - - override fun renderOperationError( - writer: RustWriter, - model: Model, - symbolProvider: RustSymbolProvider, - operationSymbol: Symbol, - errors: List, - ) { - ServerOperationErrorGenerator(model, symbolProvider, operationSymbol, errors).render(writer) - } -} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt index 860b451ee85..c4c15742e3a 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamMarshallerGeneratorTest.kt @@ -5,45 +5,43 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream +import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.EventStreamMarshallerGenerator -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety -import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject -import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig -import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamMarshallTestCases.writeMarshallTestCases +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestModels +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.util.stream.Stream class ServerEventStreamMarshallerGeneratorTest { @ParameterizedTest @ArgumentsSource(TestCasesProvider::class) fun test(testCase: TestCase) { - EventStreamTestTools.runTestCase( - testCase.eventStreamTestCase, - object : ServerEventStreamBaseRequirements() { - override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes - - override fun renderGenerator( - codegenContext: ServerCodegenContext, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType { - return EventStreamMarshallerGenerator( - project.model, - CodegenTarget.SERVER, - TestRuntimeConfig, - project.symbolProvider, - project.streamShape, - protocol.structuredDataSerializer(project.operationShape), - testCase.eventStreamTestCase.requestContentType, - ).render() - } - }, - CodegenTarget.SERVER, - EventStreamTestVariety.Marshall, - ) + serverIntegrationTest(testCase.eventStreamTestCase.model) { _, rustCrate -> + rustCrate.testModule { + writeMarshallTestCases(testCase.eventStreamTestCase, optionalBuilderInputs = true) + } + } } } + +data class TestCase( + val eventStreamTestCase: EventStreamTestModels.TestCase, + val publicConstrainedTypes: Boolean, +) { + override fun toString(): String = "$eventStreamTestCase, publicConstrainedTypes = $publicConstrainedTypes" +} + +class TestCasesProvider : ArgumentsProvider { + override fun provideArguments(context: ExtensionContext?): Stream = + EventStreamTestModels.TEST_CASES + .flatMap { testCase -> + listOf( + TestCase(testCase, publicConstrainedTypes = false), + TestCase(testCase, publicConstrainedTypes = true), + ) + }.map { Arguments.of(it) }.stream() +} diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt index 08a5ef5f588..7d88d00ee60 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/eventstream/ServerEventStreamUnmarshallerGeneratorTest.kt @@ -7,20 +7,10 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.eventstream import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource -import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget -import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock -import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol -import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.EventStreamUnmarshallerGenerator -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestTools -import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestVariety -import software.amazon.smithy.rust.codegen.core.testutil.TestEventStreamProject -import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.generators.serverBuilderSymbol +import software.amazon.smithy.rust.codegen.core.testutil.EventStreamUnmarshallTestCases.writeUnmarshallTestCases +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest class ServerEventStreamUnmarshallerGeneratorTest { @ParameterizedTest @@ -32,42 +22,16 @@ class ServerEventStreamUnmarshallerGeneratorTest { return } - EventStreamTestTools.runTestCase( - testCase.eventStreamTestCase, - object : ServerEventStreamBaseRequirements() { - override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes - - override fun renderGenerator( - codegenContext: ServerCodegenContext, - project: TestEventStreamProject, - protocol: Protocol, - ): RuntimeType { - fun builderSymbol(shape: StructureShape): Symbol = shape.serverBuilderSymbol(codegenContext) - return EventStreamUnmarshallerGenerator( - protocol, - codegenContext, - project.operationShape, - project.streamShape, - ::builderSymbol, - ).render() - } - - // TODO(https://github.com/awslabs/smithy-rs/issues/1442): Delete this function override to use the correct builder from the parent class - override fun renderBuilderForShape( - writer: RustWriter, - codegenContext: ServerCodegenContext, - shape: StructureShape, - ) { - BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape).apply { - render(writer) - writer.implBlock(shape, codegenContext.symbolProvider) { - renderConvenienceMethod(writer) - } - } - } - }, - CodegenTarget.SERVER, - EventStreamTestVariety.Unmarshall, - ) + serverIntegrationTest( + testCase.eventStreamTestCase.model, + IntegrationTestParams(service = "test#TestService", addModuleToEventStreamAllowList = true), + ) { _, rustCrate -> + rustCrate.testModule { + writeUnmarshallTestCases( + testCase.eventStreamTestCase, + optionalBuilderInputs = true, + ) + } + } } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt new file mode 100644 index 00000000000..622d3806039 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/transformers/RecursiveConstraintViolationBoxerTest.kt @@ -0,0 +1,31 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.transformers + +import io.kotest.matchers.shouldBe +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import software.amazon.smithy.rust.codegen.core.util.hasTrait +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.RecursiveConstraintViolationsTest +import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait +import kotlin.streams.toList + +internal class RecursiveConstraintViolationBoxerTest { + @ParameterizedTest + @ArgumentsSource(RecursiveConstraintViolationsTest.RecursiveConstraintViolationsTestProvider::class) + fun `recursive constraint violation boxer test`(testCase: RecursiveConstraintViolationsTest.TestCase) { + val transformed = RecursiveConstraintViolationBoxer.transform(testCase.model) + + val shapesWithConstraintViolationRustBoxTrait = transformed.shapes().filter { + it.hasTrait() + }.toList() + + // Only the provided member shape should have the trait attached. + shapesWithConstraintViolationRustBoxTrait shouldBe + listOf(transformed.lookup(testCase.shapeIdWithConstraintViolationRustBoxTrait)) + } +} diff --git a/design/src/SUMMARY.md b/design/src/SUMMARY.md index 381b27987a2..ae53de0a473 100644 --- a/design/src/SUMMARY.md +++ b/design/src/SUMMARY.md @@ -54,6 +54,8 @@ - [RFC-0029: Finding New Home for Credential Types](./rfcs/rfc0029_new_home_for_cred_types.md) - [RFC-0030: Serialization And Deserialization](./rfcs/rfc0030_serialization_and_deserialization.md) - [RFC-0031: Providing Fallback Credentials on Timeout](./rfcs/rfc0031_providing_fallback_credentials_on_timeout.md) + - [RFC-0032: Better Constraint Violations](./rfcs/rfc0032_better_constraint_violations.md) + - [RFC-0033: Improving access to request IDs in SDK clients](./rfcs/rfc0033_improve_sdk_request_id_access.md) - [Contributing](./contributing/overview.md) - [Writing and debugging a low-level feature that relies on HTTP](./contributing/writing_and_debugging_a_low-level_feature_that_relies_on_HTTP.md) diff --git a/design/src/rfcs/overview.md b/design/src/rfcs/overview.md index 4109cd56fbd..e388ec49299 100644 --- a/design/src/rfcs/overview.md +++ b/design/src/rfcs/overview.md @@ -41,3 +41,5 @@ - [RFC-0029: Finding New Home for Credential Types](./rfc0029_new_home_for_cred_types.md) - [RFC-0030: Serialization And Deserialization](./rfc0030_serialization_and_deserialization.md) - [RFC-0031: Providing Fallback Credentials on Timeout](./rfc0031_providing_fallback_credentials_on_timeout.md) +- [RFC-0032: Better Constraint Violations](./rfc0032_better_constraint_violations.md) +- [RFC-0033: Improving access to request IDs in SDK clients](./rfc0033_improve_sdk_request_id_access.md) diff --git a/design/src/rfcs/rfc0020_service_builder.md b/design/src/rfcs/rfc0020_service_builder.md index a0fad97e305..35bf97f1f9f 100644 --- a/design/src/rfcs/rfc0020_service_builder.md +++ b/design/src/rfcs/rfc0020_service_builder.md @@ -862,4 +862,6 @@ A toy implementation of the combined proposal is presented in [this PR](https:// - - [x] Add code generation which outputs new service builder. - -- [ ] Deprecate `OperationRegistryBuilder`, `OperationRegistry` and `Router`. +- [x] Deprecate `OperationRegistryBuilder`, `OperationRegistry` and `Router`. + - + - diff --git a/design/src/rfcs/rfc0023_refine_builder.md b/design/src/rfcs/rfc0023_refine_builder.md index fb519ca4c7e..d93cd6bf77c 100644 --- a/design/src/rfcs/rfc0023_refine_builder.md +++ b/design/src/rfcs/rfc0023_refine_builder.md @@ -781,9 +781,12 @@ The API proposed in this RFC has been manually implemented for the Pokemon servi ## Changes checklist -- [ ] Update `codegen-server` to generate the proposed service builder API -- [ ] Implement `Pluggable` for `PluginStack` -- [ ] Evaluate the introduction of a `PluginBuilder` as the primary API to compose multiple plugins (instead of `PluginStack::new(IdentityPlugin, IdentityPlugin).apply(...)`) +- [x] Update `codegen-server` to generate the proposed service builder API + - +- [x] Implement `Pluggable` for `PluginStack` + - +- [x] Evaluate the introduction of a `PluginBuilder` as the primary API to compose multiple plugins (instead of `PluginStack::new(IdentityPlugin, IdentityPlugin).apply(...)`) + - [RFC 20]: rfc0020_service_builder.md [Pokemon service]: https://github.com/awslabs/smithy-rs/blob/c7ddb164b28b920313432789cfe05d8112a035cc/codegen-core/common-test-models/pokemon.smithy diff --git a/design/src/rfcs/rfc0032_better_constraint_violations.md b/design/src/rfcs/rfc0032_better_constraint_violations.md new file mode 100644 index 00000000000..d8648df3166 --- /dev/null +++ b/design/src/rfcs/rfc0032_better_constraint_violations.md @@ -0,0 +1,825 @@ +RFC: Better Constraint Violations +================================= + +> Status: Accepted +> +> Applies to: server + +During and after [the design][constraint-traits-rfc] and [the core +implementation][builders-of-builders-pr] of [constraint traits] in the server +SDK, some problems relating to constraint violations were identified. This RFC +sets out to explain and address three of them: [impossible constraint +violations](#impossible-constraint-violations), [collecting constraint +violations](#collecting-constraint-violations), and ["tightness" of constraint +violations](#tightness-of-constraint-violations). The RFC explains each of them +in turn, solving them in an iterative and pedagogical manner, i.e. the solution +of a problem depends on the previous ones having been solved with their +proposed solutions. The three problems are meant to be addressed atomically in +one changeset (see the [Checklist](#checklist)) section. + +Note: code snippets from generated SDKs in this document are abridged so as to +be didactic and relevant to the point being made. They are accurate with +regards to commit [`2226fe`]. + +[constraint-traits-rfc]: https://github.com/awslabs/smithy-rs/pull/1199 +[builders-of-builders-pr]: https://github.com/awslabs/smithy-rs/pull/1342 +[`2226fe`]: https://github.com/awslabs/smithy-rs/tree/2226feff8f7fa884204f81a50d7e016386912acc +[constraint traits]: https://awslabs.github.io/smithy/2.0/spec/constraint-traits.html + +Terminology +----------- + +[The design][constraint-traits-rfc] and the description of [the +PR][builders-of-builders-pr] where the core implementation of constraint traits +was made are recommended prior reading to understand this RFC. + +- **Shape closure**: the set of shapes a shape can "reach", including itself. +- **Transitively constrained shape**: a shape whose closure includes: + 1. a shape with a [constraint trait][constraint traits] attached, + 2. a (member) shape with a [`required` trait] attached, + 3. an [`enum` shape]; or + 4. an [`intEnum` shape]. +- A **directly constrained shape** is any of these: + 1. a shape with a [constraint trait][constraint traits] attached, + 2. a (member) shape with a [`required` trait] attached, + 3. an [`enum` shape], + 4. an [`intEnum` shape]; or + 5. a [`structure` shape] with at least one `required` member shape. +- **Constrained type**: the Rust type a constrained shape gets rendered as. For + shapes that are not `structure`, `union`, `enum` or `intEnum` shapes, these + are wrapper [newtype]s. + +[`required` trait]: https://smithy.io/2.0/spec/type-refinement-traits.html#required-trait +[`enum` shape]: https://smithy.io/2.0/spec/simple-types.html#enum +[`intEnum` shape]: https://smithy.io/2.0/spec/simple-types.html#intenum +[`structure` shape]: https://smithy.io/2.0/spec/aggregate-types.html#structure +[newtype]: https://doc.rust-lang.org/rust-by-example/generics/new_types.html + +In the absence of a qualifier, "constrained shape" should be interpreted as +"transitively constrained shape". + +Impossible constraint violations +-------------------------------- + +### Background + +A constrained type has a fallible constructor by virtue of it implementing the +[`TryFrom`] trait. The error type this constructor may yield is known as a +**constraint violation**: + +```rust +impl TryFrom for ConstrainedType { + type Error = ConstraintViolation; + + fn try_from(value: UnconstrainedType) -> Result { + ... + } +} +``` + +The `ConstraintViolation` type is a Rust `enum` with one variant per way +"constraining" the input value may fail. So, for example, the following Smithy +model: + +```smithy +structure A { + @required + member: String, +} +``` + +Yields: + +```rust +/// See [`A`](crate::model::A). +pub mod a { + #[derive(std::cmp::PartialEq, std::fmt::Debug)] + /// Holds one variant for each of the ways the builder can fail. + pub enum ConstraintViolation { + /// `member` was not provided but it is required when building `A`. + MissingMember, + } +} +``` + +Constraint violations are always Rust `enum`s, even if they only have one +variant. + +Constraint violations can occur in application code: + +```rust +use my_server_sdk::model + +let res = model::a::Builder::default().build(); // We forgot to set `member`. + +match res { + Ok(a) => { ... }, + Err(e) => { + assert_eq!(model::a::ConstraintViolation::MissingMember, e); + } +} +``` + +[`TryFrom`]: https://doc.rust-lang.org/std/convert/trait.TryFrom.html + +### Problem + +Currently, the constraint violation types we generate are used by _both_: + +1. the server framework upon request deserialization; and +2. by users in application code. + +However, the kinds of constraint violations that can occur in application code +can sometimes be a _strict subset_ of those that can occur during request +deserialization. + +Consider the following model: + +```smithy +@length(min: 1, max: 69) +map LengthMap { + key: String, + value: LengthString +} + +@length(min: 2, max: 69) +string LengthString +``` + +This produces: + +```rust +pub struct LengthMap( + pub(crate) std::collections::HashMap, +); + +impl + std::convert::TryFrom< + std::collections::HashMap, + > for LengthMap +{ + type Error = crate::model::length_map::ConstraintViolation; + + /// Constructs a `LengthMap` from an + /// [`std::collections::HashMap`], failing when the provided value does not + /// satisfy the modeled constraints. + fn try_from( + value: std::collections::HashMap, + ) -> Result { + let length = value.len(); + if (1..=69).contains(&length) { + Ok(Self(value)) + } else { + Err(crate::model::length_map::ConstraintViolation::Length(length)) + } + } +} + +pub mod length_map { + pub enum ConstraintViolation { + Length(usize), + Value( + std::string::String, + crate::model::length_string::ConstraintViolation, + ), + } + ... +} +``` + +Observe how the `ConstraintViolation::Value` variant is never constructed. +Indeed, this variant is impossible to be constructed _in application code_: a +user has to provide a map whose values are already constrained `LengthString`s +to the `try_from` constructor, which only enforces the map's `@length` trait. + +The reason why these seemingly "impossible violations" are being generated is +because they can arise during request deserialization. Indeed, the server +framework deserializes requests into **fully unconstrained types**. These are +types holding unconstrained types all the way through their closures. For +instance, in the case of structure shapes, builder types (the unconstrained +type corresponding to the structure shape) [hold +builders][builders-of-builders-pr] all the way down. + +In the case of the above model, below is the alternate `pub(crate)` constructor +the server framework uses upon deserialization. Observe how +`LengthMapOfLengthStringsUnconstrained` is _fully unconstrained_ and how the +`try_from` constructor can yield `ConstraintViolation::Value`. + +```rust +pub(crate) mod length_map_of_length_strings_unconstrained { + #[derive(Debug, Clone)] + pub(crate) struct LengthMapOfLengthStringsUnconstrained( + pub(crate) std::collections::HashMap, + ); + + impl std::convert::TryFrom + for crate::model::LengthMapOfLengthStrings + { + type Error = crate::model::length_map_of_length_strings::ConstraintViolation; + fn try_from(value: LengthMapOfLengthStringsUnconstrained) -> Result { + let res: Result< + std::collections::HashMap, + Self::Error, + > = value + .0 + .into_iter() + .map(|(k, v)| { + let v: crate::model::LengthString = k.try_into().map_err(Self::Error::Key)?; + + Ok((k, v)) + }) + .collect(); + let hm = res?; + Self::try_from(hm) + } + } +} +``` + +In conclusion, the user is currently exposed to an internal detail of how the +framework operates that has no bearing on their application code. They +shouldn't be exposed to impossible constraint violation variants in their Rust +docs, nor have to `match` on these variants when handling errors. + +Note: [this comment] alludes to the problem described above. + +[this comment]: https://github.com/awslabs/smithy-rs/blob/27020be3421fb93e35692803f9a795f92feb1d19/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/MapConstraintViolationGenerator.kt#L66-L69 + +### Solution proposal + +The problem can be mitigated by adding `#[doc(hidden)]` to the internal +variants and `#[non_exhaustive]` to the enum. We're already doing this in some +constraint violation types. + +However, a "less leaky" solution is achieved by _splitting_ the constraint +violation type into two types, which this RFC proposes: + +1. one for use by the framework, with `pub(crate)` visibility, named + `ConstraintViolationException`; and +2. one for use by user application code, with `pub` visibility, named + `ConstraintViolation`. + +```rust +pub mod length_map { + pub enum ConstraintViolation { + Length(usize), + } + pub (crate) enum ConstraintViolationException { + Length(usize), + Value( + std::string::String, + crate::model::length_string::ConstraintViolation, + ), + } +} +``` + +Note that, to some extent, the spirit of this approach is [already currently +present](https://github.com/awslabs/smithy-rs/blob/9a4c1f304f6f5237d480cfb56dad2951d927d424/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt#L78-L81) +in the case of builder types when `publicConstrainedTypes` is set to `false`: + +1. [`ServerBuilderGenerator.kt`] renders the usual builder type that enforces + constraint traits, setting its visibility to `pub (crate)`, for exclusive + use by the framework. +2. [`ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt`] renders the + builder type the user is exposed to: this builder does not take in + constrained types and does not enforce all modeled constraints. + +[`ServerBuilderGenerator.kt`]: https://github.com/awslabs/smithy-rs/blob/2226feff8f7fa884204f81a50d7e016386912acc/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt +[`ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt`]: https://github.com/awslabs/smithy-rs/blob/2226feff8f7fa884204f81a50d7e016386912acc/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt + +Collecting constraint violations +-------------------------------- + +### Background + +Constrained operations are currently required to have +`smithy.framework#ValidationException` as a member in their [`errors` +property](https://smithy.io/2.0/spec/service-types.html#operation). This is the +shape that is rendered in responses when a request contains data that violates +the modeled constraints. + +The shape is defined in the +[`smithy-validation-model`](https://search.maven.org/artifact/software.amazon.smithy/smithy-validation-model) +Maven package, [as +follows](https://github.com/awslabs/smithy/blob/main/smithy-validation-model/model/smithy.framework.validation.smithy): + +```smithy +$version: "2.0" + +namespace smithy.framework + +/// A standard error for input validation failures. +/// This should be thrown by services when a member of the input structure +/// falls outside of the modeled or documented constraints. +@error("client") +structure ValidationException { + + /// A summary of the validation failure. + @required + message: String, + + /// A list of specific failures encountered while validating the input. + /// A member can appear in this list more than once if it failed to satisfy multiple constraints. + fieldList: ValidationExceptionFieldList +} + +/// Describes one specific validation failure for an input member. +structure ValidationExceptionField { + /// A JSONPointer expression to the structure member whose value failed to satisfy the modeled constraints. + @required + path: String, + + /// A detailed description of the validation failure. + @required + message: String +} + +list ValidationExceptionFieldList { + member: ValidationExceptionField +} +``` + +It was mentioned in the [constraint traits +RFC](https://github.com/awslabs/smithy-rs/pull/1199#discussion_r809300673), and +implicit in the definition of Smithy's +[`smithy.framework.ValidationException`](https://github.com/awslabs/smithy/blob/main/smithy-validation-model/model/smithy.framework.validation.smithy) +shape, that server frameworks should respond with a _complete_ collection of +errors encountered during constraint trait enforcement to the client. + +### Problem + +As of writing, the `TryFrom` constructor of constrained types whose shapes have +more than one constraint trait attached can only yield a single error. For +example, the following shape: + +```smithy +@pattern("[a-f0-5]*") +@length(min: 5, max: 10) +string LengthPatternString +``` + +Yields: + +```rust +pub struct LengthPatternString(pub(crate) std::string::String); + +impl LengthPatternString { + fn check_length( + string: &str, + ) -> Result<(), crate::model::length_pattern_string::ConstraintViolation> { + let length = string.chars().count(); + + if (5..=10).contains(&length) { + Ok(()) + } else { + Err(crate::model::length_pattern_string::ConstraintViolation::Length(length)) + } + } + + fn check_pattern( + string: String, + ) -> Result { + let regex = Self::compile_regex(); + + if regex.is_match(&string) { + Ok(string) + } else { + Err(crate::model::length_pattern_string::ConstraintViolation::Pattern(string)) + } + } + + pub fn compile_regex() -> &'static regex::Regex { + static REGEX: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { + regex::Regex::new(r#"[a-f0-5]*"#).expect(r#"The regular expression [a-f0-5]* is not supported by the `regex` crate; feel free to file an issue under https://github.com/awslabs/smithy-rs/issues for support"#) + }); + + ®EX + } +} + +impl std::convert::TryFrom for LengthPatternString { + type Error = crate::model::length_pattern_string::ConstraintViolation; + + /// Constructs a `LengthPatternString` from an [`std::string::String`], + /// failing when the provided value does not satisfy the modeled constraints. + fn try_from(value: std::string::String) -> Result { + Self::check_length(&value)?; + + let value = Self::check_pattern(value)?; + + Ok(Self(value)) + } +} +``` + +Observe how a failure to adhere to the `@length` trait will short-circuit the +evaluation of the constructor, when the value could technically also not adhere +with the `@pattern` trait. + +Similarly, constrained structures fail upon encountering the first member that +violates a constraint. + +Additionally, _in framework request deserialization code_: + +- collections whose members are constrained fail upon encountering the first + member that violates the constraint, +- maps whose keys and/or values are constrained fail upon encountering the + first violation; and +- structures whose members are constrained fail upon encountering the first + member that violates the constraint, + +In summary, any shape that is transitively constrained yields types whose +constructors (both the internal one and the user-facing one) currently +short-circuit upon encountering the first violation. + +### Solution proposal + +The deserializing architecture lends itself to be easily refactored so that we +can collect constraint violations before returning them. Indeed, note that +deserializers enforce constraint traits in a two-step phase: first, the +_entirety_ of the unconstrained value is deserialized, _then_ constraint traits +are enforced by feeding the entire value to the `TryFrom` constructor. + +Let's consider a `ConstraintViolations` type (note the plural) that represents +a collection of constraint violations that can occur _within user application +code_. Roughly: + +```rust +pub ConstraintViolations(pub(crate) Vec); + +impl IntoIterator for ConstraintViolations { ... } + +impl std::convert::TryFrom for LengthPatternString { + type Error = ConstraintViolations; + + fn try_from(value: std::string::String) -> Result { + // Check constraints and collect violations. + ... + } +} +``` + +- The main reason for wrapping a vector in `ConstraintViolations` as opposed to + directly returning the vector is forwards-compatibility: we may want to + expand `ConstraintViolations` with conveniences. +- If the constrained type can only ever yield a single violation, we will + dispense with `ConstraintViolations` and keep directly returning the + `crate::model::shape_name::ConstraintViolation` type. + +We will analogously introduce a `ConstraintViolationExceptions` type that +represents a collection of constraint violations that can occur _within the +framework's request deserialization code_. This type will be `pub(crate)` and +will be the one the framework will map to Smithy's `ValidationException` that +eventually gets serialized into the response. + +#### Collecting constraint violations may constitute a DOS attack vector + +This is a problem that _already_ exists as of writing, but that collecting +constraint violations highlights, so it is a good opportunity, from a +pedagogical perspective, to explain it here. Consider the following model: + +```smithy +@length(max: 3) +list ListOfPatternStrings { + member: PatternString +} + +@pattern("expensive regex to evaluate") +string PatternString +``` + +Our implementation currently enforces constraints _from the leaf to the root_: +when enforcing the `@length` constraint, the `TryFrom` constructor the server +framework uses gets a `Vec` and _first_ checks the members adhere to +the `@pattern` trait, and only _after_ is the `@length` trait checked. This +means that if a client sends a request with `n >>> 3` list members, the +expensive check runs `n` times, when a constant-time check inspecting the +length of the input vector would have sufficed to reject the request. +Additionally, we may want to avoid serializing `n` `ValidationExceptionField`s +due to performance concerns. + +1. A possibility to circumvent this is making the `@length` validator special, + having it bound the other validators via effectively permuting the order of + the checks and thus short-circuiting. + * In general, it's unclear what constraint traits should cause + short-circuiting. A probably reasonable rule of thumb is to include + traits that can be attached directly to aggregate shapes: as of writing, + that would be `@uniqueItems` on list shapes and `@length` on list shapes. +1. Another possiblity is to _do nothing_ and value _complete_ validation + exception response messages over trying to mitigate this with special + handling. One could argue that these kind of DOS attack vectors should be + taken care of with a separate solution e.g. a layer that bounds a request + body's size to a reasonable default (see [how Axum added + this](https://github.com/tokio-rs/axum/pull/1420)). We will provide a similar + request body limiting mechanism regardless. + +This RFC advocates for implementing the first option, arguing that [it's fair +to say that the framework should return an error that is as informative as +possible, but it doesn't necessarily have to be +complete](https://github.com/awslabs/smithy-rs/pull/2040#discussion_r1036226762). +However, we will also write a layer, applied by default to all server SDKs, +that bounds a request body's size to a reasonable (yet high) default. Relying +on users to manually apply the layer is dangerous, since such a configuration +is [trivially +exploitable]. +Users can always manually apply the layer again to their resulting service if +they want to further restrict a request's body size. + +[trivially exploitable]: https://jfrog.com/blog/watch-out-for-dos-when-using-rusts-popular-hyper-package/ + +"Tightness" of constraint violations +------------------------------------ + +### Problem + +`ConstraintViolationExceptions` [is not +"tight"](https://www.ecorax.net/tightness/) in that there's nothing in the type +system that indicates to the user, when writing the custom validation error +mapping function, that the iterator will not return a sequence of +`ConstraintViolationException`s that is actually impossible to occur in +practice. + +Recall that `ConstraintViolationException`s are `enum`s that model both direct +constraint violations as well as transitive ones. For example, given the model: + +```smithy +@length(min: 1, max: 69) +map LengthMap { + key: String, + value: LengthString +} + +@length(min: 2, max: 69) +string LengthString +``` + +The corresponding `ConstraintViolationException` Rust type for the `LengthMap` +shape is: + +```rust +pub mod length_map { + pub enum ConstraintViolation { + Length(usize), + } + pub (crate) enum ConstraintViolationException { + Length(usize), + Value( + std::string::String, + crate::model::length_string::ConstraintViolationException, + ), + } +} +``` + +`ConstraintViolationExceptions` is just a container over this type: + +```rust +pub ConstraintViolationExceptions(pub(crate) Vec); + +impl IntoIterator for ConstraintViolationExceptions { ... } +``` + +There might be multiple map values that fail to adhere to the constraints in +`LengthString`, which would make the iterator yield multiple +`length_map::ConstraintViolationException::Value`s; however, at most one +`length_map::ConstraintViolationException::Length` can be yielded _in +practice_. This might be obvious to the service owner when inspecting the model +and the Rust docs, but it's not expressed in the type system. + +The above tightness problem has been formulated in terms of +`ConstraintViolationExceptions`, because the fact that +`ConstraintViolationExceptions` contain transitive constraint violations +highlights the tightness problem. Note, however, that **the tightness problem +also afflicts `ConstraintViolations`**. + +Indeed, consider the following model: + +```smithy +@pattern("[a-f0-5]*") +@length(min: 5, max: 10) +string LengthPatternString +``` + +This would yield: + +```rust +pub ConstraintViolations(pub(crate) Vec); + +impl IntoIterator for ConstraintViolations { ... } + +pub mod length_pattern_string { + pub enum ConstraintViolation { + Length(usize), + Pattern(String) + } +} + +impl std::convert::TryFrom for LengthPatternString { + type Error = ConstraintViolations; + + fn try_from(value: std::string::String) -> Result { + // Check constraints and collect violations. + ... + } +} +``` + +Observe how the iterator of an instance of +`ConstraintViolations`, +may, a priori, yield e.g. the +`length_pattern_string::ConstraintViolation::Length` variant twice, when it's +clear that the iterator should contain _at most one_ of each of +`length_pattern_string::ConstraintViolation`'s variants. + +### Final solution proposal + +We propose a tighter API design. + +1. We substitute `enum`s for `struct`s whose members are all `Option`al, + representing all the constraint violations that can occur. +1. For list shapes and map shapes: + 1. we implement `IntoIterator` on an additional `struct` `Members` + representing only the violations that can occur on the collection's + members. + 2. we add a _non_ `Option`-al field to the `struct` representing the + constraint violations of type `Members`. + +Let's walk through an example. Take the last model: + +```smithy +@pattern("[a-f0-5]*") +@length(min: 5, max: 10) +string LengthPatternString +``` + +This would yield, as per the first substitution: + +```rust +pub mod length_pattern_string { + pub struct ConstraintViolations { + pub length: Option, + pub pattern: Option, + } + + pub mod constraint_violation { + pub struct Length(usize); + pub struct Pattern(String); + } +} + +impl std::convert::TryFrom for LengthPatternString { + type Error = length_pattern_string::ConstraintViolations; + + // The error type returned by this constructor, `ConstraintViolations`, + // will always have _at least_ one member set. + fn try_from(value: std::string::String) -> Result { + // Check constraints and collect violations. + ... + } +} +``` + +We now expand the model to highlight the second step of the algorithm: + +```smithy +@length(min: 1, max: 69) +map LengthMap { + key: String, + value: LengthString +} +``` + +This gives us: + +```rust +pub mod length_map { + pub struct ConstraintViolations { + pub length: Option, + + // Would be `Option` in the case of an aggregate shape that is _not_ a + // list shape or a map shape. + pub member_violations: constraint_violation::Members, + } + + pub mod constraint_violation { + // Note that this could now live outside the `length_map` module and be + // reused across all `@length`-constrained shapes, if we expanded it with + // another `usize` indicating the _modeled_ value in the `@length` trait; by + // keeping it inside `length_map` we can hardcode that value in the + // implementation of e.g. error messages. + pub struct Length(usize); + + pub struct Members { + pub(crate) Vec + } + + pub struct Member { + // If the map's key shape were constrained, we'd have a `key` + // field here too. + + value: Option + } + + pub struct Value( + std::string::String, + crate::model::length_string::ConstraintViolation, + ); + + impl IntoIterator for Members { ... } + } +} +``` + +--- + +The above examples have featured the tight API design with +`ConstraintViolation`s. Of course, we will apply the same design in the case of +`ConstraintViolationException`s. For the sake of completeness, let's expand our +model yet again with a structure shape: + +```smithy +structure A { + @required + member: String, + + @required + length_map: LengthMap, +} +``` + +And this time let's feature _both_ the resulting +`ConstraintViolationExceptions` and `ConstraintViolations` types: + +```rust +pub mod a { + pub struct ConstraintViolationExceptions { + // All fields must be `Option`, despite the members being `@required`, + // since no violations for their values might have occurred. + + pub missing_member_exception: Option, + pub missing_length_map_exception: Option, + pub length_map_exceptions: Option, + } + + pub mod constraint_violation_exception { + pub struct MissingMember; + pub struct MissingLengthMap; + } + + pub struct ConstraintViolations { + pub missing_member: Option, + pub missing_length_map: Option, + } + + pub mod constraint_violation { + pub struct MissingMember; + pub struct MissingLengthMap; + } +} +``` + +As can be intuited, the only differences are that: + +* `ConstraintViolationExceptions` hold transitive violations while + `ConstraintViolations` only need to expose direct violations (as explained in + the [Impossible constraint violations](#impossible-constraint-violations) + section), +* `ConstraintViolationExceptions` have members suffixed with `_exception`, as + is the module name. + +Note that while the constraint violation (exception) type names are plural, the +module names are always singular. + +We also make a conscious decision of, in this case of structure shapes, making +the types of all members `Option`s, for simplicity. Another choice would have +been to make `length_map_exceptions` not `Option`-al, and, in the case where no +violations in `LengthMap` values occurred, set +`length_map::ConstraintViolations::length` to `None` and +`length_map::ConstraintViolations::member_violations` eventually reach an empty +iterator. However, it's best that we use the expressiveness of `Option`s at the +earliest ("highest" in the shape hierarchy) opportunity: if a member is `Some`, +it means it (eventually) reaches data. + +Checklist +--------- + +Unfortunately, while this RFC _could_ be implemented iteratively (i.e. solve +each of the problems in turn), it would introduce too much churn and throwaway +work: solving the tightness problem requires a more or less complete overhaul +of the constraint violations code generator. It's best that all three problems +be solved in the same changeset. + +- [ ] Generate `ConstraintViolations` and `ConstraintViolationExceptions` types + so as to not reify [impossible constraint + violations](#impossible-constraint-violations), add the ability to [collect + constraint + violations](#collecting-constraint-violations), and solve the ["tightness" problem of constraint violations](#tightness-of-constraint-violations). +- [ ] Special-case generated request deserialization code for operations + using `@length` and `@uniqueItems` constrained shapes whose closures reach + other constrained shapes so that the validators for these two traits + short-circuit upon encountering a number of inner constraint violations + above a certain threshold. +- [ ] Write and expose a layer, applied by default to all generated server SDKs, + that bounds a request body's size to a reasonable (yet high) default, to prevent [trivial DoS attacks][trivially exploitable]. diff --git a/design/src/rfcs/rfc0033_improve_sdk_request_id_access.md b/design/src/rfcs/rfc0033_improve_sdk_request_id_access.md new file mode 100644 index 00000000000..3e78f4070c7 --- /dev/null +++ b/design/src/rfcs/rfc0033_improve_sdk_request_id_access.md @@ -0,0 +1,245 @@ +RFC: Improving access to request IDs in SDK clients +=================================================== + +> Status: Implemented in [#2129](https://github.com/awslabs/smithy-rs/pull/2129) +> +> Applies to: AWS SDK clients + +At time of writing, customers can retrieve a request ID in one of four ways in the Rust SDK: + +1. For error cases where the response parsed successfully, the request ID can be retrieved + via accessor method on operation error. This also works for unmodeled errors so long as + the response parsing succeeds. +2. For error cases where a response was received but parsing fails, the response headers + can be retrieved from the raw response on the error, but customers have to manually extract + the request ID from those headers (there's no convenient accessor method). +3. For all error cases where the request ID header was sent in the response, customers can + call `SdkError::into_service_error` to transform the `SdkError` into an operation error, + which has a `request_id` accessor on it. +4. For success cases, the customer can't retrieve the request ID at all if they use the fluent + client. Instead, they must manually make the operation and call the underlying Smithy client + so that they have access to `SdkSuccess`, which provides the raw response where the request ID + can be manually extracted from headers. + +Only one of these mechanisms is convenient and ergonomic. The rest need considerable improvements. +Additionally, the request ID should be attached to tracing events where possible so that enabling +debug logging reveals the request IDs without any code changes being necessary. + +This RFC proposes changes to make the request ID easier to access. + +Terminology +----------- + +- **Request ID:** A unique identifier assigned to and associated with a request to AWS that is + sent back in the response headers. This identifier is useful to customers when requesting support. +- **Operation Error:** Operation errors are code generated for each operation in a Smithy model. + They are an enum of every possible modeled error that that operation can respond with, as well + as an `Unhandled` variant for any unmodeled or unrecognized errors. +- **Modeled Errors:** Any error that is represented in a Smithy model with the `@error` trait. +- **Unmodeled Errors:** Errors that a service responds with that do not appear in the Smithy model. +- **SDK Clients:** Clients generated for the AWS SDK, including "adhoc" or "one-off" clients. +- **Smithy Clients:** Any clients not generated for the AWS SDK, excluding "adhoc" or "one-off" clients. + +SDK/Smithy Purity +----------------- + +Before proposing any changes, the topic of purity needs to be covered. Request IDs are not +currently a Smithy concept. However, at time of writing, the request ID concept is leaked into +the non-SDK rust runtime crates and generated code via the [generic error] struct and the +`request_id` functions on generated operation errors (e.g., [`GetObjectError` example in S3]). + +This RFC attempts to remove these leaks from Smithy clients. + +Proposed Changes +---------------- + +First, we'll explore making it easier to retrieve a request ID from errors, +and then look at making it possible to retrieve them from successful responses. +To see the customer experience of these changes, see the **Example Interactions** +section below. + +### Make request ID retrieval on errors consistent + +One could argue that customers being able to convert a `SdkError` into an operation error +that has a request ID on it is sufficient. However, there's no way to write a function +that takes an error from any operation and logs a request ID, so it's still not ideal. + +The `aws-http` crate needs to have a `RequestId` trait on it to facilitate generic +request ID retrieval: + +```rust +pub trait RequestId { + /// Returns the request ID if it's available. + fn request_id(&self) -> Option<&str>; +} +``` + +This trait will be implemented for `SdkError` in `aws-http` where it is declared, +complete with logic to pull the request ID header out of the raw HTTP responses +(it will always return `None` for event stream `Message` responses; an additional +trait may need to be added to `aws-smithy-http` to facilitate access to the headers). +This logic will try different request ID header names in order of probability +since AWS services have a couple of header name variations. `x-amzn-requestid` is +the most common, with `x-amzn-request-id` being the second most common. + +`aws-http` will also implement `RequestId` for `aws_smithy_types::error::Error`, +and the `request_id` method will be removed from `aws_smithy_types::error::Error`. +Places that construct `Error` will place the request ID into its `extras` field, +where the `RequestId` trait implementation can retrieve it. + +A codegen decorator will be added to `sdk-codegen` to implement `RequestId` for +operation errors, and the existing `request_id` accessors will be removed from +`CombinedErrorGenerator` in `codegen-core`. + +With these changes, customers can directly access request IDs from `SdkError` and +operations errors by importing the `RequestId` trait. Additionally, the Smithy/SDK +purity is improved since both places where request IDs are leaked to Smithy clients +will be resolved. + +### Implement `RequestId` for outputs + +To make it possible to retrieve request IDs when using the fluent client, the new +`RequestId` trait can be implemented for outputs. + +Some services (e.g., Transcribe Streaming) model the request ID header in their +outputs, while other services (e.g., Directory Service) model a request ID +field on errors. In some cases, services take `RequestId` as a modeled input +(e.g., IoT Event Data). It follows that it is possible, but unlikely, that +a service could have a field named `RequestId` that is not the same concept +in the future. + +Thus, name collisions are going to be a concern for putting a request ID accessor +on output. However, if it is implemented as a trait, then this concern is partially +resolved. In the vast majority of cases, importing `RequestId` will provide the +accessor without any confusion. In cases where it is already modeled and is the +same concept, customers will likely just use it and not even realize they didn't +import the trait. The only concern is future cases where it is modeled as a +separate concept, and as long as customers don't import `RequestId` for something +else in the same file, that confusion can be avoided. + +In order to implement `RequestId` for outputs, either the original response needs +to be stored on the output, or the request ID needs to be extracted earlier and +stored on the output. The latter will lead to a small amount of header lookup +code duplication. + +In either case, the `StructureGenerator` needs to be customized in `sdk-codegen` +(Appendix B outlines an alternative approach to this and why it was dismissed). +This will be done by adding customization hooks to `StructureGenerator` similar +to the ones for `ServiceConfigGenerator` so that a `sdk-codegen` decorator can +conditionally add fields and functions to any generated structs. A hook will +also be needed to additional trait impl blocks. + +Once the hooks are in place, a decorator will be added to store either the original +response or the request ID on outputs, and then the `RequestId` trait will be +implemented for them. The `ParseResponse` trait implementation will be customized +to populate this new field. + +Note: To avoid name collisions of the request ID or response on the output struct, +these fields can be prefixed with an underscore. It shouldn't be possible for SDK +fields to code generate with this prefix given the model validation rules in place. + +### Implement `RequestId` for `Operation` and `operation::Response` + +In the case that a customer wants to ditch the fluent client, it should still +be easy to retrieve a request ID. To do this, `aws-http` will provide `RequestId` +implementations for `Operation` and `operation::Response`. These implementations +will likely make the other `RequestId` implementations easier to implement as well. + +### Implement `RequestId` for `Result` + +The `Result` returned by the SDK should directly implement `RequestId` when both +its `Ok` and `Err` variants implement `RequestId`. This will make it possible +for a customer to feed the return value from `send()` directly to a request ID logger. + +Example Interactions +-------------------- + +### Generic Handling Case + +```rust +// A re-export of the RequestId trait +use aws_sdk_service::primitives::RequestId; + +fn my_request_id_logging_fn(request_id: &dyn RequestId) { + println!("request ID: {:?}", request_id.request_id()); +} + +let result = client.some_operation().send().await?; +my_request_id_logging_fn(&result); +``` + +### Success Case + +```rust +use aws_sdk_service::primitives::RequestId; + +let output = client.some_operation().send().await?; +println!("request ID: {:?}", output.request_id()); +``` + +### Error Case with `SdkError` + +```rust +use aws_sdk_service::primitives::RequestId; + +match client.some_operation().send().await { + Ok(_) => { /* handle OK */ } + Err(err) => { + println!("request ID: {:?}", output.request_id()); + } +} +``` + +### Error Case with operation error + +```rust +use aws_sdk_service::primitives::RequestId; + +match client.some_operation().send().await { + Ok(_) => { /* handle OK */ } + Err(err) => match err.into_service_err() { + err @ SomeOperationError::SomeError(_) => { println!("request ID: {:?}", err.request_id()); } + _ => { /* don't care */ } + } +} +``` + +Changes Checklist +----------------- + +- [x] Create the `RequestId` trait in `aws-http` +- [x] Implement for errors + - [x] Implement `RequestId` for `SdkError` in `aws-http` + - [x] Remove `request_id` from `aws_smithy_types::error::Error`, and store request IDs in its `extras` instead + - [x] Implement `RequestId` for `aws_smithy_types::error::Error` in `aws-http` + - [x] Remove generation of `request_id` accessors from `CombinedErrorGenerator` in `codegen-core` +- [x] Implement for outputs + - [x] Add customization hooks to `StructureGenerator` + - [x] Add customization hook to `ParseResponse` + - [x] Add customization hook to `HttpBoundProtocolGenerator` + - [x] Customize output structure code gen in `sdk-codegen` to add either a request ID or a response field + - [x] Customize `ParseResponse` in `sdk-codegen` to populate the outputs +- [x] Implement `RequestId` for `Operation` and `operation::Response` +- [x] Implement `RequestId` for `Result` where `O` and `E` both implement `RequestId` +- [x] Re-export `RequestId` in generated crates +- [x] Add integration tests for each request ID access point + +Appendix A: Alternate solution for access on successful responses +----------------------------------------------------------------- + +Alternatively, for successful responses, a second `send` method (that is difficult to name)w +be added to the fluent client that has a return value that includes both the output and +the request ID (or entire response). + +This solution was dismissed due to difficulty naming, and the risk of name collision. + +Appendix B: Adding `RequestId` as a string to outputs via model transform +------------------------------------------------------------------------- + +The request ID could be stored on outputs by doing a model transform in `sdk-codegen` to add a +`RequestId` member field. However, this causes problems when an output already has a `RequestId` field, +and requires the addition of a synthetic trait to skip binding the field in the generated +serializers/deserializers. + +[generic error]: https://docs.rs/aws-smithy-types/0.51.0/aws_smithy_types/error/struct.Error.html +[`GetObjectError` example in S3]: https://docs.rs/aws-sdk-s3/0.21.0/aws_sdk_s3/error/struct.GetObjectError.html#method.request_id diff --git a/design/src/server/code_generation.md b/design/src/server/code_generation.md index 20da6c46c61..b04008cb46f 100644 --- a/design/src/server/code_generation.md +++ b/design/src/server/code_generation.md @@ -7,13 +7,15 @@ This document introduces the project and how code is being generated. It is writ The project is divided in: -- `/codegen`: it contains shared code for both client and server, but only generates a client -- `/codegen-server`: server only. This project started with `codegen` to generate a client, but client and server share common code; that code lives in `codegen`, which `codegen-server` depends on +- `/codegen-core`: contains common code to be used for both client and server code generation +- `/codegen-client`: client code generation. Depends on `codegen-core` +- `/codegen-server`: server code generation. Depends on `codegen-core` - `/aws`: the AWS Rust SDK, it deals with AWS services specifically. The folder structure reflects the project's, with the `rust-runtime` and the `codegen` - `/rust-runtime`: the generated client and server crates may depend on crates in this folder. Crates here are not code generated. The only crate that is not published is `inlineable`, which contains common functions used by other crates, [copied into][2] the source crate -`/rust-runtime` crates ("runtime crates") are added to a crate's dependency only when used. If a model uses event streams, it will depend on [`aws-smithy-eventstream`][3]. +Crates in `/rust-runtime` (informally referred to as "runtime crates") are added to a crate's dependency only when used. +For example, if a model uses event streams, the generated crates will depend on [`aws-smithy-eventstream`][3]. ## Generating code diff --git a/gradle.properties b/gradle.properties index 9857dcc8951..a353e421aff 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,8 +4,11 @@ # # Rust MSRV (entered into the generated README) -rust.msrv=1.62.1 +rust.msrv=1.63.0 +# To enable debug, swap out the two lines below. +# When changing this value, be sure to run `./gradlew --stop` to kill the Gradle daemon. +# org.gradle.jvmargs=-Xmx1024M -agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=localhost:5006 org.gradle.jvmargs=-Xmx1024M # Version number to use for the generated runtime crates @@ -18,10 +21,10 @@ smithyGradlePluginVersion=0.6.0 smithyVersion=1.26.2 # kotlin -kotlinVersion=1.6.21 +kotlinVersion=1.7.21 # testing/utility -ktlintVersion=0.46.1 +ktlintVersion=0.48.2 kotestVersion=5.2.3 # Avoid registering dependencies/plugins/tasks that are only used for testing purposes isTestingEnabled=true diff --git a/rust-runtime/Cargo.toml b/rust-runtime/Cargo.toml index 185d7082a0a..6a53b080e99 100644 --- a/rust-runtime/Cargo.toml +++ b/rust-runtime/Cargo.toml @@ -7,6 +7,7 @@ members = [ "aws-smithy-checksums", "aws-smithy-eventstream", "aws-smithy-http", + "aws-smithy-http-auth", "aws-smithy-http-tower", "aws-smithy-json", "aws-smithy-protocol-test", diff --git a/rust-runtime/aws-smithy-async/src/lib.rs b/rust-runtime/aws-smithy-async/src/lib.rs index 6cd95109e27..b6b4951afc2 100644 --- a/rust-runtime/aws-smithy-async/src/lib.rs +++ b/rust-runtime/aws-smithy-async/src/lib.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_debug_implementations, missing_docs, diff --git a/rust-runtime/aws-smithy-checksums/Cargo.toml b/rust-runtime/aws-smithy-checksums/Cargo.toml index b210360d556..ed770fe7997 100644 --- a/rust-runtime/aws-smithy-checksums/Cargo.toml +++ b/rust-runtime/aws-smithy-checksums/Cargo.toml @@ -29,7 +29,7 @@ tracing = "0.1" [dev-dependencies] bytes-utils = "0.1.2" -pretty_assertions = "1.2" +pretty_assertions = "1.3" tokio = { version = "1.8.4", features = ["macros", "rt"] } tracing-test = "0.2.1" diff --git a/rust-runtime/aws-smithy-checksums/src/lib.rs b/rust-runtime/aws-smithy-checksums/src/lib.rs index ce422fe5ee9..95789ad7c1e 100644 --- a/rust-runtime/aws-smithy-checksums/src/lib.rs +++ b/rust-runtime/aws-smithy-checksums/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] + //! Checksum calculation and verification callbacks. use crate::error::UnknownChecksumAlgorithmError; @@ -379,8 +381,7 @@ mod tests { fn test_checksum_algorithm_returns_error_for_unknown() { let error = "some invalid checksum algorithm" .parse::() - .err() - .expect("it should error"); + .expect_err("it should error"); assert_eq!( "some invalid checksum algorithm", error.checksum_algorithm() diff --git a/rust-runtime/aws-smithy-client/src/bounds.rs b/rust-runtime/aws-smithy-client/src/bounds.rs index 0cddf86096e..e0abf9e4dfc 100644 --- a/rust-runtime/aws-smithy-client/src/bounds.rs +++ b/rust-runtime/aws-smithy-client/src/bounds.rs @@ -7,7 +7,7 @@ //! required for `call` and friends. //! //! The short-hands will one day be true [trait aliases], but for now they are traits with blanket -//! implementations. Also, due to [compiler limitations], the bounds repeat a nubmer of associated +//! implementations. Also, due to [compiler limitations], the bounds repeat a number of associated //! types with bounds so that those bounds [do not need to be repeated] at the call site. It's a //! bit of a mess to define, but _should_ be invisible to callers. //! @@ -17,8 +17,12 @@ use crate::erase::DynConnector; use crate::http_connector::HttpConnector; -use crate::*; -use aws_smithy_http::result::ConnectorError; +use aws_smithy_http::body::SdkBody; +use aws_smithy_http::operation::{self, Operation}; +use aws_smithy_http::response::ParseHttpResponse; +use aws_smithy_http::result::{ConnectorError, SdkError, SdkSuccess}; +use aws_smithy_http::retry::ClassifyRetry; +use tower::{Layer, Service}; /// A service that has parsed a raw Smithy response. pub type Parsed = @@ -75,14 +79,14 @@ where } } -/// A Smithy middleware service that adjusts [`aws_smithy_http::operation::Request`]s. +/// A Smithy middleware service that adjusts [`aws_smithy_http::operation::Request`](operation::Request)s. /// /// This trait has a blanket implementation for all compatible types, and should never be /// implemented. pub trait SmithyMiddlewareService: Service< - aws_smithy_http::operation::Request, - Response = aws_smithy_http::operation::Response, + operation::Request, + Response = operation::Response, Error = aws_smithy_http_tower::SendOperationError, Future = ::Future, > @@ -96,8 +100,8 @@ pub trait SmithyMiddlewareService: impl SmithyMiddlewareService for T where T: Service< - aws_smithy_http::operation::Request, - Response = aws_smithy_http::operation::Response, + operation::Request, + Response = operation::Response, Error = aws_smithy_http_tower::SendOperationError, >, T::Future: Send + 'static, @@ -143,7 +147,7 @@ pub trait SmithyRetryPolicy: /// Forwarding type to `E` for bound inference. /// /// See module-level docs for details. - type E: Error; + type E: std::error::Error; /// Forwarding type to `Retry` for bound inference. /// @@ -155,7 +159,7 @@ impl SmithyRetryPolicy for R where R: tower::retry::Policy, SdkSuccess, SdkError> + Clone, O: ParseHttpResponse> + Send + Sync + Clone + 'static, - E: Error, + E: std::error::Error, Retry: ClassifyRetry, SdkError>, { type O = O; diff --git a/rust-runtime/aws-smithy-client/src/conns.rs b/rust-runtime/aws-smithy-client/src/conns.rs new file mode 100644 index 00000000000..1aac78cf467 --- /dev/null +++ b/rust-runtime/aws-smithy-client/src/conns.rs @@ -0,0 +1,129 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Type aliases for standard connection types. + +#[cfg(feature = "rustls")] +/// A `hyper` connector that uses the `rustls` crate for TLS. To use this in a smithy client, +/// wrap it in a [hyper_ext::Adapter](crate::hyper_ext::Adapter). +pub type Https = hyper_rustls::HttpsConnector; + +#[cfg(feature = "native-tls")] +/// A `hyper` connector that uses the `native-tls` crate for TLS. To use this in a smithy client, +/// wrap it in a [hyper_ext::Adapter](crate::hyper_ext::Adapter). +pub type NativeTls = hyper_tls::HttpsConnector; + +#[cfg(feature = "rustls")] +/// A smithy connector that uses the `rustls` crate for TLS. +pub type Rustls = crate::hyper_ext::Adapter; + +// Creating a `with_native_roots` HTTP client takes 300ms on OS X. Cache this so that we +// don't need to repeatedly incur that cost. +#[cfg(feature = "rustls")] +lazy_static::lazy_static! { + static ref HTTPS_NATIVE_ROOTS: Https = { + hyper_rustls::HttpsConnectorBuilder::new() + .with_native_roots() + .https_or_http() + .enable_http1() + .enable_http2() + .build() + }; +} + +#[cfg(feature = "rustls")] +/// Return a default HTTPS connector backed by the `rustls` crate. +/// +/// It requires a minimum TLS version of 1.2. +/// It allows you to connect to both `http` and `https` URLs. +pub fn https() -> Https { + HTTPS_NATIVE_ROOTS.clone() +} + +#[cfg(feature = "native-tls")] +/// Return a default HTTPS connector backed by the `hyper_tls` crate. +/// +/// It requires a minimum TLS version of 1.2. +/// It allows you to connect to both `http` and `https` URLs. +pub fn native_tls() -> NativeTls { + // `TlsConnector` actually comes for here: https://docs.rs/native-tls/latest/native_tls/ + // hyper_tls just re-exports the crate for convenience. + let mut tls = hyper_tls::native_tls::TlsConnector::builder(); + let tls = tls + .min_protocol_version(Some(hyper_tls::native_tls::Protocol::Tlsv12)) + .build() + .unwrap_or_else(|e| panic!("Error while creating TLS connector: {}", e)); + let mut http = hyper::client::HttpConnector::new(); + http.enforce_http(false); + hyper_tls::HttpsConnector::from((http, tls.into())) +} + +#[cfg(all(test, any(feature = "native-tls", feature = "rustls")))] +mod tests { + use crate::erase::DynConnector; + use crate::hyper_ext::Adapter; + use aws_smithy_http::body::SdkBody; + use http::{Method, Request, Uri}; + use tower::{Service, ServiceBuilder}; + + async fn send_request_and_assert_success(conn: DynConnector, uri: &Uri) { + let mut svc = ServiceBuilder::new().service(conn); + let req = Request::builder() + .uri(uri) + .method(Method::GET) + .body(SdkBody::empty()) + .unwrap(); + let res = svc.call(req).await.unwrap(); + assert!(res.status().is_success()); + } + + #[cfg(feature = "native-tls")] + mod native_tls_tests { + use super::super::native_tls; + use super::*; + + #[tokio::test] + async fn test_native_tls_connector_can_make_http_requests() { + let conn = Adapter::builder().build(native_tls()); + let conn = DynConnector::new(conn); + let http_uri: Uri = "http://example.com/".parse().unwrap(); + + send_request_and_assert_success(conn, &http_uri).await; + } + + #[tokio::test] + async fn test_native_tls_connector_can_make_https_requests() { + let conn = Adapter::builder().build(native_tls()); + let conn = DynConnector::new(conn); + let https_uri: Uri = "https://example.com/".parse().unwrap(); + + send_request_and_assert_success(conn, &https_uri).await; + } + } + + #[cfg(feature = "rustls")] + mod rustls_tests { + use super::super::https; + use super::*; + + #[tokio::test] + async fn test_rustls_connector_can_make_http_requests() { + let conn = Adapter::builder().build(https()); + let conn = DynConnector::new(conn); + let http_uri: Uri = "http://example.com/".parse().unwrap(); + + send_request_and_assert_success(conn, &http_uri).await; + } + + #[tokio::test] + async fn test_rustls_connector_can_make_https_requests() { + let conn = Adapter::builder().build(https()); + let conn = DynConnector::new(conn); + let https_uri: Uri = "https://example.com/".parse().unwrap(); + + send_request_and_assert_success(conn, &https_uri).await; + } + } +} diff --git a/rust-runtime/aws-smithy-client/src/lib.rs b/rust-runtime/aws-smithy-client/src/lib.rs index b1e2b1128d6..6e5e5ba9ee7 100644 --- a/rust-runtime/aws-smithy-client/src/lib.rs +++ b/rust-runtime/aws-smithy-client/src/lib.rs @@ -14,6 +14,7 @@ //! | `rustls` | Use `rustls` as the HTTP client's TLS implementation | //! | `client-hyper` | Use `hyper` to handle HTTP requests | +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_debug_implementations, missing_docs, @@ -23,7 +24,10 @@ pub mod bounds; pub mod erase; +pub mod http_connector; +pub mod never; pub mod retry; +pub mod timeout; // https://github.com/rust-lang/rust/issues/72081 #[allow(rustdoc::private_doc_tests)] @@ -35,8 +39,8 @@ pub mod dvr; #[cfg(feature = "test-util")] pub mod test_connection; -pub mod http_connector; - +#[cfg(feature = "client-hyper")] +pub mod conns; #[cfg(feature = "client-hyper")] pub mod hyper_ext; @@ -46,64 +50,19 @@ pub mod hyper_ext; #[doc(hidden)] pub mod static_tests; -pub mod never; -pub mod timeout; -pub use timeout::TimeoutLayer; - -/// Type aliases for standard connection types. -#[cfg(feature = "client-hyper")] -#[allow(missing_docs)] -pub mod conns { - #[cfg(feature = "rustls")] - pub type Https = hyper_rustls::HttpsConnector; - - // Creating a `with_native_roots` HTTP client takes 300ms on OS X. Cache this so that we - // don't need to repeatedly incur that cost. - #[cfg(feature = "rustls")] - lazy_static::lazy_static! { - static ref HTTPS_NATIVE_ROOTS: Https = { - hyper_rustls::HttpsConnectorBuilder::new() - .with_native_roots() - .https_or_http() - .enable_http1() - .enable_http2() - .build() - }; - } - - #[cfg(feature = "rustls")] - pub fn https() -> Https { - HTTPS_NATIVE_ROOTS.clone() - } - - #[cfg(feature = "native-tls")] - pub fn native_tls() -> NativeTls { - hyper_tls::HttpsConnector::new() - } - - #[cfg(feature = "native-tls")] - pub type NativeTls = hyper_tls::HttpsConnector; - - #[cfg(feature = "rustls")] - pub type Rustls = - crate::hyper_ext::Adapter>; -} - use aws_smithy_async::rt::sleep::AsyncSleep; -use aws_smithy_http::body::SdkBody; use aws_smithy_http::operation::Operation; use aws_smithy_http::response::ParseHttpResponse; pub use aws_smithy_http::result::{SdkError, SdkSuccess}; -use aws_smithy_http::retry::ClassifyRetry; use aws_smithy_http_tower::dispatch::DispatchLayer; use aws_smithy_http_tower::parse_response::ParseResponseLayer; use aws_smithy_types::error::display::DisplayErrorContext; use aws_smithy_types::retry::ProvideErrorKind; use aws_smithy_types::timeout::OperationTimeoutConfig; -use std::error::Error; use std::sync::Arc; use timeout::ClientTimeoutParams; -use tower::{Layer, Service, ServiceBuilder, ServiceExt}; +pub use timeout::TimeoutLayer; +use tower::{Service, ServiceBuilder, ServiceExt}; use tracing::{debug_span, field, field::display, Instrument}; /// Smithy service client. @@ -116,7 +75,7 @@ use tracing::{debug_span, field, field::display, Instrument}; /// such as those used for routing (like the URL), authentication, and authorization. /// /// The middleware takes the form of a [`tower::Layer`] that wraps the actual connection for each -/// request. The [`tower::Service`] that the middleware produces must accept requests of the type +/// request. The [`tower::Service`](Service) that the middleware produces must accept requests of the type /// [`aws_smithy_http::operation::Request`] and return responses of the type /// [`http::Response`], most likely by modifying the provided request in place, passing it /// to the inner service, and then ultimately returning the inner service's response. @@ -150,9 +109,9 @@ impl Client where M: Default, { - /// Create a Smithy client from the given `connector`, a middleware default, the [standard - /// retry policy](crate::retry::Standard), and the [`default_async_sleep`](aws_smithy_async::rt::sleep::default_async_sleep) - /// sleep implementation. + /// Create a Smithy client from the given `connector`, a middleware default, the + /// [standard retry policy](retry::Standard), and the + /// [`default_async_sleep`](aws_smithy_async::rt::sleep::default_async_sleep) sleep implementation. pub fn new(connector: C) -> Self { Builder::new() .connector(connector) diff --git a/rust-runtime/aws-smithy-client/src/static_tests.rs b/rust-runtime/aws-smithy-client/src/static_tests.rs index 4c9ef91bad8..a8cd503022c 100644 --- a/rust-runtime/aws-smithy-client/src/static_tests.rs +++ b/rust-runtime/aws-smithy-client/src/static_tests.rs @@ -5,7 +5,7 @@ //! This module provides types useful for static tests. #![allow(missing_docs, missing_debug_implementations)] -use crate::{Builder, Error, Operation, ParseHttpResponse, ProvideErrorKind}; +use crate::{Builder, Operation, ParseHttpResponse, ProvideErrorKind}; use aws_smithy_http::operation; use aws_smithy_http::retry::DefaultResponseRetryClassifier; @@ -17,7 +17,7 @@ impl std::fmt::Display for TestOperationError { unreachable!("only used for static tests") } } -impl Error for TestOperationError {} +impl std::error::Error for TestOperationError {} impl ProvideErrorKind for TestOperationError { fn retryable_error_kind(&self) -> Option { unreachable!("only used for static tests") diff --git a/rust-runtime/aws-smithy-eventstream/src/lib.rs b/rust-runtime/aws-smithy-eventstream/src/lib.rs index 07cb5388577..594be5d3c9b 100644 --- a/rust-runtime/aws-smithy-eventstream/src/lib.rs +++ b/rust-runtime/aws-smithy-eventstream/src/lib.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_debug_implementations, /*missing_docs, diff --git a/rust-runtime/aws-smithy-http-auth/Cargo.toml b/rust-runtime/aws-smithy-http-auth/Cargo.toml new file mode 100644 index 00000000000..0d70b25eec7 --- /dev/null +++ b/rust-runtime/aws-smithy-http-auth/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "aws-smithy-http-auth" +version = "0.0.0-smithy-rs-head" +authors = [ + "AWS Rust SDK Team ", + "Eduardo Rodrigues <16357187+eduardomourar@users.noreply.github.com>", +] +description = "Smithy HTTP logic for smithy-rs." +edition = "2021" +license = "Apache-2.0" +repository = "https://github.com/awslabs/smithy-rs" + +[dependencies] +zeroize = "1" + +[package.metadata.docs.rs] +all-features = true +targets = ["x86_64-unknown-linux-gnu"] +rustdoc-args = ["--cfg", "docsrs"] +# End of docs.rs metadata diff --git a/rust-runtime/aws-smithy-http-auth/LICENSE b/rust-runtime/aws-smithy-http-auth/LICENSE new file mode 100644 index 00000000000..67db8588217 --- /dev/null +++ b/rust-runtime/aws-smithy-http-auth/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. diff --git a/rust-runtime/aws-smithy-http-auth/README.md b/rust-runtime/aws-smithy-http-auth/README.md new file mode 100644 index 00000000000..1d963cafce0 --- /dev/null +++ b/rust-runtime/aws-smithy-http-auth/README.md @@ -0,0 +1,7 @@ +# aws-smithy-http-auth + +HTTP Auth implementation for service clients generated by [smithy-rs](https://github.com/awslabs/smithy-rs). + + +This crate is part of the [AWS SDK for Rust](https://awslabs.github.io/aws-sdk-rust/) and the [smithy-rs](https://github.com/awslabs/smithy-rs) code generator. In most cases, it should not be used directly. + diff --git a/rust-runtime/aws-smithy-http-auth/external-types.toml b/rust-runtime/aws-smithy-http-auth/external-types.toml new file mode 100644 index 00000000000..ff30ccf5ad0 --- /dev/null +++ b/rust-runtime/aws-smithy-http-auth/external-types.toml @@ -0,0 +1,2 @@ +allowed_external_types = [ +] diff --git a/rust-runtime/aws-smithy-http-auth/src/api_key.rs b/rust-runtime/aws-smithy-http-auth/src/api_key.rs new file mode 100644 index 00000000000..bb2ab65b3c5 --- /dev/null +++ b/rust-runtime/aws-smithy-http-auth/src/api_key.rs @@ -0,0 +1,74 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! HTTP Auth API Key + +use std::cmp::PartialEq; +use std::fmt::Debug; +use std::sync::Arc; +use zeroize::Zeroizing; + +/// Authentication configuration to connect to a Smithy Service +#[derive(Clone, Eq, PartialEq)] +pub struct AuthApiKey(Arc); + +#[derive(Clone, Eq, PartialEq)] +struct Inner { + api_key: Zeroizing, +} + +impl Debug for AuthApiKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut auth_api_key = f.debug_struct("AuthApiKey"); + auth_api_key.field("api_key", &"** redacted **").finish() + } +} + +impl AuthApiKey { + /// Constructs a new API key. + pub fn new(api_key: impl Into) -> Self { + Self(Arc::new(Inner { + api_key: Zeroizing::new(api_key.into()), + })) + } + + /// Returns the underlying api key. + pub fn api_key(&self) -> &str { + &self.0.api_key + } +} + +impl From<&str> for AuthApiKey { + fn from(api_key: &str) -> Self { + Self::from(api_key.to_owned()) + } +} + +impl From for AuthApiKey { + fn from(api_key: String) -> Self { + Self(Arc::new(Inner { + api_key: Zeroizing::new(api_key), + })) + } +} + +#[cfg(test)] +mod tests { + use super::AuthApiKey; + + #[test] + fn api_key_is_equal() { + let api_key_a: AuthApiKey = "some-api-key".into(); + let api_key_b = AuthApiKey::new("some-api-key"); + assert_eq!(api_key_a, api_key_b); + } + + #[test] + fn api_key_is_different() { + let api_key_a = AuthApiKey::new("some-api-key"); + let api_key_b: AuthApiKey = String::from("another-api-key").into(); + assert_ne!(api_key_a, api_key_b); + } +} diff --git a/rust-runtime/aws-smithy-http-auth/src/definition.rs b/rust-runtime/aws-smithy-http-auth/src/definition.rs new file mode 100644 index 00000000000..918f6aae8f3 --- /dev/null +++ b/rust-runtime/aws-smithy-http-auth/src/definition.rs @@ -0,0 +1,251 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! HTTP Auth Definition + +use crate::location::HttpAuthLocation; +use std::cmp::PartialEq; +use std::fmt::Debug; + +/// A HTTP-specific authentication scheme that sends an arbitrary +/// auth value in a header or query string parameter. +// As described in the Smithy documentation: +// https://github.com/awslabs/smithy/blob/main/smithy-model/src/main/resources/software/amazon/smithy/model/loader/prelude.smithy +#[derive(Clone, Debug, Default, PartialEq)] +pub struct HttpAuthDefinition { + /// Defines the location of where the Auth is serialized. + location: HttpAuthLocation, + + /// Defines the name of the HTTP header or query string parameter + /// that contains the Auth. + name: String, + + /// Defines the security scheme to use on the `Authorization` header value. + /// This can only be set if the "location" property is set to [`HttpAuthLocation::Header`]. + scheme: Option, +} + +impl HttpAuthDefinition { + /// Returns a builder for `HttpAuthDefinition`. + pub fn builder() -> http_auth_definition::Builder { + http_auth_definition::Builder::default() + } + + /// Constructs a new HTTP auth definition in header. + pub fn header(header_name: N, scheme: S) -> Self + where + N: Into, + S: Into>, + { + let mut builder = Self::builder() + .location(HttpAuthLocation::Header) + .name(header_name); + let scheme: Option = scheme.into(); + if scheme.is_some() { + builder.set_scheme(scheme); + } + builder.build() + } + + /// Constructs a new HTTP auth definition following the RFC 2617 for Basic Auth. + pub fn basic_auth() -> Self { + Self::builder() + .location(HttpAuthLocation::Header) + .name("Authorization".to_owned()) + .scheme("Basic".to_owned()) + .build() + } + + /// Constructs a new HTTP auth definition following the RFC 2617 for Digest Auth. + pub fn digest_auth() -> Self { + Self::builder() + .location(HttpAuthLocation::Header) + .name("Authorization".to_owned()) + .scheme("Digest".to_owned()) + .build() + } + + /// Constructs a new HTTP auth definition following the RFC 6750 for Bearer Auth. + pub fn bearer_auth() -> Self { + Self::builder() + .location(HttpAuthLocation::Header) + .name("Authorization".to_owned()) + .scheme("Bearer".to_owned()) + .build() + } + + /// Constructs a new HTTP auth definition in query string. + pub fn query(name: impl Into) -> Self { + Self::builder() + .location(HttpAuthLocation::Query) + .name(name.into()) + .build() + } + + /// Returns the HTTP auth location. + pub fn location(&self) -> HttpAuthLocation { + self.location + } + + /// Returns the HTTP auth name. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the HTTP auth scheme. + pub fn scheme(&self) -> Option<&str> { + self.scheme.as_deref() + } +} + +/// Types associated with [`HttpAuthDefinition`]. +pub mod http_auth_definition { + use super::HttpAuthDefinition; + use crate::{ + definition::HttpAuthLocation, + error::{AuthError, AuthErrorKind}, + }; + + /// A builder for [`HttpAuthDefinition`]. + #[derive(Debug, Default)] + pub struct Builder { + location: Option, + name: Option, + scheme: Option, + } + + impl Builder { + /// Sets the HTTP auth location. + pub fn location(mut self, location: HttpAuthLocation) -> Self { + self.location = Some(location); + self + } + + /// Sets the HTTP auth location. + pub fn set_location(&mut self, location: Option) -> &mut Self { + self.location = location; + self + } + + /// Sets the the HTTP auth name. + pub fn name(mut self, name: impl Into) -> Self { + self.name = Some(name.into()); + self + } + + /// Sets the the HTTP auth name. + pub fn set_name(&mut self, name: Option) -> &mut Self { + self.name = name; + self + } + + /// Sets the HTTP auth scheme. + pub fn scheme(mut self, scheme: impl Into) -> Self { + self.scheme = Some(scheme.into()); + self + } + + /// Sets the HTTP auth scheme. + pub fn set_scheme(&mut self, scheme: Option) -> &mut Self { + self.scheme = scheme; + self + } + + /// Constructs a [`HttpAuthDefinition`] from the builder. + pub fn build(self) -> HttpAuthDefinition { + if self.scheme.is_some() + && self + .name + .as_deref() + .map_or("".to_string(), |s| s.to_ascii_lowercase()) + != "authorization" + { + // Stop execution because the Smithy model should not contain such combination. + // Otherwise, this would cause unexpected behavior in the SDK. + panic!("{}", AuthError::from(AuthErrorKind::SchemeNotAllowed)); + } + HttpAuthDefinition { + location: self.location.unwrap_or_else(|| { + panic!( + "{}", + AuthError::from(AuthErrorKind::MissingRequiredField("location")) + ) + }), + name: self.name.unwrap_or_else(|| { + panic!( + "{}", + AuthError::from(AuthErrorKind::MissingRequiredField("name")) + ) + }), + scheme: self.scheme, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::HttpAuthDefinition; + use crate::location::HttpAuthLocation; + + #[test] + fn definition_for_header_without_scheme() { + let definition = HttpAuthDefinition::header("Header", None); + assert_eq!(definition.location, HttpAuthLocation::Header); + assert_eq!(definition.name, "Header"); + assert_eq!(definition.scheme, None); + } + + #[test] + fn definition_for_authorization_header_with_scheme() { + let definition = HttpAuthDefinition::header("authorization", "Scheme".to_owned()); + assert_eq!(definition.location(), HttpAuthLocation::Header); + assert_eq!(definition.name(), "authorization"); + assert_eq!(definition.scheme(), Some("Scheme")); + } + + #[test] + #[should_panic] + fn definition_fails_with_scheme_not_allowed() { + let _ = HttpAuthDefinition::header("Invalid".to_owned(), "Scheme".to_owned()); + } + + #[test] + fn definition_for_basic() { + let definition = HttpAuthDefinition::basic_auth(); + assert_eq!( + definition, + HttpAuthDefinition { + location: HttpAuthLocation::Header, + name: "Authorization".to_owned(), + scheme: Some("Basic".to_owned()), + } + ); + } + + #[test] + fn definition_for_digest() { + let definition = HttpAuthDefinition::digest_auth(); + assert_eq!(definition.location(), HttpAuthLocation::Header); + assert_eq!(definition.name(), "Authorization"); + assert_eq!(definition.scheme(), Some("Digest")); + } + + #[test] + fn definition_for_bearer_token() { + let definition = HttpAuthDefinition::bearer_auth(); + assert_eq!(definition.location(), HttpAuthLocation::Header); + assert_eq!(definition.name(), "Authorization"); + assert_eq!(definition.scheme(), Some("Bearer")); + } + + #[test] + fn definition_for_query() { + let definition = HttpAuthDefinition::query("query_key"); + assert_eq!(definition.location(), HttpAuthLocation::Query); + assert_eq!(definition.name(), "query_key"); + assert_eq!(definition.scheme(), None); + } +} diff --git a/rust-runtime/aws-smithy-http-auth/src/error.rs b/rust-runtime/aws-smithy-http-auth/src/error.rs new file mode 100644 index 00000000000..227dbe1cf29 --- /dev/null +++ b/rust-runtime/aws-smithy-http-auth/src/error.rs @@ -0,0 +1,42 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! HTTP Auth Error + +use std::cmp::PartialEq; +use std::fmt::Debug; + +#[derive(Debug, Eq, PartialEq)] +pub(crate) enum AuthErrorKind { + InvalidLocation, + MissingRequiredField(&'static str), + SchemeNotAllowed, +} + +/// Error for Smithy authentication +#[derive(Debug, Eq, PartialEq)] +pub struct AuthError { + kind: AuthErrorKind, +} + +impl std::fmt::Display for AuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use AuthErrorKind::*; + match &self.kind { + InvalidLocation => write!(f, "invalid location: expected `header` or `query`"), + MissingRequiredField(field) => write!(f, "missing required field: {}", field), + SchemeNotAllowed => write!( + f, + "scheme only allowed when it is set into the `Authorization` header" + ), + } + } +} + +impl From for AuthError { + fn from(kind: AuthErrorKind) -> Self { + Self { kind } + } +} diff --git a/rust-runtime/aws-smithy-http-auth/src/lib.rs b/rust-runtime/aws-smithy-http-auth/src/lib.rs new file mode 100644 index 00000000000..ada202143ec --- /dev/null +++ b/rust-runtime/aws-smithy-http-auth/src/lib.rs @@ -0,0 +1,14 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#![allow(clippy::derive_partial_eq_without_eq)] +#![warn(missing_debug_implementations, missing_docs, rustdoc::all)] + +//! Smithy HTTP Auth Types + +pub mod api_key; +pub mod definition; +pub mod error; +pub mod location; diff --git a/rust-runtime/aws-smithy-http-auth/src/location.rs b/rust-runtime/aws-smithy-http-auth/src/location.rs new file mode 100644 index 00000000000..9fc0ced0e8a --- /dev/null +++ b/rust-runtime/aws-smithy-http-auth/src/location.rs @@ -0,0 +1,73 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! HTTP Auth Location + +use std::cmp::PartialEq; +use std::fmt::Debug; + +use crate::error::{AuthError, AuthErrorKind}; + +/// Enum for describing where the HTTP Auth can be placed. +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] +pub enum HttpAuthLocation { + /// In the HTTP header. + #[default] + Header, + /// In the query string of the URL. + Query, +} + +impl HttpAuthLocation { + fn as_str(&self) -> &'static str { + match self { + Self::Header => "header", + Self::Query => "query", + } + } +} + +impl TryFrom<&str> for HttpAuthLocation { + type Error = AuthError; + fn try_from(value: &str) -> Result { + match value { + "header" => Ok(Self::Header), + "query" => Ok(Self::Query), + _ => Err(AuthError::from(AuthErrorKind::InvalidLocation)), + } + } +} + +impl TryFrom for HttpAuthLocation { + type Error = AuthError; + fn try_from(value: String) -> Result { + Self::try_from(value.as_str()) + } +} + +impl AsRef for HttpAuthLocation { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl std::fmt::Display for HttpAuthLocation { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + std::fmt::Display::fmt(&self.as_str(), f) + } +} + +#[cfg(test)] +mod tests { + use super::HttpAuthLocation; + use crate::error::{AuthError, AuthErrorKind}; + + #[test] + fn fails_if_location_is_invalid() { + let actual = HttpAuthLocation::try_from("invalid").unwrap_err(); + let expected = AuthError::from(AuthErrorKind::InvalidLocation); + assert_eq!(actual, expected); + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/examples/Makefile b/rust-runtime/aws-smithy-http-server-python/examples/Makefile index 22be94a19e2..3b73f08ca6c 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/Makefile +++ b/rust-runtime/aws-smithy-http-server-python/examples/Makefile @@ -25,29 +25,35 @@ endif # Note on `--compatibility linux`: Maturin by default uses `manylinux_x_y` but it is not supported # by our current CI version (3.7.10), we can drop `--compatibility linux` when we switch to higher Python version. # For more detail: https://github.com/pypa/manylinux -build-wheel: ensure-maturin codegen - maturin build --manifest-path $(SERVER_SDK_DST)/Cargo.toml --out $(WHEELS) --compatibility linux +build-wheel: ensure-maturin + cd $(SERVER_SDK_DST) && maturin build --out $(WHEELS) --compatibility linux -build-wheel-release: ensure-maturin codegen - maturin build --manifest-path $(SERVER_SDK_DST)/Cargo.toml --out $(WHEELS) --compatibility linux --release +build-wheel-release: ensure-maturin + cd $(SERVER_SDK_DST) && maturin build --out $(WHEELS) --compatibility linux --release install-wheel: find $(WHEELS) -type f -name '*.whl' | xargs python3 -m pip install --user --force-reinstall -build: build-wheel install-wheel +generate-stubs: + python3 $(CUR_DIR)/stubgen.py pokemon_service_server_sdk $(SERVER_SDK_DST)/python/pokemon_service_server_sdk -release: build-wheel-release install-wheel +build: codegen + $(MAKE) build-wheel + $(MAKE) install-wheel + $(MAKE) generate-stubs + $(MAKE) build-wheel-release + $(MAKE) install-wheel run: build python3 $(CUR_DIR)/pokemon_service.py -run-release: release - python3 $(CUR_DIR)/pokemon_service.py - py-check: build - mypy pokemon_service.py + python3 -m mypy pokemon_service.py + +py-test: + python3 stubgen_test.py -test: build +test: build py-check py-test cargo test clippy: codegen @@ -60,6 +66,6 @@ clean: cargo clean || echo "Unable to run cargo clean" distclean: clean - rm -rf $(SERVER_SDK_DST) $(CLIENT_SDK_DST) $(WHEELS) $(CUR_DIR)/Cargo.lock + rm -rf $(SERVER_SDK_DST) $(SERVER_SDK_SRC) $(CLIENT_SDK_DST) $(CLIENT_SDK_SRC) $(WHEELS) $(CUR_DIR)/Cargo.lock .PHONY: all diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index 1f45c5fe13b..a3c7bf1c933 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -8,34 +8,34 @@ import random from threading import Lock from dataclasses import dataclass -from typing import List, Optional, Callable, Awaitable +from typing import Dict, Any, List, Optional, Callable, Awaitable from pokemon_service_server_sdk import App -from pokemon_service_server_sdk.tls import TlsConfig # type: ignore -from pokemon_service_server_sdk.aws_lambda import LambdaContext # type: ignore -from pokemon_service_server_sdk.error import ResourceNotFoundException # type: ignore -from pokemon_service_server_sdk.input import ( # type: ignore +from pokemon_service_server_sdk.tls import TlsConfig +from pokemon_service_server_sdk.aws_lambda import LambdaContext +from pokemon_service_server_sdk.error import ResourceNotFoundException +from pokemon_service_server_sdk.input import ( DoNothingInput, GetPokemonSpeciesInput, GetServerStatisticsInput, CheckHealthInput, StreamPokemonRadioInput, ) -from pokemon_service_server_sdk.logging import TracingHandler # type: ignore -from pokemon_service_server_sdk.middleware import ( # type: ignore +from pokemon_service_server_sdk.logging import TracingHandler +from pokemon_service_server_sdk.middleware import ( MiddlewareException, Response, Request, ) -from pokemon_service_server_sdk.model import FlavorText, Language # type: ignore -from pokemon_service_server_sdk.output import ( # type: ignore +from pokemon_service_server_sdk.model import FlavorText, Language +from pokemon_service_server_sdk.output import ( DoNothingOutput, GetPokemonSpeciesOutput, GetServerStatisticsOutput, CheckHealthOutput, StreamPokemonRadioOutput, ) -from pokemon_service_server_sdk.types import ByteStream # type: ignore +from pokemon_service_server_sdk.types import ByteStream # Logging can bee setup using standard Python tooling. We provide # fast logging handler, Tracingandler based on Rust tracing crate. @@ -131,7 +131,7 @@ def get_random_radio_stream(self) -> str: # Entrypoint ########################################################### # Get an App instance. -app = App() +app: "App[Context]" = App() # Register the context. app.context(Context()) @@ -249,7 +249,7 @@ def check_health(_: CheckHealthInput) -> CheckHealthOutput: async def stream_pokemon_radio( _: StreamPokemonRadioInput, context: Context ) -> StreamPokemonRadioOutput: - import aiohttp + import aiohttp # type: ignore radio_url = context.get_random_radio_stream() logging.info("Random radio URL for this stream is %s", radio_url) @@ -270,7 +270,7 @@ def main() -> None: parser.add_argument("--tls-cert-path") args = parser.parse_args() - config = dict(workers=1) + config: Dict[str, Any] = dict(workers=1) if args.enable_tls: config["tls"] = TlsConfig( key_path=args.tls_key_path, diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service_server_sdk.pyi b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service_server_sdk.pyi deleted file mode 100644 index ccbd5edb2e7..00000000000 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service_server_sdk.pyi +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -# NOTE: This is manually created to surpass some mypy errors and it is incomplete, -# in future we will autogenerate correct stubs. - -from typing import Any, TypeVar, Callable - -F = TypeVar("F", bound=Callable[..., Any]) - -class App: - context: Any - run: Any - - def middleware(self, func: F) -> F: ... - def do_nothing(self, func: F) -> F: ... - def get_pokemon_species(self, func: F) -> F: ... - def get_server_statistics(self, func: F) -> F: ... - def check_health(self, func: F) -> F: ... - def stream_pokemon_radio(self, func: F) -> F: ... diff --git a/rust-runtime/aws-smithy-http-server-python/examples/stubgen.py b/rust-runtime/aws-smithy-http-server-python/examples/stubgen.py new file mode 100644 index 00000000000..30348838d2d --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/examples/stubgen.py @@ -0,0 +1,422 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations +import re +import inspect +import textwrap +from pathlib import Path +from typing import Any, Set, Dict, List, Tuple, Optional + +ROOT_MODULE_NAME_PLACEHOLDER = "__root_module_name__" + + +class Writer: + """ + Writer provides utilities for writing Python stubs. + """ + + root_module_name: str + path: Path + subwriters: List[Writer] + imports: Set[str] + defs: List[str] + generics: Set[str] + + def __init__(self, path: Path, root_module_name: str) -> None: + self.path = path + self.root_module_name = root_module_name + self.subwriters = [] + self.imports = set([]) + self.defs = [] + self.generics = set([]) + + def fix_path(self, path: str) -> str: + """ + Returns fixed version of given type path. + It unescapes `\\[` and `\\]` and also populates placeholder for root module name. + """ + return ( + path.replace(ROOT_MODULE_NAME_PLACEHOLDER, self.root_module_name) + .replace("\\[", "[") + .replace("\\]", "]") + ) + + def submodule(self, path: Path) -> Writer: + w = Writer(path, self.root_module_name) + self.subwriters.append(w) + return w + + def include(self, path: str) -> str: + # `path` might be nested like: typing.Optional[typing.List[pokemon_service_server_sdk.model.GetPokemonSpecies]] + # we need to process every subpath in a nested path + paths = filter(lambda p: p, re.split("\\[|\\]|,| ", path)) + for subpath in paths: + parts = subpath.rsplit(".", maxsplit=1) + # add `typing` to imports for a path like `typing.List` + # but skip if the path doesn't have any namespace like `str` or `bool` + if len(parts) == 2: + self.imports.add(parts[0]) + + return path + + def fix_and_include(self, path: str) -> str: + return self.include(self.fix_path(path)) + + def define(self, code: str) -> None: + self.defs.append(code) + + def generic(self, name: str) -> None: + self.generics.add(name) + + def dump(self) -> None: + for w in self.subwriters: + w.dump() + + generics = "" + for g in sorted(self.generics): + generics += f"{g} = {self.include('typing.TypeVar')}('{g}')\n" + + self.path.parent.mkdir(parents=True, exist_ok=True) + contents = join([f"import {p}" for p in sorted(self.imports)]) + contents += "\n\n" + if generics: + contents += generics + "\n" + contents += join(self.defs) + self.path.write_text(contents) + + +class DocstringParserResult: + def __init__(self) -> None: + self.types: List[str] = [] + self.params: List[Tuple[str, str]] = [] + self.rtypes: List[str] = [] + self.generics: List[str] = [] + self.extends: List[str] = [] + + +def parse_type_directive(line: str, res: DocstringParserResult): + parts = line.split(" ", maxsplit=1) + if len(parts) != 2: + raise ValueError( + f"Invalid `:type` directive: `{line}` must be in `:type T:` format" + ) + res.types.append(parts[1].rstrip(":")) + + +def parse_rtype_directive(line: str, res: DocstringParserResult): + parts = line.split(" ", maxsplit=1) + if len(parts) != 2: + raise ValueError( + f"Invalid `:rtype` directive: `{line}` must be in `:rtype T:` format" + ) + res.rtypes.append(parts[1].rstrip(":")) + + +def parse_param_directive(line: str, res: DocstringParserResult): + parts = line.split(" ", maxsplit=2) + if len(parts) != 3: + raise ValueError( + f"Invalid `:param` directive: `{line}` must be in `:param name T:` format" + ) + name = parts[1] + ty = parts[2].rstrip(":") + res.params.append((name, ty)) + + +def parse_generic_directive(line: str, res: DocstringParserResult): + parts = line.split(" ", maxsplit=1) + if len(parts) != 2: + raise ValueError( + f"Invalid `:generic` directive: `{line}` must be in `:generic T:` format" + ) + res.generics.append(parts[1].rstrip(":")) + + +def parse_extends_directive(line: str, res: DocstringParserResult): + parts = line.split(" ", maxsplit=1) + if len(parts) != 2: + raise ValueError( + f"Invalid `:extends` directive: `{line}` must be in `:extends Base[...]:` format" + ) + res.extends.append(parts[1].rstrip(":")) + + +DocstringParserDirectives = { + "type": parse_type_directive, + "param": parse_param_directive, + "rtype": parse_rtype_directive, + "generic": parse_generic_directive, + "extends": parse_extends_directive, +} + + +class DocstringParser: + """ + DocstringParser provides utilities for parsing type information from docstring. + """ + + @staticmethod + def parse(obj: Any) -> Optional[DocstringParserResult]: + doc = inspect.getdoc(obj) + if not doc: + return None + + res = DocstringParserResult() + for line in doc.splitlines(): + line = line.strip() + for d, p in DocstringParserDirectives.items(): + if line.startswith(f":{d} ") and line.endswith(":"): + p(line, res) + return res + + @staticmethod + def parse_type(obj: Any) -> str: + result = DocstringParser.parse(obj) + if not result or len(result.types) == 0: + return "typing.Any" + return result.types[0] + + @staticmethod + def parse_function(obj: Any) -> Optional[Tuple[List[Tuple[str, str]], str]]: + result = DocstringParser.parse(obj) + if not result: + return None + + return ( + result.params, + "None" if len(result.rtypes) == 0 else result.rtypes[0], + ) + + @staticmethod + def parse_class(obj: Any) -> Tuple[List[str], List[str]]: + result = DocstringParser.parse(obj) + if not result: + return ([], []) + return (result.generics, result.extends) + + @staticmethod + def clean_doc(obj: Any) -> str: + doc = inspect.getdoc(obj) + if not doc: + return "" + + def predicate(l: str) -> bool: + for k in DocstringParserDirectives.keys(): + if l.startswith(f":{k} ") and l.endswith(":"): + return False + return True + + return "\n".join([l for l in doc.splitlines() if predicate(l)]).strip() + + +def indent(code: str, level: int = 4) -> str: + return textwrap.indent(code, level * " ") + + +def is_fn_like(obj: Any) -> bool: + return ( + inspect.isbuiltin(obj) + or inspect.ismethod(obj) + or inspect.isfunction(obj) + or inspect.ismethoddescriptor(obj) + or inspect.iscoroutine(obj) + or inspect.iscoroutinefunction(obj) + ) + + +def join(args: List[str], delim: str = "\n") -> str: + return delim.join(filter(lambda x: x, args)) + + +def make_doc(obj: Any) -> str: + doc = DocstringParser.clean_doc(obj) + doc = textwrap.dedent(doc) + if not doc: + return "" + + return join(['"""', doc, '"""']) + + +def make_field(writer: Writer, name: str, field: Any) -> str: + return f"{name}: {writer.fix_and_include(DocstringParser.parse_type(field))}" + + +def make_function( + writer: Writer, + name: str, + obj: Any, + include_docs: bool = True, + parent: Optional[Any] = None, +) -> str: + is_static_method = False + if parent and isinstance(obj, staticmethod): + # Get real method instance from `parent` if `obj` is a `staticmethod` + is_static_method = True + obj = getattr(parent, name) + + res = DocstringParser.parse_function(obj) + if not res: + # Make it `Any` if we can't parse the docstring + return f"{name}: {writer.include('typing.Any')}" + + params, rtype = res + # We're using signature for getting default values only, currently type hints are not supported + # in signatures. We can leverage signatures more if it supports type hints in future. + sig: Optional[inspect.Signature] = None + try: + sig = inspect.signature(obj) + except: + pass + + def has_default(param: str, ty: str) -> bool: + # PyO3 allows omitting `Option` params while calling a Rust function from Python, + # we should always mark `typing.Optional[T]` values as they have default values to allow same + # flexibiliy as runtime dynamics in type-stubs. + if ty.startswith("typing.Optional["): + return True + + if sig is None: + return False + + sig_param = sig.parameters.get(param) + return sig_param is not None and sig_param.default is not sig_param.empty + + receivers: List[str] = [] + attrs: List[str] = [] + if parent: + if is_static_method: + attrs.append("@staticmethod") + else: + receivers.append("self") + + def make_param(name: str, ty: str) -> str: + fixed_ty = writer.fix_and_include(ty) + param = f"{name}: {fixed_ty}" + if has_default(name, fixed_ty): + param += " = ..." + return param + + params = join(receivers + [make_param(n, t) for n, t in params], delim=", ") + attrs_str = join(attrs) + rtype = writer.fix_and_include(rtype) + body = "..." + if include_docs: + body = join([make_doc(obj), body]) + + return f""" +{attrs_str} +def {name}({params}) -> {rtype}: +{indent(body)} +""".lstrip() + + +def make_class(writer: Writer, name: str, klass: Any) -> str: + bases = list( + filter(lambda n: n != "object", map(lambda b: b.__name__, klass.__bases__)) + ) + class_sig = DocstringParser.parse_class(klass) + if class_sig: + (generics, extends) = class_sig + bases.extend(map(writer.fix_and_include, extends)) + for g in generics: + writer.generic(g) + + members: List[str] = [] + + class_vars: Dict[str, Any] = vars(klass) + for member_name, member in sorted(class_vars.items(), key=lambda k: k[0]): + if member_name.startswith("__"): + continue + + if inspect.isdatadescriptor(member): + members.append( + join( + [ + make_field(writer, member_name, member), + make_doc(member), + ] + ) + ) + elif is_fn_like(member): + members.append( + make_function(writer, member_name, member, parent=klass), + ) + elif isinstance(member, klass): + # Enum variant + members.append( + join( + [ + f"{member_name}: {name}", + make_doc(member), + ] + ) + ) + else: + print(f"Unknown member type: {member}") + + if inspect.getdoc(klass) is not None: + constructor_sig = DocstringParser.parse(klass) + if constructor_sig is not None and ( + # Make sure to only generate `__init__` if the class has a constructor defined + len(constructor_sig.rtypes) > 0 + or len(constructor_sig.params) > 0 + ): + members.append( + make_function( + writer, + "__init__", + klass, + include_docs=False, + parent=klass, + ) + ) + + bases_str = "" if len(bases) == 0 else f"({join(bases, delim=', ')})" + doc = make_doc(klass) + if doc: + doc += "\n" + body = join([doc, join(members, delim="\n\n") or "..."]) + return f"""\ +class {name}{bases_str}: +{indent(body)} +""" + + +def walk_module(writer: Writer, mod: Any): + exported = mod.__all__ + + for (name, member) in inspect.getmembers(mod): + if name not in exported: + continue + + if inspect.ismodule(member): + subpath = writer.path.parent / name / "__init__.pyi" + walk_module(writer.submodule(subpath), member) + elif inspect.isclass(member): + writer.define(make_class(writer, name, member)) + elif is_fn_like(member): + writer.define(make_function(writer, name, member)) + else: + print(f"Unknown type: {member}") + + +if __name__ == "__main__": + import argparse + import importlib + + parser = argparse.ArgumentParser() + parser.add_argument("module") + parser.add_argument("outdir") + args = parser.parse_args() + + path = Path(args.outdir) / f"{args.module}.pyi" + writer = Writer( + path, + args.module, + ) + walk_module( + writer, + importlib.import_module(args.module), + ) + writer.dump() diff --git a/rust-runtime/aws-smithy-http-server-python/examples/stubgen_test.py b/rust-runtime/aws-smithy-http-server-python/examples/stubgen_test.py new file mode 100644 index 00000000000..f1d2aa840a0 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/examples/stubgen_test.py @@ -0,0 +1,407 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import sys +import unittest +from types import ModuleType +from textwrap import dedent +from pathlib import Path +from tempfile import TemporaryDirectory + +from stubgen import Writer, walk_module + + +def create_module(name: str, code: str) -> ModuleType: + mod = ModuleType(name) + exec(dedent(code), mod.__dict__) + if not hasattr(mod, "__all__"): + # Manually populate `__all__` with all the members that doesn't start with `__` + mod.__all__ = [k for k in mod.__dict__.keys() if not k.startswith("__")] # type: ignore + sys.modules[name] = mod + return mod + + +class TestStubgen(unittest.TestCase): + def test_function_without_docstring(self): + self.single_mod( + """ + def foo(): + pass + """, + """ + import typing + + foo: typing.Any + """, + ) + + def test_regular_function(self): + self.single_mod( + """ + def foo(bar): + ''' + :param bar str: + :rtype bool: + ''' + pass + """, + """ + def foo(bar: str) -> bool: + ... + """, + ) + + def test_function_with_default_value(self): + self.single_mod( + """ + def foo(bar, qux=None): + ''' + :param bar int: + :param qux typing.Optional[str]: + :rtype None: + ''' + pass + """, + """ + import typing + + def foo(bar: int, qux: typing.Optional[str] = ...) -> None: + ... + """, + ) + + def test_empty_class(self): + self.single_mod( + """ + class Foo: + pass + """, + """ + class Foo: + ... + """, + ) + + def test_class(self): + self.single_mod( + """ + class Foo: + @property + def bar(self): + ''' + :type typing.List[bool]: + ''' + pass + + def qux(self, a, b, c): + ''' + :param a typing.Dict[typing.List[int]]: + :param b str: + :param c float: + :rtype typing.Union[int, str, bool]: + ''' + pass + """, + """ + import typing + + class Foo: + bar: typing.List[bool] + + def qux(self, a: typing.Dict[typing.List[int]], b: str, c: float) -> typing.Union[int, str, bool]: + ... + """, + ) + + def test_class_with_constructor_signature(self): + self.single_mod( + """ + class Foo: + ''' + :param bar str: + :rtype None: + ''' + """, + """ + class Foo: + def __init__(self, bar: str) -> None: + ... + """, + ) + + def test_class_with_static_method(self): + self.single_mod( + """ + class Foo: + @staticmethod + def bar(name): + ''' + :param name str: + :rtype typing.List[bool]: + ''' + pass + """, + """ + import typing + + class Foo: + @staticmethod + def bar(name: str) -> typing.List[bool]: + ... + """, + ) + + def test_class_with_an_undocumented_descriptor(self): + self.single_mod( + """ + class Foo: + @property + def bar(self): + pass + """, + """ + import typing + + class Foo: + bar: typing.Any + """, + ) + + def test_enum(self): + self.single_mod( + """ + class Foo: + def __init__(self, name): + pass + + Foo.Bar = Foo("Bar") + Foo.Baz = Foo("Baz") + Foo.Qux = Foo("Qux") + """, + """ + class Foo: + Bar: Foo + + Baz: Foo + + Qux: Foo + """, + ) + + def test_generic(self): + self.single_mod( + """ + class Foo: + ''' + :generic T: + :generic U: + :extends typing.Generic[T]: + :extends typing.Generic[U]: + ''' + + @property + def bar(self): + ''' + :type typing.Tuple[T, U]: + ''' + pass + + def baz(self, a): + ''' + :param a U: + :rtype T: + ''' + pass + """, + """ + import typing + + T = typing.TypeVar('T') + U = typing.TypeVar('U') + + class Foo(typing.Generic[T], typing.Generic[U]): + bar: typing.Tuple[T, U] + + def baz(self, a: U) -> T: + ... + """, + ) + + def test_items_with_docstrings(self): + self.single_mod( + """ + class Foo: + ''' + This is the docstring of Foo. + + And it has multiple lines. + + :generic T: + :extends typing.Generic[T]: + :param member T: + ''' + + @property + def bar(self): + ''' + This is the docstring of property `bar`. + + :type typing.Optional[T]: + ''' + pass + + def baz(self, t): + ''' + This is the docstring of method `baz`. + :param t T: + :rtype T: + ''' + pass + """, + ''' + import typing + + T = typing.TypeVar('T') + + class Foo(typing.Generic[T]): + """ + This is the docstring of Foo. + + And it has multiple lines. + """ + + bar: typing.Optional[T] + """ + This is the docstring of property `bar`. + """ + + def baz(self, t: T) -> T: + """ + This is the docstring of method `baz`. + """ + ... + + + def __init__(self, member: T) -> None: + ... + ''', + ) + + def test_adds_default_to_optional_types(self): + # Since PyO3 provides `impl FromPyObject for Option` and maps Python `None` to Rust `None`, + # you don't have to pass `None` explicitly. Type-stubs also shoudln't require `None`s + # to be passed explicitly (meaning they should have a default value). + + self.single_mod( + """ + def foo(bar, qux): + ''' + :param bar typing.Optional[int]: + :param qux typing.List[typing.Optional[int]]: + :rtype int: + ''' + pass + """, + """ + import typing + + def foo(bar: typing.Optional[int] = ..., qux: typing.List[typing.Optional[int]]) -> int: + ... + """, + ) + + def test_multiple_mods(self): + create_module( + "foo.bar", + """ + class Bar: + ''' + :param qux str: + :rtype None: + ''' + pass + """, + ) + + foo = create_module( + "foo", + """ + import sys + + bar = sys.modules["foo.bar"] + + class Foo: + ''' + :param a __root_module_name__.bar.Bar: + :param b typing.Optional[__root_module_name__.bar.Bar]: + :rtype None: + ''' + + @property + def a(self): + ''' + :type __root_module_name__.bar.Bar: + ''' + pass + + @property + def b(self): + ''' + :type typing.Optional[__root_module_name__.bar.Bar]: + ''' + pass + + __all__ = ["bar", "Foo"] + """, + ) + + with TemporaryDirectory() as temp_dir: + foo_path = Path(temp_dir) / "foo.pyi" + bar_path = Path(temp_dir) / "bar" / "__init__.pyi" + + writer = Writer(foo_path, "foo") + walk_module(writer, foo) + writer.dump() + + self.assert_stub( + foo_path, + """ + import foo.bar + import typing + + class Foo: + a: foo.bar.Bar + + b: typing.Optional[foo.bar.Bar] + + def __init__(self, a: foo.bar.Bar, b: typing.Optional[foo.bar.Bar] = ...) -> None: + ... + """, + ) + + self.assert_stub( + bar_path, + """ + class Bar: + def __init__(self, qux: str) -> None: + ... + """, + ) + + def single_mod(self, mod_code: str, expected_stub: str) -> None: + with TemporaryDirectory() as temp_dir: + mod = create_module("test", mod_code) + path = Path(temp_dir) / "test.pyi" + + writer = Writer(path, "test") + walk_module(writer, mod) + writer.dump() + + self.assert_stub(path, expected_stub) + + def assert_stub(self, path: Path, expected: str) -> None: + self.assertEqual(path.read_text().strip(), dedent(expected).strip()) + + +if __name__ == "__main__": + unittest.main() diff --git a/rust-runtime/aws-smithy-http-server-python/src/context.rs b/rust-runtime/aws-smithy-http-server-python/src/context.rs index 04f370a8761..dbbd49f7335 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/context.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/context.rs @@ -16,7 +16,6 @@ mod testing; /// PyContext is a wrapper for context object provided by the user. /// It injects some values (currently only [super::lambda::PyLambdaContext]) that is type-hinted by the user. /// -/// /// PyContext is initialised during the startup, it inspects the provided context object for fields /// that are type-hinted to inject some values provided by the framework (see [PyContext::new()]). /// diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index 06e20e9b520..01596be5071 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -39,12 +39,19 @@ impl From for PyErr { /// /// It allows to specify a message and HTTP status code and implementing protocol specific capabilities /// to build a [aws_smithy_http_server::response::Response] from it. +/// +/// :param message str: +/// :param status_code typing.Optional\[int\]: +/// :rtype None: #[pyclass(name = "MiddlewareException", extends = BasePyException)] -#[pyo3(text_signature = "(message, status_code)")] +#[pyo3(text_signature = "($self, message, status_code=None)")] #[derive(Debug, Clone)] pub struct PyMiddlewareException { + /// :type str: #[pyo3(get, set)] message: String, + + /// :type int: #[pyo3(get, set)] status_code: u16, } diff --git a/rust-runtime/aws-smithy-http-server-python/src/lambda.rs b/rust-runtime/aws-smithy-http-server-python/src/lambda.rs index 15bb9bf1b09..3823e8d93f6 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lambda.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lambda.rs @@ -11,70 +11,111 @@ use lambda_http::Context; use pyo3::pyclass; /// AWS Mobile SDK client fields. -#[pyclass] +#[pyclass(name = "ClientApplication")] #[derive(Clone)] -struct PyClientApplication { +pub struct PyClientApplication { /// The mobile app installation id + /// + /// :type str: #[pyo3(get)] installation_id: String, + /// The app title for the mobile app as registered with AWS' mobile services. + /// + /// :type str: #[pyo3(get)] app_title: String, + /// The version name of the application as registered with AWS' mobile services. + /// + /// :type str: #[pyo3(get)] app_version_name: String, + /// The app version code. + /// + /// :type str: #[pyo3(get)] app_version_code: String, + /// The package name for the mobile application invoking the function + /// + /// :type str: #[pyo3(get)] app_package_name: String, } /// Client context sent by the AWS Mobile SDK. -#[pyclass] +#[pyclass(name = "ClientContext")] #[derive(Clone)] -struct PyClientContext { +pub struct PyClientContext { /// Information about the mobile application invoking the function. + /// + /// :type ClientApplication: #[pyo3(get)] client: PyClientApplication, + /// Custom properties attached to the mobile event context. + /// + /// :type typing.Dict[str, str]: #[pyo3(get)] custom: HashMap, + /// Environment settings from the mobile client. + /// + /// :type typing.Dict[str, str]: #[pyo3(get)] environment: HashMap, } /// Cognito identity information sent with the event -#[pyclass] +#[pyclass(name = "CognitoIdentity")] #[derive(Clone)] -struct PyCognitoIdentity { +pub struct PyCognitoIdentity { /// The unique identity id for the Cognito credentials invoking the function. + /// + /// :type str: #[pyo3(get)] identity_id: String, + /// The identity pool id the caller is "registered" with. + /// + /// :type str: #[pyo3(get)] identity_pool_id: String, } /// Configuration derived from environment variables. -#[pyclass] +#[pyclass(name = "Config")] #[derive(Clone)] -struct PyConfig { +pub struct PyConfig { /// The name of the function. + /// + /// :type str: #[pyo3(get)] function_name: String, + /// The amount of memory available to the function in MB. + /// + /// :type int: #[pyo3(get)] memory: i32, + /// The version of the function being executed. + /// + /// :type str: #[pyo3(get)] version: String, + /// The name of the Amazon CloudWatch Logs stream for the function. + /// + /// :type str: #[pyo3(get)] log_stream: String, + /// The name of the Amazon CloudWatch Logs group for the function. + /// + /// :type str: #[pyo3(get)] log_group: String, } @@ -86,29 +127,49 @@ struct PyConfig { #[pyclass(name = "LambdaContext")] pub struct PyLambdaContext { /// The AWS request ID generated by the Lambda service. + /// + /// :type str: #[pyo3(get)] request_id: String, + /// The execution deadline for the current invocation in milliseconds. + /// + /// :type int: #[pyo3(get)] deadline: u64, + /// The ARN of the Lambda function being invoked. + /// + /// :type str: #[pyo3(get)] invoked_function_arn: String, + /// The X-Ray trace ID for the current invocation. + /// + /// :type typing.Optional\[str\]: #[pyo3(get)] xray_trace_id: Option, + /// The client context object sent by the AWS mobile SDK. This field is /// empty unless the function is invoked using an AWS mobile SDK. + /// + /// :type typing.Optional\[ClientContext\]: #[pyo3(get)] client_context: Option, + /// The Cognito identity that invoked the function. This field is empty /// unless the invocation request to the Lambda APIs was made using AWS /// credentials issues by Amazon Cognito Identity Pools. + /// + /// :type typing.Optional\[CognitoIdentity\]: #[pyo3(get)] identity: Option, + /// Lambda function configuration from the local environment variables. /// Includes information such as the function name, memory allocation, /// version, and log streams. + /// + /// :type Config: #[pyo3(get)] env_config: PyConfig, } diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index b8efefe03f0..17b7a96a035 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] #![cfg_attr(docsrs, feature(doc_cfg))] //! Rust/Python bindings, runtime and utilities. diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index 2492096269e..80d4966ac0b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -86,7 +86,12 @@ fn setup_tracing_subscriber( /// - A new builtin function `logging.py_tracing_event` transcodes `logging.LogRecord`s to `tracing::Event`s. This function /// is not exported in `logging.__all__`, as it is not intended to be called directly. /// - A new class `logging.TracingHandler` provides a `logging.Handler` that delivers all records to `python_tracing`. +/// +/// :param level typing.Optional\[int\]: +/// :param logfile typing.Optional\[pathlib.Path\]: +/// :rtype None: #[pyclass(name = "TracingHandler")] +#[pyo3(text_signature = "($self, level=None, logfile=None)")] #[derive(Debug)] pub struct PyTracingHandler { _guard: Option, @@ -104,6 +109,7 @@ impl PyTracingHandler { Ok(Self { _guard }) } + /// :rtype typing.Any: fn handler(&self, py: Python) -> PyResult> { let logging = py.import("logging")?; logging.setattr( diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs index 118b4e7cf8b..506e5a6c208 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/pytests/layer.rs @@ -253,6 +253,7 @@ fn simple_request(body: &'static str) -> Request { .expect("could not create request") } +#[allow(clippy::type_complexity)] fn spawn_service( layer: L, ) -> ( @@ -306,7 +307,7 @@ fn py_handler(code: &str) -> PyMiddlewareHandler { .get_item("middleware") .expect("your handler must be named `middleware`") .into(); - Ok::<_, PyErr>(PyMiddlewareHandler::new(py, handler)?) + PyMiddlewareHandler::new(py, handler) }) .unwrap() } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs index 16fc9d6d7fb..d1eff86e002 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -17,7 +17,6 @@ use super::{PyHeaderMap, PyMiddlewareError}; /// Python-compatible [Request] object. #[pyclass(name = "Request")] -#[pyo3(text_signature = "(request)")] #[derive(Debug)] pub struct PyRequest { parts: Option, @@ -56,6 +55,8 @@ impl PyRequest { #[pymethods] impl PyRequest { /// Return the HTTP method of this request. + /// + /// :type str: #[getter] fn method(&self) -> PyResult { self.parts @@ -65,6 +66,8 @@ impl PyRequest { } /// Return the URI of this request. + /// + /// :type str: #[getter] fn uri(&self) -> PyResult { self.parts @@ -74,6 +77,8 @@ impl PyRequest { } /// Return the HTTP version of this request. + /// + /// :type str: #[getter] fn version(&self) -> PyResult { self.parts @@ -83,6 +88,8 @@ impl PyRequest { } /// Return the HTTP headers of this request. + /// + /// :type typing.MutableMapping[str, str]: #[getter] fn headers(&self) -> PyHeaderMap { self.headers.clone() @@ -90,6 +97,8 @@ impl PyRequest { /// Return the HTTP body of this request. /// Note that this is a costly operation because the whole request body is cloned. + /// + /// :type typing.Awaitable[bytes]: #[getter] fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let body = self.body.clone(); diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs index 5e3619cb11d..2e97af5e49b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs @@ -17,8 +17,13 @@ use tokio::sync::Mutex; use super::{PyHeaderMap, PyMiddlewareError}; /// Python-compatible [Response] object. +/// +/// :param status int: +/// :param headers typing.Optional[typing.Dict[str, str]]: +/// :param body typing.Optional[bytes]: +/// :rtype None: #[pyclass(name = "Response")] -#[pyo3(text_signature = "(status, headers, body)")] +#[pyo3(text_signature = "($self, status, headers=None, body=None)")] pub struct PyResponse { parts: Option, headers: PyHeaderMap, @@ -78,6 +83,8 @@ impl PyResponse { } /// Return the HTTP status of this response. + /// + /// :type int: #[getter] fn status(&self) -> PyResult { self.parts @@ -87,6 +94,8 @@ impl PyResponse { } /// Return the HTTP version of this response. + /// + /// :type str: #[getter] fn version(&self) -> PyResult { self.parts @@ -96,6 +105,8 @@ impl PyResponse { } /// Return the HTTP headers of this response. + /// + /// :type typing.MutableMapping[str, str]: #[getter] fn headers(&self) -> PyHeaderMap { self.headers.clone() @@ -103,6 +114,8 @@ impl PyResponse { /// Return the HTTP body of this response. /// Note that this is a costly operation because the whole response body is cloned. + /// + /// :type typing.Awaitable[bytes]: #[getter] fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let body = self.body.clone(); diff --git a/rust-runtime/aws-smithy-http-server-python/src/socket.rs b/rust-runtime/aws-smithy-http-server-python/src/socket.rs index 8243aa28c22..13900ff8c8f 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/socket.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/socket.rs @@ -20,7 +20,12 @@ use std::net::SocketAddr; /// computing capacity of the host. /// /// [GIL]: https://wiki.python.org/moin/GlobalInterpreterLock -#[pyclass] +/// +/// :param address str: +/// :param port int: +/// :param backlog typing.Optional\[int\]: +/// :rtype None: +#[pyclass(text_signature = "($self, address, port, backlog=None)")] #[derive(Debug)] pub struct PySocket { pub(crate) inner: Socket, @@ -49,7 +54,8 @@ impl PySocket { /// Clone the inner socket allowing it to be shared between multiple /// Python processes. - #[pyo3(text_signature = "($self, socket, worker_number)")] + /// + /// :rtype PySocket: pub fn try_clone(&self) -> PyResult { let copied = self.inner.try_clone()?; Ok(PySocket { inner: copied }) diff --git a/rust-runtime/aws-smithy-http-server-python/src/tls.rs b/rust-runtime/aws-smithy-http-server-python/src/tls.rs index 0c09a224e8f..538508fcec0 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/tls.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/tls.rs @@ -20,19 +20,30 @@ use tokio_rustls::rustls::{Certificate, Error as RustTlsError, PrivateKey, Serve pub mod listener; /// PyTlsConfig represents TLS configuration created from Python. +/// +/// :param key_path pathlib.Path: +/// :param cert_path pathlib.Path: +/// :param reload_secs int: +/// :rtype None: #[pyclass( name = "TlsConfig", - text_signature = "(*, key_path, cert_path, reload)" + text_signature = "($self, *, key_path, cert_path, reload_secs=86400)" )] #[derive(Clone)] pub struct PyTlsConfig { /// Absolute path of the RSA or PKCS private key. + /// + /// :type pathlib.Path: key_path: PathBuf, /// Absolute path of the x509 certificate. + /// + /// :type pathlib.Path: cert_path: PathBuf, /// Duration to reloading certificates. + /// + /// :type int: reload_secs: u64, } diff --git a/rust-runtime/aws-smithy-http-server-python/src/tls/listener.rs b/rust-runtime/aws-smithy-http-server-python/src/tls/listener.rs index 345e0e24ab9..fb4f759f898 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/tls/listener.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/tls/listener.rs @@ -214,7 +214,7 @@ mod tests { fn client_config_with_cert(cert: &rcgen::Certificate) -> ClientConfig { let mut roots = RootCertStore::empty(); - roots.add_parsable_certificates(&vec![cert.serialize_der().unwrap()]); + roots.add_parsable_certificates(&[cert.serialize_der().unwrap()]); ClientConfig::builder() .with_safe_defaults() .with_root_certificates(roots) @@ -223,7 +223,7 @@ mod tests { fn cert_with_invalid_date() -> rcgen::Certificate { let mut params = rcgen::CertificateParams::new(vec!["localhost".to_string()]); - params.not_after = rcgen::date_time_ymd(1970, 01, 01); + params.not_after = rcgen::date_time_ymd(1970, 1, 1); rcgen::Certificate::from_params(params).unwrap() } diff --git a/rust-runtime/aws-smithy-http-server-python/src/types.rs b/rust-runtime/aws-smithy-http-server-python/src/types.rs index b42ced51a5a..1ae44a0432f 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/types.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/types.rs @@ -27,6 +27,9 @@ use tokio_stream::StreamExt; use crate::PyError; /// Python Wrapper for [aws_smithy_types::Blob]. +/// +/// :param input bytes: +/// :rtype None: #[pyclass] #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Blob(aws_smithy_types::Blob); @@ -58,6 +61,8 @@ impl Blob { } /// Python getter for the `Blob` byte array. + /// + /// :type bytes: #[getter(data)] pub fn get_data(&self) -> &[u8] { self.as_ref() @@ -134,6 +139,9 @@ impl DateTime { #[pymethods] impl DateTime { /// Creates a `DateTime` from a number of seconds since the Unix epoch. + /// + /// :param epoch_seconds int: + /// :rtype DateTime: #[staticmethod] pub fn from_secs(epoch_seconds: i64) -> Self { Self(aws_smithy_types::date_time::DateTime::from_secs( @@ -142,6 +150,9 @@ impl DateTime { } /// Creates a `DateTime` from a number of milliseconds since the Unix epoch. + /// + /// :param epoch_millis int: + /// :rtype DateTime: #[staticmethod] pub fn from_millis(epoch_millis: i64) -> Self { Self(aws_smithy_types::date_time::DateTime::from_secs( @@ -150,6 +161,9 @@ impl DateTime { } /// Creates a `DateTime` from a number of nanoseconds since the Unix epoch. + /// + /// :param epoch_nanos int: + /// :rtype DateTime: #[staticmethod] pub fn from_nanos(epoch_nanos: i128) -> PyResult { Ok(Self( @@ -159,6 +173,12 @@ impl DateTime { } /// Read 1 date of `format` from `s`, expecting either `delim` or EOF. + /// + /// TODO(PythonTyping): How do we represent `char` in Python? + /// + /// :param format Format: + /// :param delim str: + /// :rtype typing.Tuple[DateTime, str]: #[staticmethod] pub fn read(s: &str, format: Format, delim: char) -> PyResult<(Self, &str)> { let (self_, next) = aws_smithy_types::date_time::DateTime::read(s, format.into(), delim) @@ -167,6 +187,10 @@ impl DateTime { } /// Creates a `DateTime` from a number of seconds and a fractional second since the Unix epoch. + /// + /// :param epoch_seconds int: + /// :param fraction float: + /// :rtype DateTime: #[staticmethod] pub fn from_fractional_secs(epoch_seconds: i64, fraction: f64) -> Self { Self(aws_smithy_types::date_time::DateTime::from_fractional_secs( @@ -176,6 +200,10 @@ impl DateTime { } /// Creates a `DateTime` from a number of seconds and sub-second nanos since the Unix epoch. + /// + /// :param seconds int: + /// :param subsecond_nanos int: + /// :rtype DateTime: #[staticmethod] pub fn from_secs_and_nanos(seconds: i64, subsecond_nanos: u32) -> Self { Self(aws_smithy_types::date_time::DateTime::from_secs_and_nanos( @@ -185,6 +213,9 @@ impl DateTime { } /// Creates a `DateTime` from an `f64` representing the number of seconds since the Unix epoch. + /// + /// :param epoch_seconds float: + /// :rtype DateTime: #[staticmethod] pub fn from_secs_f64(epoch_seconds: f64) -> Self { Self(aws_smithy_types::date_time::DateTime::from_secs_f64( @@ -193,6 +224,10 @@ impl DateTime { } /// Parses a `DateTime` from a string using the given `format`. + /// + /// :param s str: + /// :param format Format: + /// :rtype DateTime: #[staticmethod] pub fn from_str(s: &str, format: Format) -> PyResult { Ok(Self( @@ -202,31 +237,43 @@ impl DateTime { } /// Returns the number of nanoseconds since the Unix epoch that this `DateTime` represents. + /// + /// :rtype int: pub fn as_nanos(&self) -> i128 { self.0.as_nanos() } /// Returns the `DateTime` value as an `f64` representing the seconds since the Unix epoch. + /// + /// :rtype float: pub fn as_secs_f64(&self) -> f64 { self.0.as_secs_f64() } /// Returns true if sub-second nanos is greater than zero. + /// + /// :rtype bool: pub fn has_subsec_nanos(&self) -> bool { self.0.has_subsec_nanos() } /// Returns the epoch seconds component of the `DateTime`. + /// + /// :rtype int: pub fn secs(&self) -> i64 { self.0.secs() } /// Returns the sub-second nanos component of the `DateTime`. + /// + /// :rtype int: pub fn subsec_nanos(&self) -> u32 { self.0.subsec_nanos() } /// Converts the `DateTime` to the number of milliseconds since the Unix epoch. + /// + /// :rtype int: pub fn to_millis(&self) -> PyResult { Ok(self.0.to_millis().map_err(PyError::DateTimeConversion)?) } @@ -283,6 +330,9 @@ impl<'date> From<&'date DateTime> for &'date aws_smithy_types::DateTime { /// /// The original Rust [ByteStream](aws_smithy_http::byte_stream::ByteStream) is wrapped inside a `Arc` to allow the type to be /// [Clone] (required by PyO3) and to allow internal mutability, required to fetch the next chunk of data. +/// +/// :param input bytes: +/// :rtype None: #[pyclass] #[derive(Debug, Clone)] pub struct ByteStream(Arc>); @@ -347,6 +397,9 @@ impl ByteStream { /// requiring Python to await this method. /// /// **NOTE:** This method will block the Rust event loop when it is running. + /// + /// :param path str: + /// :rtype ByteStream: #[staticmethod] pub fn from_path_blocking(py: Python, path: String) -> PyResult> { let byte_stream = futures::executor::block_on(async { @@ -360,6 +413,9 @@ impl ByteStream { /// Create a new [ByteStream](aws_smithy_http::byte_stream::ByteStream) from a path, forcing /// Python to await this coroutine. + /// + /// :param path str: + /// :rtype typing.Awaitable[ByteStream]: #[staticmethod] pub fn from_path(py: Python, path: String) -> PyResult<&PyAny> { pyo3_asyncio::tokio::future_into_py(py, async move { @@ -654,10 +710,10 @@ mod tests { .into(), ), "{ - 't': True, - 'foo': 'foo', - 'f42': 42.0, - 'i42': 42, + 't': True, + 'foo': 'foo', + 'f42': 42.0, + 'i42': 42, 'f': False, 'vec': [ 'inner', diff --git a/rust-runtime/aws-smithy-http-server-python/src/util.rs b/rust-runtime/aws-smithy-http-server-python/src/util.rs index b420c26fd45..4cac7f06d27 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/util.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/util.rs @@ -43,7 +43,6 @@ fn is_coroutine(py: Python, func: &PyObject) -> PyResult { } // Checks whether given Python type is `Optional[T]`. -#[allow(unused)] pub fn is_optional_of(py: Python, ty: &PyAny) -> PyResult { // for reference: https://stackoverflow.com/a/56833826 @@ -131,6 +130,7 @@ async def async_func(): }) } + #[allow(clippy::bool_assert_comparison)] #[test] fn check_if_is_optional_of() -> PyResult<()> { pyo3::prepare_freethreaded_python(); diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/Cargo.toml b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/Cargo.toml index da4cb284c08..b548c221f60 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/Cargo.toml @@ -51,7 +51,6 @@ pokemon-service-server-sdk = { path = "../pokemon-service-server-sdk/" } assert_cmd = "2.0" home = "0.5" serial_test = "0.7.0" -wrk-api-bench = "0.0.8" # This dependency is only required for testing the `pokemon-service-tls` program. hyper-rustls = { version = "0.23.0", features = ["http2"] } diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs index 0beb61b1103..7a4d868da0a 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/lib.rs @@ -226,8 +226,8 @@ pub async fn capture_pokemon( Some(event) => { let capturing_event = event.as_event(); if let Ok(attempt) = capturing_event { - let payload = attempt.payload.clone().unwrap_or(CapturingPayload::builder().build()); - let pokeball = payload.pokeball.as_ref().map(|ball| ball.as_str()).unwrap_or(""); + let payload = attempt.payload.clone().unwrap_or_else(|| CapturingPayload::builder().build()); + let pokeball = payload.pokeball().unwrap_or(""); if ! matches!(pokeball, "Master Ball" | "Great Ball" | "Fast Ball") { yield Err( crate::error::CapturePokemonEventsError::InvalidPokeballError( @@ -249,9 +249,7 @@ pub async fn capture_pokemon( if captured { let shiny = rand::thread_rng().gen_range(0..4096) == 0; let pokemon = payload - .name - .as_ref() - .map(|name| name.as_str()) + .name() .unwrap_or("") .to_string(); let pokedex: Vec = (0..255).collect(); diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/tests/benchmark.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/tests/benchmark.rs deleted file mode 100644 index 4499854d795..00000000000 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/tests/benchmark.rs +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -use std::{env, fs::OpenOptions, io::Write, path::Path, time::Duration}; - -use tokio::time; -use wrk_api_bench::{BenchmarkBuilder, HistoryPeriod, WrkBuilder}; - -use crate::helpers::PokemonService; - -mod helpers; - -#[tokio::test] -async fn benchmark() -> Result<(), Box> { - // Benchmarks are expensive, so they run only if the environment - // variable `RUN_BENCHMARKS` is present. - if env::var_os("RUN_BENCHMARKS").is_some() { - let _program = PokemonService::run(); - // Give PokémonService some time to start up. - time::sleep(Duration::from_millis(50)).await; - - // The history directory is cached inside GitHub actions under - // the running use home directory to allow us to recover historical - // data between runs. - let history_dir = if env::var_os("GITHUB_ACTIONS").is_some() { - home::home_dir().unwrap().join(".wrk-api-bench") - } else { - Path::new(".").join(".wrk-api-bench") - }; - - let mut wrk = WrkBuilder::default() - .url(String::from("http://localhost:13734/empty-operation")) - .history_dir(history_dir) - .build()?; - - // Run a single benchmark with 8 threads and 64 connections for 60 seconds. - let benches = vec![BenchmarkBuilder::default() - .duration(Duration::from_secs(90)) - .threads(2) - .connections(32) - .build()?]; - wrk.bench(&benches)?; - - // Calculate deviation from last run and write it to disk. - if let Ok(deviation) = wrk.deviation(HistoryPeriod::Last) { - let mut deviation_file = OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .open("/tmp/smithy_rs_benchmark_deviation.txt") - .unwrap(); - deviation_file.write_all(deviation.to_github_markdown().as_bytes())?; - } - } - Ok(()) -} diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/tests/simple_integration_test.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/tests/simple_integration_test.rs index 5e6efd29dac..7b0a477a34a 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/tests/simple_integration_test.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/tests/simple_integration_test.rs @@ -12,10 +12,7 @@ use crate::helpers::{client, client_http2_only, PokemonService}; use async_stream::stream; use aws_smithy_types::error::display::DisplayErrorContext; use pokemon_service_client::{ - error::{ - AttemptCapturingPokemonEventError, AttemptCapturingPokemonEventErrorKind, GetStorageError, GetStorageErrorKind, - MasterBallUnsuccessful, StorageAccessNotAuthorized, - }, + error::{AttemptCapturingPokemonEventError, GetStorageError, MasterBallUnsuccessful, StorageAccessNotAuthorized}, model::{AttemptCapturingPokemonEvent, CapturingEvent, CapturingPayload}, types::SdkError, }; @@ -77,10 +74,7 @@ async fn simple_integration_test() { let has_not_authorized_error = if let Err(SdkError::ServiceError(context)) = storage_err { matches!( context.err(), - GetStorageError { - kind: GetStorageErrorKind::StorageAccessNotAuthorized(StorageAccessNotAuthorized { .. }), - .. - } + GetStorageError::StorageAccessNotAuthorized(StorageAccessNotAuthorized { .. }), ) } else { false @@ -137,10 +131,7 @@ async fn event_stream_test() { .build()) .build() )); - yield Err(AttemptCapturingPokemonEventError::new( - AttemptCapturingPokemonEventErrorKind::MasterBallUnsuccessful(MasterBallUnsuccessful::builder().build()), - Default::default() - )); + yield Err(AttemptCapturingPokemonEventError::MasterBallUnsuccessful(MasterBallUnsuccessful::builder().build())); // The next event should not happen yield Ok(AttemptCapturingPokemonEvent::Event( CapturingEvent::builder() diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index 804b00d8a30..6031c53e2a2 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] #![cfg_attr(docsrs, feature(doc_cfg))] //! HTTP server runtime and utilities, loosely based on [axum]. diff --git a/rust-runtime/aws-smithy-http-tower/src/lib.rs b/rust-runtime/aws-smithy-http-tower/src/lib.rs index d2a1c546033..92b60c27029 100644 --- a/rust-runtime/aws-smithy-http-tower/src/lib.rs +++ b/rust-runtime/aws-smithy-http-tower/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] + pub mod dispatch; pub mod map_request; pub mod parse_response; diff --git a/rust-runtime/aws-smithy-http/Cargo.toml b/rust-runtime/aws-smithy-http/Cargo.toml index 23bb28437d4..05e39ee18bf 100644 --- a/rust-runtime/aws-smithy-http/Cargo.toml +++ b/rust-runtime/aws-smithy-http/Cargo.toml @@ -39,7 +39,7 @@ tokio-util = { version = "0.7", optional = true } async-stream = "0.3" futures-util = "0.3" hyper = { version = "0.14.12", features = ["stream"] } -pretty_assertions = "1.2" +pretty_assertions = "1.3" proptest = "1" tokio = { version = "1.8.4", features = [ "macros", diff --git a/rust-runtime/aws-smithy-http/external-types.toml b/rust-runtime/aws-smithy-http/external-types.toml index a3dbbce3c57..b06231e92f5 100644 --- a/rust-runtime/aws-smithy-http/external-types.toml +++ b/rust-runtime/aws-smithy-http/external-types.toml @@ -26,9 +26,6 @@ allowed_external_types = [ # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Feature gate references to Tokio `File` "tokio::fs::file::File", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Decide if `InvalidUri` should be exposed - "http::uri::InvalidUri", - # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Don't expose `once_cell` in public API "once_cell::sync::Lazy", diff --git a/rust-runtime/aws-smithy-http/src/body.rs b/rust-runtime/aws-smithy-http/src/body.rs index 5e9b4fbf2d7..8fea913b2ba 100644 --- a/rust-runtime/aws-smithy-http/src/body.rs +++ b/rust-runtime/aws-smithy-http/src/body.rs @@ -277,6 +277,7 @@ mod test { assert_eq!(SdkBody::from("").size_hint().exact(), Some(0)); } + #[allow(clippy::bool_assert_comparison)] #[test] fn valid_eos() { assert_eq!(SdkBody::from("hello").is_end_stream(), false); diff --git a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs index 8ecc1b7fd4f..fb7af53714b 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -228,9 +228,7 @@ mod tests { type Input = TestServiceError; fn marshall(&self, _input: Self::Input) -> Result { - Err(Message::read_from(&b""[..]) - .err() - .expect("this should always fail")) + Err(Message::read_from(&b""[..]).expect_err("this should always fail")) } } diff --git a/rust-runtime/aws-smithy-http/src/http.rs b/rust-runtime/aws-smithy-http/src/http.rs new file mode 100644 index 00000000000..4c7bcbe93b9 --- /dev/null +++ b/rust-runtime/aws-smithy-http/src/http.rs @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use http::{HeaderMap, HeaderValue}; + +/// Trait for accessing HTTP headers. +/// +/// Useful for generic impls so that they can access headers via trait bounds. +pub trait HttpHeaders { + /// Returns a reference to the associated header map. + fn http_headers(&self) -> &HeaderMap; + + /// Returns a mutable reference to the associated header map. + fn http_headers_mut(&mut self) -> &mut HeaderMap; +} + +impl HttpHeaders for http::Response { + fn http_headers(&self) -> &HeaderMap { + self.headers() + } + + fn http_headers_mut(&mut self) -> &mut HeaderMap { + self.headers_mut() + } +} + +impl HttpHeaders for crate::operation::Response { + fn http_headers(&self) -> &HeaderMap { + self.http().http_headers() + } + + fn http_headers_mut(&mut self) -> &mut HeaderMap { + self.http_mut().http_headers_mut() + } +} diff --git a/rust-runtime/aws-smithy-http/src/lib.rs b/rust-runtime/aws-smithy-http/src/lib.rs index 54449bf38ef..f777e15c82a 100644 --- a/rust-runtime/aws-smithy-http/src/lib.rs +++ b/rust-runtime/aws-smithy-http/src/lib.rs @@ -15,17 +15,21 @@ //! | `rt-tokio` | Provides features that are dependent on `tokio` including the `ByteStream::from_path` util | //! | `event-stream` | Provides Sender/Receiver implementations for Event Stream codegen. | +#![allow(clippy::derive_partial_eq_without_eq)] #![cfg_attr(docsrs, feature(doc_cfg))] pub mod body; pub mod endpoint; pub mod header; +pub mod http; pub mod http_versions; pub mod label; pub mod middleware; pub mod operation; pub mod property_bag; pub mod query; +#[doc(hidden)] +pub mod query_writer; pub mod response; pub mod result; pub mod retry; diff --git a/aws/rust-runtime/aws-sigv4/src/http_request/query_writer.rs b/rust-runtime/aws-smithy-http/src/query_writer.rs similarity index 92% rename from aws/rust-runtime/aws-sigv4/src/http_request/query_writer.rs rename to rust-runtime/aws-smithy-http/src/query_writer.rs index 40a98d9aba3..a7f9d7cfe13 100644 --- a/aws/rust-runtime/aws-sigv4/src/http_request/query_writer.rs +++ b/rust-runtime/aws-smithy-http/src/query_writer.rs @@ -3,11 +3,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -use crate::http_request::url_escape::percent_encode_query; +use crate::query::fmt_string as percent_encode_query; use http::Uri; /// Utility for updating the query string in a [`Uri`]. -pub(super) struct QueryWriter { +pub struct QueryWriter { base_uri: Uri, new_path_and_query: String, prefix: Option, @@ -15,7 +15,7 @@ pub(super) struct QueryWriter { impl QueryWriter { /// Creates a new `QueryWriter` based off the given `uri`. - pub(super) fn new(uri: &Uri) -> Self { + pub fn new(uri: &Uri) -> Self { let new_path_and_query = uri .path_and_query() .map(|pq| pq.to_string()) @@ -35,7 +35,7 @@ impl QueryWriter { } /// Clears all query parameters. - pub(super) fn clear_params(&mut self) { + pub fn clear_params(&mut self) { if let Some(index) = self.new_path_and_query.find('?') { self.new_path_and_query.truncate(index); self.prefix = Some('?'); @@ -44,7 +44,7 @@ impl QueryWriter { /// Inserts a new query parameter. The key and value are percent encoded /// by `QueryWriter`. Passing in percent encoded values will result in double encoding. - pub(super) fn insert(&mut self, k: &str, v: &str) { + pub fn insert(&mut self, k: &str, v: &str) { if let Some(prefix) = self.prefix { self.new_path_and_query.push(prefix); } @@ -56,12 +56,12 @@ impl QueryWriter { } /// Returns just the built query string. - pub(super) fn build_query(self) -> String { + pub fn build_query(self) -> String { self.build_uri().query().unwrap_or_default().to_string() } /// Returns a full [`Uri`] with the query string updated. - pub(super) fn build_uri(self) -> Uri { + pub fn build_uri(self) -> Uri { let mut parts = self.base_uri.into_parts(); parts.path_and_query = Some( self.new_path_and_query @@ -142,7 +142,7 @@ mod test { let mut query_writer = QueryWriter::new(&uri); query_writer.insert("key", value); - if let Err(_) = std::panic::catch_unwind(|| query_writer.build_uri()) { + if std::panic::catch_unwind(|| query_writer.build_uri()).is_err() { problematic_chars.push(char::from(byte)); }; } diff --git a/rust-runtime/aws-smithy-http/src/result.rs b/rust-runtime/aws-smithy-http/src/result.rs index 6cf2fc5c784..f11dcc2d36e 100644 --- a/rust-runtime/aws-smithy-http/src/result.rs +++ b/rust-runtime/aws-smithy-http/src/result.rs @@ -13,6 +13,8 @@ //! `Result` wrapper types for [success](SdkSuccess) and [failure](SdkError) responses. use crate::operation; +use aws_smithy_types::error::metadata::{ProvideErrorMetadata, EMPTY_ERROR_METADATA}; +use aws_smithy_types::error::ErrorMetadata; use aws_smithy_types::retry::ErrorKind; use std::error::Error; use std::fmt; @@ -30,18 +32,183 @@ pub struct SdkSuccess { pub parsed: O, } +/// Builders for `SdkError` variant context. +pub mod builders { + use super::*; + + macro_rules! source_only_error_builder { + ($errorName:ident, $builderName:ident, $sourceType:ident) => { + #[doc = concat!("Builder for [`", stringify!($errorName), "`](super::", stringify!($errorName), ").")] + #[derive(Debug, Default)] + pub struct $builderName { + source: Option<$sourceType>, + } + + impl $builderName { + #[doc = "Creates a new builder."] + pub fn new() -> Self { Default::default() } + + #[doc = "Sets the error source."] + pub fn source(mut self, source: impl Into<$sourceType>) -> Self { + self.source = Some(source.into()); + self + } + + #[doc = "Sets the error source."] + pub fn set_source(&mut self, source: Option<$sourceType>) -> &mut Self { + self.source = source; + self + } + + #[doc = "Builds the error context."] + pub fn build(self) -> $errorName { + $errorName { source: self.source.expect("source is required") } + } + } + }; + } + + source_only_error_builder!(ConstructionFailure, ConstructionFailureBuilder, BoxError); + source_only_error_builder!(TimeoutError, TimeoutErrorBuilder, BoxError); + source_only_error_builder!(DispatchFailure, DispatchFailureBuilder, ConnectorError); + + /// Builder for [`ResponseError`](super::ResponseError). + #[derive(Debug)] + pub struct ResponseErrorBuilder { + source: Option, + raw: Option, + } + + impl Default for ResponseErrorBuilder { + fn default() -> Self { + Self { + source: None, + raw: None, + } + } + } + + impl ResponseErrorBuilder { + /// Creates a new builder. + pub fn new() -> Self { + Default::default() + } + + /// Sets the error source. + pub fn source(mut self, source: impl Into) -> Self { + self.source = Some(source.into()); + self + } + + /// Sets the error source. + pub fn set_source(&mut self, source: Option) -> &mut Self { + self.source = source; + self + } + + /// Sets the raw response. + pub fn raw(mut self, raw: R) -> Self { + self.raw = Some(raw); + self + } + + /// Sets the raw response. + pub fn set_raw(&mut self, raw: Option) -> &mut Self { + self.raw = raw; + self + } + + /// Builds the error context. + pub fn build(self) -> ResponseError { + ResponseError { + source: self.source.expect("source is required"), + raw: self.raw.expect("a raw response is required"), + } + } + } + + /// Builder for [`ServiceError`](super::ServiceError). + #[derive(Debug)] + pub struct ServiceErrorBuilder { + source: Option, + raw: Option, + } + + impl Default for ServiceErrorBuilder { + fn default() -> Self { + Self { + source: None, + raw: None, + } + } + } + + impl ServiceErrorBuilder { + /// Creates a new builder. + pub fn new() -> Self { + Default::default() + } + + /// Sets the error source. + pub fn source(mut self, source: impl Into) -> Self { + self.source = Some(source.into()); + self + } + + /// Sets the error source. + pub fn set_source(&mut self, source: Option) -> &mut Self { + self.source = source; + self + } + + /// Sets the raw response. + pub fn raw(mut self, raw: R) -> Self { + self.raw = Some(raw); + self + } + + /// Sets the raw response. + pub fn set_raw(&mut self, raw: Option) -> &mut Self { + self.raw = raw; + self + } + + /// Builds the error context. + pub fn build(self) -> ServiceError { + ServiceError { + source: self.source.expect("source is required"), + raw: self.raw.expect("a raw response is required"), + } + } + } +} + /// Error context for [`SdkError::ConstructionFailure`] #[derive(Debug)] pub struct ConstructionFailure { source: BoxError, } +impl ConstructionFailure { + /// Creates a builder for this error context type. + pub fn builder() -> builders::ConstructionFailureBuilder { + builders::ConstructionFailureBuilder::new() + } +} + /// Error context for [`SdkError::TimeoutError`] #[derive(Debug)] pub struct TimeoutError { source: BoxError, } +impl TimeoutError { + /// Creates a builder for this error context type. + pub fn builder() -> builders::TimeoutErrorBuilder { + builders::TimeoutErrorBuilder::new() + } +} + /// Error context for [`SdkError::DispatchFailure`] #[derive(Debug)] pub struct DispatchFailure { @@ -49,6 +216,11 @@ pub struct DispatchFailure { } impl DispatchFailure { + /// Creates a builder for this error context type. + pub fn builder() -> builders::DispatchFailureBuilder { + builders::DispatchFailureBuilder::new() + } + /// Returns true if the error is an IO error pub fn is_io(&self) -> bool { self.source.is_io() @@ -80,6 +252,11 @@ pub struct ResponseError { } impl ResponseError { + /// Creates a builder for this error context type. + pub fn builder() -> builders::ResponseErrorBuilder { + builders::ResponseErrorBuilder::new() + } + /// Returns a reference to the raw response pub fn raw(&self) -> &R { &self.raw @@ -101,6 +278,11 @@ pub struct ServiceError { } impl ServiceError { + /// Creates a builder for this error context type. + pub fn builder() -> builders::ServiceErrorBuilder { + builders::ServiceErrorBuilder::new() + } + /// Returns the underlying error of type `E` pub fn err(&self) -> &E { &self.source @@ -126,8 +308,11 @@ impl ServiceError { /// /// This trait exists so that [`SdkError::into_service_error`] can be infallible. pub trait CreateUnhandledError { - /// Creates an unhandled error variant with the given `source`. - fn create_unhandled_error(source: Box) -> Self; + /// Creates an unhandled error variant with the given `source` and error metadata. + fn create_unhandled_error( + source: Box, + meta: Option, + ) -> Self; } /// Failed SDK Result @@ -200,19 +385,21 @@ impl SdkError { /// /// ```no_run /// # use aws_smithy_http::result::{SdkError, CreateUnhandledError}; - /// # #[derive(Debug)] enum GetObjectErrorKind { NoSuchKey(()), Other(()) } - /// # #[derive(Debug)] struct GetObjectError { kind: GetObjectErrorKind } + /// # #[derive(Debug)] enum GetObjectError { NoSuchKey(()), Other(()) } /// # impl std::fmt::Display for GetObjectError { /// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { unimplemented!() } /// # } /// # impl std::error::Error for GetObjectError {} /// # impl CreateUnhandledError for GetObjectError { - /// # fn create_unhandled_error(_: Box) -> Self { unimplemented!() } + /// # fn create_unhandled_error( + /// # _: Box, + /// # _: Option, + /// # ) -> Self { unimplemented!() } /// # } /// # fn example() -> Result<(), GetObjectError> { - /// # let sdk_err = SdkError::service_error(GetObjectError { kind: GetObjectErrorKind::NoSuchKey(()) }, ()); + /// # let sdk_err = SdkError::service_error(GetObjectError::NoSuchKey(()), ()); /// match sdk_err.into_service_error() { - /// GetObjectError { kind: GetObjectErrorKind::NoSuchKey(_) } => { + /// GetObjectError::NoSuchKey(_) => { /// // handle NoSuchKey /// } /// err @ _ => return Err(err), @@ -227,7 +414,7 @@ impl SdkError { { match self { Self::ServiceError(context) => context.source, - _ => E::create_unhandled_error(self.into()), + _ => E::create_unhandled_error(self.into(), None), } } @@ -278,6 +465,21 @@ where } } +impl ProvideErrorMetadata for SdkError +where + E: ProvideErrorMetadata, +{ + fn meta(&self) -> &aws_smithy_types::Error { + match self { + Self::ConstructionFailure(_) => &EMPTY_ERROR_METADATA, + Self::TimeoutError(_) => &EMPTY_ERROR_METADATA, + Self::DispatchFailure(_) => &EMPTY_ERROR_METADATA, + Self::ResponseError(_) => &EMPTY_ERROR_METADATA, + Self::ServiceError(err) => err.source.meta(), + } + } +} + #[derive(Debug)] enum ConnectorErrorKind { /// A timeout occurred while processing the request diff --git a/rust-runtime/aws-smithy-json/src/lib.rs b/rust-runtime/aws-smithy-json/src/lib.rs index 58da71b621b..a4e6904924f 100644 --- a/rust-runtime/aws-smithy-json/src/lib.rs +++ b/rust-runtime/aws-smithy-json/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] + //! JSON Abstractions for Smithy pub mod deserialize; diff --git a/rust-runtime/aws-smithy-protocol-test/Cargo.toml b/rust-runtime/aws-smithy-protocol-test/Cargo.toml index b4d5bb516b8..e18aea12b40 100644 --- a/rust-runtime/aws-smithy-protocol-test/Cargo.toml +++ b/rust-runtime/aws-smithy-protocol-test/Cargo.toml @@ -15,7 +15,7 @@ regex = "1.5" # Not perfect for our needs, but good for now assert-json-diff = "1.1" -pretty_assertions = "1.0" +pretty_assertions = "1.3" roxmltree = "0.14.1" diff --git a/rust-runtime/aws-smithy-query/src/lib.rs b/rust-runtime/aws-smithy-query/src/lib.rs index 3b9d57a9d02..26e5ab50bfe 100644 --- a/rust-runtime/aws-smithy-query/src/lib.rs +++ b/rust-runtime/aws-smithy-query/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] + //! Abstractions for the Smithy AWS Query protocol use aws_smithy_types::date_time::{DateTimeFormatError, Format}; diff --git a/rust-runtime/aws-smithy-types-convert/src/lib.rs b/rust-runtime/aws-smithy-types-convert/src/lib.rs index 4487db52b66..c18ebec8e31 100644 --- a/rust-runtime/aws-smithy-types-convert/src/lib.rs +++ b/rust-runtime/aws-smithy-types-convert/src/lib.rs @@ -5,6 +5,7 @@ //! Conversions between `aws-smithy-types` and the types of frequently used Rust libraries. +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_docs, rustdoc::missing_crate_level_docs, diff --git a/rust-runtime/aws-smithy-types/Cargo.toml b/rust-runtime/aws-smithy-types/Cargo.toml index 079664ada13..30063308d38 100644 --- a/rust-runtime/aws-smithy-types/Cargo.toml +++ b/rust-runtime/aws-smithy-types/Cargo.toml @@ -12,7 +12,7 @@ itoa = "1.0.0" num-integer = "0.1.44" ryu = "1.0.5" time = { version = "0.3.4", features = ["parsing"] } -base64-simd = "0.7" +base64-simd = "0.8" [target.'cfg(aws_sdk_unstable)'.dependencies.serde] version = "1" diff --git a/rust-runtime/aws-smithy-types/benches/base64.rs b/rust-runtime/aws-smithy-types/benches/base64.rs index d29654b0591..0539c190ff2 100644 --- a/rust-runtime/aws-smithy-types/benches/base64.rs +++ b/rust-runtime/aws-smithy-types/benches/base64.rs @@ -4,7 +4,6 @@ */ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use rand; use rand::distributions::{Alphanumeric, DistString}; /// Generates a random string of a given length @@ -155,6 +154,7 @@ mod handrolled_base64 { } /// Failure to decode a base64 value. + #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone, Eq, PartialEq)] #[non_exhaustive] pub enum DecodeError { diff --git a/rust-runtime/aws-smithy-types/src/base64.rs b/rust-runtime/aws-smithy-types/src/base64.rs index 460a07e33d4..76a7943ccd4 100644 --- a/rust-runtime/aws-smithy-types/src/base64.rs +++ b/rust-runtime/aws-smithy-types/src/base64.rs @@ -3,9 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -//! A thin wrapper over `base64-simd` +//! A thin wrapper over [`base64-simd`](https://docs.rs/base64-simd/) -use base64_simd::Base64; +use base64_simd::STANDARD; use std::error::Error; /// Failure to decode a base64 value. @@ -28,20 +28,15 @@ impl std::fmt::Display for DecodeError { /// /// If input is not a valid base64 encoded string, this function will return `DecodeError`. pub fn decode(input: impl AsRef) -> Result, DecodeError> { - Base64::STANDARD - .decode_to_boxed_bytes(input.as_ref().as_bytes()) - .map(|bytes| bytes.into_vec()) - .map_err(DecodeError) + STANDARD.decode_to_vec(input.as_ref()).map_err(DecodeError) } /// Encode `input` into base64 using the standard base64 alphabet pub fn encode(input: impl AsRef<[u8]>) -> String { - Base64::STANDARD - .encode_to_boxed_str(input.as_ref()) - .into_string() + STANDARD.encode_to_string(input.as_ref()) } /// Returns the base64 representation's length for the given `length` of data pub fn encoded_length(length: usize) -> usize { - Base64::STANDARD.encoded_length(length) + STANDARD.encoded_length(length) } diff --git a/rust-runtime/aws-smithy-types/src/error.rs b/rust-runtime/aws-smithy-types/src/error.rs index dc41a67d833..b83d6ffe077 100644 --- a/rust-runtime/aws-smithy-types/src/error.rs +++ b/rust-runtime/aws-smithy-types/src/error.rs @@ -3,147 +3,16 @@ * SPDX-License-Identifier: Apache-2.0 */ -//! Generic errors for Smithy codegen +//! Errors for Smithy codegen -use crate::retry::{ErrorKind, ProvideErrorKind}; -use std::collections::HashMap; use std::fmt; pub mod display; +pub mod metadata; +mod unhandled; -/// Generic Error type -/// -/// For many services, Errors are modeled. However, many services only partially model errors or don't -/// model errors at all. In these cases, the SDK will return this generic error type to expose the -/// `code`, `message` and `request_id`. -#[derive(Debug, Eq, PartialEq, Default, Clone)] -pub struct Error { - code: Option, - message: Option, - request_id: Option, - extras: HashMap<&'static str, String>, -} - -/// Builder for [`Error`]. -#[derive(Debug, Default)] -pub struct Builder { - inner: Error, -} - -impl Builder { - /// Sets the error message. - pub fn message(&mut self, message: impl Into) -> &mut Self { - self.inner.message = Some(message.into()); - self - } - - /// Sets the error code. - pub fn code(&mut self, code: impl Into) -> &mut Self { - self.inner.code = Some(code.into()); - self - } - - /// Sets the request ID the error happened for. - pub fn request_id(&mut self, request_id: impl Into) -> &mut Self { - self.inner.request_id = Some(request_id.into()); - self - } - - /// Set a custom field on the error metadata - /// - /// Typically, these will be accessed with an extension trait: - /// ```rust - /// use aws_smithy_types::Error; - /// const HOST_ID: &str = "host_id"; - /// trait S3ErrorExt { - /// fn extended_request_id(&self) -> Option<&str>; - /// } - /// - /// impl S3ErrorExt for Error { - /// fn extended_request_id(&self) -> Option<&str> { - /// self.extra(HOST_ID) - /// } - /// } - /// - /// fn main() { - /// // Extension trait must be brought into scope - /// use S3ErrorExt; - /// let sdk_response: Result<(), Error> = Err(Error::builder().custom(HOST_ID, "x-1234").build()); - /// if let Err(err) = sdk_response { - /// println!("request id: {:?}, extended request id: {:?}", err.request_id(), err.extended_request_id()); - /// } - /// } - /// ``` - pub fn custom(&mut self, key: &'static str, value: impl Into) -> &mut Self { - self.inner.extras.insert(key, value.into()); - self - } - - /// Creates the error. - pub fn build(&mut self) -> Error { - std::mem::take(&mut self.inner) - } -} - -impl Error { - /// Returns the error code. - pub fn code(&self) -> Option<&str> { - self.code.as_deref() - } - /// Returns the error message. - pub fn message(&self) -> Option<&str> { - self.message.as_deref() - } - /// Returns the request ID the error occurred for, if it's available. - pub fn request_id(&self) -> Option<&str> { - self.request_id.as_deref() - } - /// Returns additional information about the error if it's present. - pub fn extra(&self, key: &'static str) -> Option<&str> { - self.extras.get(key).map(|k| k.as_str()) - } - - /// Creates an `Error` builder. - pub fn builder() -> Builder { - Builder::default() - } - - /// Converts an `Error` into a builder. - pub fn into_builder(self) -> Builder { - Builder { inner: self } - } -} - -impl ProvideErrorKind for Error { - fn retryable_error_kind(&self) -> Option { - None - } - - fn code(&self) -> Option<&str> { - Error::code(self) - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut fmt = f.debug_struct("Error"); - if let Some(code) = &self.code { - fmt.field("code", code); - } - if let Some(message) = &self.message { - fmt.field("message", message); - } - if let Some(req_id) = &self.request_id { - fmt.field("request_id", req_id); - } - for (k, v) in &self.extras { - fmt.field(k, &v); - } - fmt.finish() - } -} - -impl std::error::Error for Error {} +pub use metadata::ErrorMetadata; +pub use unhandled::Unhandled; #[derive(Debug)] pub(super) enum TryFromNumberErrorKind { diff --git a/rust-runtime/aws-smithy-types/src/error/metadata.rs b/rust-runtime/aws-smithy-types/src/error/metadata.rs new file mode 100644 index 00000000000..06925e13f9c --- /dev/null +++ b/rust-runtime/aws-smithy-types/src/error/metadata.rs @@ -0,0 +1,166 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Error metadata + +use crate::retry::{ErrorKind, ProvideErrorKind}; +use std::collections::HashMap; +use std::fmt; + +/// Trait to retrieve error metadata from a result +pub trait ProvideErrorMetadata { + /// Returns error metadata, which includes the error code, message, + /// request ID, and potentially additional information. + fn meta(&self) -> &ErrorMetadata; + + /// Returns the error code if it's available. + fn code(&self) -> Option<&str> { + self.meta().code() + } + + /// Returns the error message, if there is one. + fn message(&self) -> Option<&str> { + self.meta().message() + } +} + +/// Empty error metadata +#[doc(hidden)] +pub const EMPTY_ERROR_METADATA: ErrorMetadata = ErrorMetadata { + code: None, + message: None, + extras: None, +}; + +/// Generic Error type +/// +/// For many services, Errors are modeled. However, many services only partially model errors or don't +/// model errors at all. In these cases, the SDK will return this generic error type to expose the +/// `code`, `message` and `request_id`. +#[derive(Debug, Eq, PartialEq, Default, Clone)] +pub struct ErrorMetadata { + code: Option, + message: Option, + extras: Option>, +} + +/// Builder for [`ErrorMetadata`]. +#[derive(Debug, Default)] +pub struct Builder { + inner: ErrorMetadata, +} + +impl Builder { + /// Sets the error message. + pub fn message(mut self, message: impl Into) -> Self { + self.inner.message = Some(message.into()); + self + } + + /// Sets the error code. + pub fn code(mut self, code: impl Into) -> Self { + self.inner.code = Some(code.into()); + self + } + + /// Set a custom field on the error metadata + /// + /// Typically, these will be accessed with an extension trait: + /// ```rust + /// use aws_smithy_types::Error; + /// const HOST_ID: &str = "host_id"; + /// trait S3ErrorExt { + /// fn extended_request_id(&self) -> Option<&str>; + /// } + /// + /// impl S3ErrorExt for Error { + /// fn extended_request_id(&self) -> Option<&str> { + /// self.extra(HOST_ID) + /// } + /// } + /// + /// fn main() { + /// // Extension trait must be brought into scope + /// use S3ErrorExt; + /// let sdk_response: Result<(), Error> = Err(Error::builder().custom(HOST_ID, "x-1234").build()); + /// if let Err(err) = sdk_response { + /// println!("extended request id: {:?}", err.extended_request_id()); + /// } + /// } + /// ``` + pub fn custom(mut self, key: &'static str, value: impl Into) -> Self { + if self.inner.extras.is_none() { + self.inner.extras = Some(HashMap::new()); + } + self.inner + .extras + .as_mut() + .unwrap() + .insert(key, value.into()); + self + } + + /// Creates the error. + pub fn build(self) -> ErrorMetadata { + self.inner + } +} + +impl ErrorMetadata { + /// Returns the error code. + pub fn code(&self) -> Option<&str> { + self.code.as_deref() + } + /// Returns the error message. + pub fn message(&self) -> Option<&str> { + self.message.as_deref() + } + /// Returns additional information about the error if it's present. + pub fn extra(&self, key: &'static str) -> Option<&str> { + self.extras + .as_ref() + .and_then(|extras| extras.get(key).map(|k| k.as_str())) + } + + /// Creates an `Error` builder. + pub fn builder() -> Builder { + Builder::default() + } + + /// Converts an `Error` into a builder. + pub fn into_builder(self) -> Builder { + Builder { inner: self } + } +} + +impl ProvideErrorKind for ErrorMetadata { + fn retryable_error_kind(&self) -> Option { + None + } + + fn code(&self) -> Option<&str> { + ErrorMetadata::code(self) + } +} + +impl fmt::Display for ErrorMetadata { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut fmt = f.debug_struct("Error"); + if let Some(code) = &self.code { + fmt.field("code", code); + } + if let Some(message) = &self.message { + fmt.field("message", message); + } + if let Some(extras) = &self.extras { + for (k, v) in extras { + fmt.field(k, &v); + } + } + fmt.finish() + } +} + +impl std::error::Error for ErrorMetadata {} diff --git a/rust-runtime/aws-smithy-types/src/error/unhandled.rs b/rust-runtime/aws-smithy-types/src/error/unhandled.rs new file mode 100644 index 00000000000..2397d700ffc --- /dev/null +++ b/rust-runtime/aws-smithy-types/src/error/unhandled.rs @@ -0,0 +1,90 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Unhandled error type. + +use crate::error::{metadata::ProvideErrorMetadata, ErrorMetadata}; +use std::error::Error as StdError; + +/// Builder for [`Unhandled`] +#[derive(Default, Debug)] +pub struct Builder { + source: Option>, + meta: Option, +} + +impl Builder { + /// Sets the error source + pub fn source(mut self, source: impl Into>) -> Self { + self.source = Some(source.into()); + self + } + + /// Sets the error source + pub fn set_source( + &mut self, + source: Option>, + ) -> &mut Self { + self.source = source; + self + } + + /// Sets the error metadata + pub fn meta(mut self, meta: ErrorMetadata) -> Self { + self.meta = Some(meta); + self + } + + /// Sets the error metadata + pub fn set_meta(&mut self, meta: Option) -> &mut Self { + self.meta = meta; + self + } + + /// Builds the unhandled error + pub fn build(self) -> Unhandled { + Unhandled { + source: self.source.expect("unhandled errors must have a source"), + meta: self.meta.unwrap_or_default(), + } + } +} + +/// An unexpected error occurred (e.g., invalid JSON returned by the service or an unknown error code). +/// +/// When logging an error from the SDK, it is recommended that you either wrap the error in +/// [`DisplayErrorContext`](crate::error::display::DisplayErrorContext), use another +/// error reporter library that visits the error's cause/source chain, or call +/// [`Error::source`](std::error::Error::source) for more details about the underlying cause. +#[derive(Debug)] +pub struct Unhandled { + source: Box, + meta: ErrorMetadata, +} + +impl Unhandled { + /// Returns a builder to construct an unhandled error. + pub fn builder() -> Builder { + Default::default() + } +} + +impl std::fmt::Display for Unhandled { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!(f, "unhandled error") + } +} + +impl StdError for Unhandled { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(self.source.as_ref() as _) + } +} + +impl ProvideErrorMetadata for Unhandled { + fn meta(&self) -> &ErrorMetadata { + &self.meta + } +} diff --git a/rust-runtime/aws-smithy-types/src/lib.rs b/rust-runtime/aws-smithy-types/src/lib.rs index f6542733c6c..b5cdb4974f9 100644 --- a/rust-runtime/aws-smithy-types/src/lib.rs +++ b/rust-runtime/aws-smithy-types/src/lib.rs @@ -5,6 +5,7 @@ //! Protocol-agnostic types for smithy-rs. +#![allow(clippy::derive_partial_eq_without_eq)] #![warn( missing_docs, rustdoc::missing_crate_level_docs, @@ -24,8 +25,543 @@ pub mod primitive; pub mod retry; pub mod timeout; -pub use blob::Blob; -pub use date_time::DateTime; -pub use document::Document; -pub use error::Error; -pub use number::Number; +pub use crate::date_time::DateTime; + +// TODO(deprecated): Remove deprecated re-export +/// Use [error::ErrorMetadata] instead. +#[deprecated( + note = "`aws_smithy_types::Error` has been renamed to `aws_smithy_types::error::ErrorMetadata`" +)] +pub use error::ErrorMetadata as Error; + +/// Binary Blob Type +/// +/// Blobs represent protocol-agnostic binary content. +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub struct Blob { + inner: Vec, +} + +impl Blob { + /// Creates a new blob from the given `input`. + pub fn new>>(input: T) -> Self { + Blob { + inner: input.into(), + } + } + + /// Consumes the `Blob` and returns a `Vec` with its contents. + pub fn into_inner(self) -> Vec { + self.inner + } +} + +impl AsRef<[u8]> for Blob { + fn as_ref(&self) -> &[u8] { + &self.inner + } +} + +/* ANCHOR: document */ + +/// Document Type +/// +/// Document types represents protocol-agnostic open content that is accessed like JSON data. +/// Open content is useful for modeling unstructured data that has no schema, data that can't be +/// modeled using rigid types, or data that has a schema that evolves outside of the purview of a model. +/// The serialization format of a document is an implementation detail of a protocol. +#[derive(Debug, Clone, PartialEq)] +pub enum Document { + /// JSON object + Object(HashMap), + /// JSON array + Array(Vec), + /// JSON number + Number(Number), + /// JSON string + String(String), + /// JSON boolean + Bool(bool), + /// JSON null + Null, +} + +impl From for Document { + fn from(value: bool) -> Self { + Document::Bool(value) + } +} + +impl From for Document { + fn from(value: String) -> Self { + Document::String(value) + } +} + +impl From> for Document { + fn from(values: Vec) -> Self { + Document::Array(values) + } +} + +impl From> for Document { + fn from(values: HashMap) -> Self { + Document::Object(values) + } +} + +impl From for Document { + fn from(value: u64) -> Self { + Document::Number(Number::PosInt(value)) + } +} + +impl From for Document { + fn from(value: i64) -> Self { + Document::Number(Number::NegInt(value)) + } +} + +impl From for Document { + fn from(value: i32) -> Self { + Document::Number(Number::NegInt(value as i64)) + } +} + +/// A number type that implements Javascript / JSON semantics, modeled on serde_json: +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Number { + /// Unsigned 64-bit integer value. + PosInt(u64), + /// Signed 64-bit integer value. The wrapped value is _always_ negative. + NegInt(i64), + /// 64-bit floating-point value. + Float(f64), +} + +/* ANCHOR_END: document */ + +impl Number { + /// Converts to an `f64` lossily. + /// Use `Number::try_from` to make the conversion only if it is not lossy. + pub fn to_f64_lossy(self) -> f64 { + match self { + Number::PosInt(v) => v as f64, + Number::NegInt(v) => v as f64, + Number::Float(v) => v as f64, + } + } + + /// Converts to an `f32` lossily. + /// Use `Number::try_from` to make the conversion only if it is not lossy. + pub fn to_f32_lossy(self) -> f32 { + match self { + Number::PosInt(v) => v as f32, + Number::NegInt(v) => v as f32, + Number::Float(v) => v as f32, + } + } +} + +macro_rules! to_unsigned_integer_converter { + ($typ:ident, $styp:expr) => { + #[doc = "Converts to a `"] + #[doc = $styp] + #[doc = "`. This conversion fails if it is lossy."] + impl TryFrom for $typ { + type Error = TryFromNumberError; + + fn try_from(value: Number) -> Result { + match value { + Number::PosInt(v) => Ok(Self::try_from(v)?), + Number::NegInt(v) => { + Err(TryFromNumberErrorKind::NegativeToUnsignedLossyConversion(v).into()) + } + Number::Float(v) => { + Err(TryFromNumberErrorKind::FloatToIntegerLossyConversion(v).into()) + } + } + } + } + }; + + ($typ:ident) => { + to_unsigned_integer_converter!($typ, stringify!($typ)); + }; +} + +macro_rules! to_signed_integer_converter { + ($typ:ident, $styp:expr) => { + #[doc = "Converts to a `"] + #[doc = $styp] + #[doc = "`. This conversion fails if it is lossy."] + impl TryFrom for $typ { + type Error = TryFromNumberError; + + fn try_from(value: Number) -> Result { + match value { + Number::PosInt(v) => Ok(Self::try_from(v)?), + Number::NegInt(v) => Ok(Self::try_from(v)?), + Number::Float(v) => { + Err(TryFromNumberErrorKind::FloatToIntegerLossyConversion(v).into()) + } + } + } + } + }; + + ($typ:ident) => { + to_signed_integer_converter!($typ, stringify!($typ)); + }; +} + +/// Converts to a `u64`. The conversion fails if it is lossy. +impl TryFrom for u64 { + type Error = TryFromNumberError; + + fn try_from(value: Number) -> Result { + match value { + Number::PosInt(v) => Ok(v), + Number::NegInt(v) => { + Err(TryFromNumberErrorKind::NegativeToUnsignedLossyConversion(v).into()) + } + Number::Float(v) => { + Err(TryFromNumberErrorKind::FloatToIntegerLossyConversion(v).into()) + } + } + } +} +to_unsigned_integer_converter!(u32); +to_unsigned_integer_converter!(u16); +to_unsigned_integer_converter!(u8); + +impl TryFrom for i64 { + type Error = TryFromNumberError; + + fn try_from(value: Number) -> Result { + match value { + Number::PosInt(v) => Ok(Self::try_from(v)?), + Number::NegInt(v) => Ok(v), + Number::Float(v) => { + Err(TryFromNumberErrorKind::FloatToIntegerLossyConversion(v).into()) + } + } + } +} +to_signed_integer_converter!(i32); +to_signed_integer_converter!(i16); +to_signed_integer_converter!(i8); + +/// Converts to an `f64`. The conversion fails if it is lossy. +impl TryFrom for f64 { + type Error = TryFromNumberError; + + fn try_from(value: Number) -> Result { + match value { + // Integers can only be represented with full precision in a float if they fit in the + // significand, which is 24 bits in `f32` and 53 bits in `f64`. + // https://github.com/rust-lang/rust/blob/58f11791af4f97572e7afd83f11cffe04bbbd12f/library/core/src/convert/num.rs#L151-L153 + Number::PosInt(v) => { + if v <= (1 << 53) { + Ok(v as Self) + } else { + Err(TryFromNumberErrorKind::U64ToFloatLossyConversion(v).into()) + } + } + Number::NegInt(v) => { + if (-(1 << 53)..=(1 << 53)).contains(&v) { + Ok(v as Self) + } else { + Err(TryFromNumberErrorKind::I64ToFloatLossyConversion(v).into()) + } + } + Number::Float(v) => Ok(v), + } + } +} + +/// Converts to an `f64`. The conversion fails if it is lossy. +impl TryFrom for f32 { + type Error = TryFromNumberError; + + fn try_from(value: Number) -> Result { + match value { + Number::PosInt(v) => { + if v <= (1 << 24) { + Ok(v as Self) + } else { + Err(TryFromNumberErrorKind::U64ToFloatLossyConversion(v).into()) + } + } + Number::NegInt(v) => { + if (-(1 << 24)..=(1 << 24)).contains(&v) { + Ok(v as Self) + } else { + Err(TryFromNumberErrorKind::I64ToFloatLossyConversion(v).into()) + } + } + Number::Float(v) => Err(TryFromNumberErrorKind::F64ToF32LossyConversion(v).into()), + } + } +} + +#[cfg(test)] +mod number { + use super::*; + use crate::error::{TryFromNumberError, TryFromNumberErrorKind}; + + macro_rules! to_unsigned_converter_tests { + ($typ:ident) => { + assert_eq!($typ::try_from(Number::PosInt(69u64)).unwrap(), 69); + + assert!(matches!( + $typ::try_from(Number::PosInt(($typ::MAX as u64) + 1u64)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::OutsideIntegerRange(..) + } + )); + + assert!(matches!( + $typ::try_from(Number::NegInt(-1i64)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::NegativeToUnsignedLossyConversion(..) + } + )); + + for val in [69.69f64, f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + assert!(matches!( + $typ::try_from(Number::Float(val)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::FloatToIntegerLossyConversion(..) + } + )); + } + }; + } + + #[test] + fn to_u64() { + assert_eq!(u64::try_from(Number::PosInt(69u64)).unwrap(), 69u64); + + assert!(matches!( + u64::try_from(Number::NegInt(-1i64)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::NegativeToUnsignedLossyConversion(..) + } + )); + + for val in [69.69f64, f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + assert!(matches!( + u64::try_from(Number::Float(val)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::FloatToIntegerLossyConversion(..) + } + )); + } + } + + #[test] + fn to_u32() { + to_unsigned_converter_tests!(u32); + } + + #[test] + fn to_u16() { + to_unsigned_converter_tests!(u16); + } + + #[test] + fn to_u8() { + to_unsigned_converter_tests!(u8); + } + + macro_rules! to_signed_converter_tests { + ($typ:ident) => { + assert_eq!($typ::try_from(Number::PosInt(69u64)).unwrap(), 69); + assert_eq!($typ::try_from(Number::NegInt(-69i64)).unwrap(), -69); + + assert!(matches!( + $typ::try_from(Number::PosInt(($typ::MAX as u64) + 1u64)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::OutsideIntegerRange(..) + } + )); + + assert!(matches!( + $typ::try_from(Number::NegInt(($typ::MIN as i64) - 1i64)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::OutsideIntegerRange(..) + } + )); + + for val in [69.69f64, f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + assert!(matches!( + u64::try_from(Number::Float(val)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::FloatToIntegerLossyConversion(..) + } + )); + } + }; + } + + #[test] + fn to_i64() { + assert_eq!(i64::try_from(Number::PosInt(69u64)).unwrap(), 69); + assert_eq!(i64::try_from(Number::NegInt(-69i64)).unwrap(), -69); + + for val in [69.69f64, f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + assert!(matches!( + u64::try_from(Number::Float(val)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::FloatToIntegerLossyConversion(..) + } + )); + } + } + + #[test] + fn to_i32() { + to_signed_converter_tests!(i32); + } + + #[test] + fn to_i16() { + to_signed_converter_tests!(i16); + } + + #[test] + fn to_i8() { + to_signed_converter_tests!(i8); + } + + #[test] + fn to_f64() { + assert_eq!(f64::try_from(Number::PosInt(69u64)).unwrap(), 69f64); + assert_eq!(f64::try_from(Number::NegInt(-69i64)).unwrap(), -69f64); + assert_eq!(f64::try_from(Number::Float(-69f64)).unwrap(), -69f64); + assert!(f64::try_from(Number::Float(f64::NAN)).unwrap().is_nan()); + assert_eq!( + f64::try_from(Number::Float(f64::INFINITY)).unwrap(), + f64::INFINITY + ); + assert_eq!( + f64::try_from(Number::Float(f64::NEG_INFINITY)).unwrap(), + f64::NEG_INFINITY + ); + + let significand_max_u64: u64 = 1 << 53; + let significand_max_i64: i64 = 1 << 53; + + assert_eq!( + f64::try_from(Number::PosInt(significand_max_u64)).unwrap(), + 9007199254740992f64 + ); + + assert_eq!( + f64::try_from(Number::NegInt(significand_max_i64)).unwrap(), + 9007199254740992f64 + ); + assert_eq!( + f64::try_from(Number::NegInt(-significand_max_i64)).unwrap(), + -9007199254740992f64 + ); + + assert!(matches!( + f64::try_from(Number::PosInt(significand_max_u64 + 1)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::U64ToFloatLossyConversion(..) + } + )); + + assert!(matches!( + f64::try_from(Number::NegInt(significand_max_i64 + 1)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::I64ToFloatLossyConversion(..) + } + )); + assert!(matches!( + f64::try_from(Number::NegInt(-significand_max_i64 - 1)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::I64ToFloatLossyConversion(..) + } + )); + } + + #[test] + fn to_f32() { + assert_eq!(f32::try_from(Number::PosInt(69u64)).unwrap(), 69f32); + assert_eq!(f32::try_from(Number::NegInt(-69i64)).unwrap(), -69f32); + + let significand_max_u64: u64 = 1 << 24; + let significand_max_i64: i64 = 1 << 24; + + assert_eq!( + f32::try_from(Number::PosInt(significand_max_u64)).unwrap(), + 16777216f32 + ); + + assert_eq!( + f32::try_from(Number::NegInt(significand_max_i64)).unwrap(), + 16777216f32 + ); + assert_eq!( + f32::try_from(Number::NegInt(-significand_max_i64)).unwrap(), + -16777216f32 + ); + + assert!(matches!( + f32::try_from(Number::PosInt(significand_max_u64 + 1)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::U64ToFloatLossyConversion(..) + } + )); + + assert!(matches!( + f32::try_from(Number::NegInt(significand_max_i64 + 1)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::I64ToFloatLossyConversion(..) + } + )); + assert!(matches!( + f32::try_from(Number::NegInt(-significand_max_i64 - 1)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::I64ToFloatLossyConversion(..) + } + )); + + for val in [69f64, f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + assert!(matches!( + f32::try_from(Number::Float(val)).unwrap_err(), + TryFromNumberError { + kind: TryFromNumberErrorKind::F64ToF32LossyConversion(..) + } + )); + } + } + + #[test] + fn to_f64_lossy() { + assert_eq!(Number::PosInt(69u64).to_f64_lossy(), 69f64); + assert_eq!( + Number::PosInt((1 << 53) + 1).to_f64_lossy(), + 9007199254740992f64 + ); + assert_eq!( + Number::NegInt(-(1 << 53) - 1).to_f64_lossy(), + -9007199254740992f64 + ); + } + + #[test] + fn to_f32_lossy() { + assert_eq!(Number::PosInt(69u64).to_f32_lossy(), 69f32); + assert_eq!(Number::PosInt((1 << 24) + 1).to_f32_lossy(), 16777216f32); + assert_eq!(Number::NegInt(-(1 << 24) - 1).to_f32_lossy(), -16777216f32); + assert_eq!( + Number::Float(1452089033.7674935).to_f32_lossy(), + 1452089100f32 + ); + } +} diff --git a/rust-runtime/aws-smithy-xml/src/lib.rs b/rust-runtime/aws-smithy-xml/src/lib.rs index f9d4e990ced..fcd41447008 100644 --- a/rust-runtime/aws-smithy-xml/src/lib.rs +++ b/rust-runtime/aws-smithy-xml/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#![allow(clippy::derive_partial_eq_without_eq)] + //! Abstractions for Smithy //! [XML Binding Traits](https://awslabs.github.io/smithy/1.0/spec/core/xml-traits.html) pub mod decode; diff --git a/rust-runtime/inlineable/src/aws_query_compatible_errors.rs b/rust-runtime/inlineable/src/aws_query_compatible_errors.rs new file mode 100644 index 00000000000..7a94064d717 --- /dev/null +++ b/rust-runtime/inlineable/src/aws_query_compatible_errors.rs @@ -0,0 +1,103 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use http::header::ToStrError; +use http::{HeaderMap, HeaderValue}; + +const X_AMZN_QUERY_ERROR: &str = "x-amzn-query-error"; +const QUERY_COMPATIBLE_ERRORCODE_DELIMITER: char = ';'; + +fn aws_query_compatible_error_from_header( + headers: &HeaderMap, +) -> Result, ToStrError> { + headers + .get(X_AMZN_QUERY_ERROR) + .map(|v| v.to_str()) + .transpose() +} + +/// Obtains custom error code and error type from the given `headers`. +/// +/// Looks up a value for the `X_AMZN_QUERY_ERROR` header and if found, the value should be in the +/// form of `;`. The function then splits it into two parts and returns +/// a (error code, error type) as a tuple. +/// +/// Any execution path besides the above happy path will yield a `None`. +pub fn parse_aws_query_compatible_error(headers: &HeaderMap) -> Option<(&str, &str)> { + let header_value = match aws_query_compatible_error_from_header(headers) { + Ok(error) => error?, + _ => return None, + }; + + header_value + .find(QUERY_COMPATIBLE_ERRORCODE_DELIMITER) + .map(|idx| (&header_value[..idx], &header_value[idx + 1..])) +} + +#[cfg(test)] +mod test { + use crate::aws_query_compatible_errors::{ + aws_query_compatible_error_from_header, parse_aws_query_compatible_error, + X_AMZN_QUERY_ERROR, + }; + + #[test] + fn aws_query_compatible_error_from_header_should_provide_value_for_custom_header() { + let mut response: http::Response<()> = http::Response::default(); + response.headers_mut().insert( + X_AMZN_QUERY_ERROR, + http::HeaderValue::from_static("AWS.SimpleQueueService.NonExistentQueue;Sender"), + ); + + let actual = aws_query_compatible_error_from_header(response.headers()).unwrap(); + + assert_eq!( + Some("AWS.SimpleQueueService.NonExistentQueue;Sender"), + actual, + ); + } + + #[test] + fn parse_aws_query_compatible_error_should_parse_code_and_type_fields() { + let mut response: http::Response<()> = http::Response::default(); + response.headers_mut().insert( + X_AMZN_QUERY_ERROR, + http::HeaderValue::from_static("AWS.SimpleQueueService.NonExistentQueue;Sender"), + ); + + let actual = parse_aws_query_compatible_error(response.headers()); + + assert_eq!( + Some(("AWS.SimpleQueueService.NonExistentQueue", "Sender")), + actual, + ); + } + + #[test] + fn parse_aws_query_compatible_error_should_return_none_when_header_value_has_no_delimiter() { + let mut response: http::Response<()> = http::Response::default(); + response.headers_mut().insert( + X_AMZN_QUERY_ERROR, + http::HeaderValue::from_static("AWS.SimpleQueueService.NonExistentQueue"), + ); + + let actual = parse_aws_query_compatible_error(response.headers()); + + assert_eq!(None, actual); + } + + #[test] + fn parse_aws_query_compatible_error_should_return_none_when_there_is_no_target_header() { + let mut response: http::Response<()> = http::Response::default(); + response.headers_mut().insert( + "x-amzn-requestid", + http::HeaderValue::from_static("a918fbf2-457a-4fe1-99ba-5685ce220fc1"), + ); + + let actual = parse_aws_query_compatible_error(response.headers()); + + assert_eq!(None, actual); + } +} diff --git a/rust-runtime/inlineable/src/ec2_query_errors.rs b/rust-runtime/inlineable/src/ec2_query_errors.rs index a7fc1b11631..3355dbe0042 100644 --- a/rust-runtime/inlineable/src/ec2_query_errors.rs +++ b/rust-runtime/inlineable/src/ec2_query_errors.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_smithy_types::error::metadata::{Builder as ErrorMetadataBuilder, ErrorMetadata}; use aws_smithy_xml::decode::{try_data, Document, ScopedDecoder, XmlDecodeError}; use std::convert::TryFrom; @@ -13,36 +14,30 @@ pub fn body_is_error(body: &[u8]) -> Result { Ok(scoped.start_el().matches("Response")) } -pub fn parse_generic_error(body: &[u8]) -> Result { +pub fn parse_error_metadata(body: &[u8]) -> Result { let mut doc = Document::try_from(body)?; let mut root = doc.root_element()?; - let mut err_builder = aws_smithy_types::Error::builder(); + let mut err_builder = ErrorMetadata::builder(); while let Some(mut tag) = root.next_tag() { - match tag.start_el().local() { - "Errors" => { - while let Some(mut error_tag) = tag.next_tag() { - if let "Error" = error_tag.start_el().local() { - while let Some(mut error_field) = error_tag.next_tag() { - match error_field.start_el().local() { - "Code" => { - err_builder.code(try_data(&mut error_field)?); - } - "Message" => { - err_builder.message(try_data(&mut error_field)?); - } - _ => {} + if tag.start_el().local() == "Errors" { + while let Some(mut error_tag) = tag.next_tag() { + if let "Error" = error_tag.start_el().local() { + while let Some(mut error_field) = error_tag.next_tag() { + match error_field.start_el().local() { + "Code" => { + err_builder = err_builder.code(try_data(&mut error_field)?); } + "Message" => { + err_builder = err_builder.message(try_data(&mut error_field)?); + } + _ => {} } } } } - "RequestId" => { - err_builder.request_id(try_data(&mut tag)?); - } - _ => {} } } - Ok(err_builder.build()) + Ok(err_builder) } #[allow(unused)] @@ -71,7 +66,7 @@ pub fn error_scope<'a, 'b>( #[cfg(test)] mod test { - use super::{body_is_error, parse_generic_error}; + use super::{body_is_error, parse_error_metadata}; use crate::ec2_query_errors::error_scope; use aws_smithy_xml::decode::Document; use std::convert::TryFrom; @@ -92,8 +87,7 @@ mod test { "#; assert!(body_is_error(xml).unwrap()); - let parsed = parse_generic_error(xml).expect("valid xml"); - assert_eq!(parsed.request_id(), Some("foo-id")); + let parsed = parse_error_metadata(xml).expect("valid xml").build(); assert_eq!(parsed.message(), Some("Hi")); assert_eq!(parsed.code(), Some("InvalidGreeting")); } diff --git a/rust-runtime/inlineable/src/endpoint_lib/host.rs b/rust-runtime/inlineable/src/endpoint_lib/host.rs index 41dc029122b..4c4168437b4 100644 --- a/rust-runtime/inlineable/src/endpoint_lib/host.rs +++ b/rust-runtime/inlineable/src/endpoint_lib/host.rs @@ -40,6 +40,7 @@ mod test { super::is_valid_host_label(label, allow_dots, &mut DiagnosticCollector::new()) } + #[allow(clippy::bool_assert_comparison)] #[test] fn basic_cases() { assert_eq!(is_valid_host_label("", false), false); @@ -57,6 +58,7 @@ mod test { ); } + #[allow(clippy::bool_assert_comparison)] #[test] fn start_bounds() { assert_eq!(is_valid_host_label("-foo", false), false); diff --git a/rust-runtime/inlineable/src/endpoint_lib/parse_url.rs b/rust-runtime/inlineable/src/endpoint_lib/parse_url.rs index bd7862aa943..5e437c9bfe3 100644 --- a/rust-runtime/inlineable/src/endpoint_lib/parse_url.rs +++ b/rust-runtime/inlineable/src/endpoint_lib/parse_url.rs @@ -71,6 +71,7 @@ mod test { use super::*; use crate::endpoint_lib::diagnostic::DiagnosticCollector; + #[allow(clippy::bool_assert_comparison)] #[test] fn parse_simple_url() { let url = "https://control.vpce-1a2b3c4d-5e6f.s3.us-west-2.vpce.amazonaws.com"; @@ -92,6 +93,7 @@ mod test { assert_eq!(url.scheme(), "https"); } + #[allow(clippy::bool_assert_comparison)] #[test] fn parse_url_with_port() { let url = "http://localhost:8000/path"; diff --git a/rust-runtime/inlineable/src/endpoint_lib/partition.rs b/rust-runtime/inlineable/src/endpoint_lib/partition.rs index 25bec52d952..55b97a21d26 100644 --- a/rust-runtime/inlineable/src/endpoint_lib/partition.rs +++ b/rust-runtime/inlineable/src/endpoint_lib/partition.rs @@ -574,8 +574,10 @@ mod test { #[test] fn resolve_partitions() { let mut resolver = PartitionResolver::empty(); - let mut new_suffix = PartitionOutputOverride::default(); - new_suffix.dns_suffix = Some("mars.aws".into()); + let new_suffix = PartitionOutputOverride { + dns_suffix: Some("mars.aws".into()), + ..Default::default() + }; resolver.add_partition(PartitionMetadata { id: "aws".into(), region_regex: Regex::new("^(us|eu|ap|sa|ca|me|af)-\\w+-\\d+$").unwrap(), diff --git a/rust-runtime/inlineable/src/json_errors.rs b/rust-runtime/inlineable/src/json_errors.rs index ea13da3ba82..1973f59d795 100644 --- a/rust-runtime/inlineable/src/json_errors.rs +++ b/rust-runtime/inlineable/src/json_errors.rs @@ -5,7 +5,7 @@ use aws_smithy_json::deserialize::token::skip_value; use aws_smithy_json::deserialize::{error::DeserializeError, json_token_iter, Token}; -use aws_smithy_types::Error as SmithyError; +use aws_smithy_types::error::metadata::{Builder as ErrorMetadataBuilder, ErrorMetadata}; use bytes::Bytes; use http::header::ToStrError; use http::{HeaderMap, HeaderValue}; @@ -82,56 +82,47 @@ fn error_type_from_header(headers: &HeaderMap) -> Result) -> Option<&str> { - headers - .get("X-Amzn-Requestid") - .and_then(|v| v.to_str().ok()) -} - -pub fn parse_generic_error( +pub fn parse_error_metadata( payload: &Bytes, headers: &HeaderMap, -) -> Result { +) -> Result { let ErrorBody { code, message } = parse_error_body(payload.as_ref())?; - let mut err_builder = SmithyError::builder(); + let mut err_builder = ErrorMetadata::builder(); if let Some(code) = error_type_from_header(headers) .map_err(|_| DeserializeError::custom("X-Amzn-Errortype header was not valid UTF-8"))? .or(code.as_deref()) .map(sanitize_error_code) { - err_builder.code(code); + err_builder = err_builder.code(code); } if let Some(message) = message { - err_builder.message(message); + err_builder = err_builder.message(message); } - if let Some(request_id) = request_id(headers) { - err_builder.request_id(request_id); - } - Ok(err_builder.build()) + Ok(err_builder) } #[cfg(test)] mod test { - use crate::json_errors::{parse_error_body, parse_generic_error, sanitize_error_code}; + use crate::json_errors::{parse_error_body, parse_error_metadata, sanitize_error_code}; use aws_smithy_types::Error; use bytes::Bytes; use std::borrow::Cow; #[test] - fn generic_error() { + fn error_metadata() { let response = http::Response::builder() - .header("X-Amzn-Requestid", "1234") .body(Bytes::from_static( br#"{ "__type": "FooError", "message": "Go to foo" }"#, )) .unwrap(); assert_eq!( - parse_generic_error(response.body(), response.headers()).unwrap(), + parse_error_metadata(response.body(), response.headers()) + .unwrap() + .build(), Error::builder() .code("FooError") .message("Go to foo") - .request_id("1234") .build() ) } @@ -209,7 +200,9 @@ mod test { )) .unwrap(); assert_eq!( - parse_generic_error(response.body(), response.headers()).unwrap(), + parse_error_metadata(response.body(), response.headers()) + .unwrap() + .build(), Error::builder() .code("ResourceNotFoundException") .message("Functions from 'us-west-2' are not reachable from us-east-1") diff --git a/rust-runtime/inlineable/src/lib.rs b/rust-runtime/inlineable/src/lib.rs index 41af3589197..e53b81db7db 100644 --- a/rust-runtime/inlineable/src/lib.rs +++ b/rust-runtime/inlineable/src/lib.rs @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#[allow(dead_code)] +mod aws_query_compatible_errors; #[allow(unused)] mod constrained; #[allow(dead_code)] diff --git a/rust-runtime/inlineable/src/rest_xml_unwrapped_errors.rs b/rust-runtime/inlineable/src/rest_xml_unwrapped_errors.rs index df0f22ef4e7..def901cf7f6 100644 --- a/rust-runtime/inlineable/src/rest_xml_unwrapped_errors.rs +++ b/rust-runtime/inlineable/src/rest_xml_unwrapped_errors.rs @@ -6,6 +6,7 @@ //! Error abstractions for `noErrorWrapping`. Code generators should either inline this file //! or its companion `rest_xml_wrapped_errors.rs` for code generation +use aws_smithy_types::error::metadata::{Builder as ErrorMetadataBuilder, ErrorMetadata}; use aws_smithy_xml::decode::{try_data, Document, ScopedDecoder, XmlDecodeError}; use std::convert::TryFrom; @@ -26,30 +27,27 @@ pub fn error_scope<'a, 'b>( Ok(scoped) } -pub fn parse_generic_error(body: &[u8]) -> Result { +pub fn parse_error_metadata(body: &[u8]) -> Result { let mut doc = Document::try_from(body)?; let mut root = doc.root_element()?; - let mut err = aws_smithy_types::Error::builder(); + let mut builder = ErrorMetadata::builder(); while let Some(mut tag) = root.next_tag() { match tag.start_el().local() { "Code" => { - err.code(try_data(&mut tag)?); + builder = builder.code(try_data(&mut tag)?); } "Message" => { - err.message(try_data(&mut tag)?); - } - "RequestId" => { - err.request_id(try_data(&mut tag)?); + builder = builder.message(try_data(&mut tag)?); } _ => {} } } - Ok(err.build()) + Ok(builder) } #[cfg(test)] mod test { - use super::{body_is_error, parse_generic_error}; + use super::{body_is_error, parse_error_metadata}; #[test] fn parse_unwrapped_error() { @@ -61,8 +59,7 @@ mod test { foo-id "#; assert!(body_is_error(xml).unwrap()); - let parsed = parse_generic_error(xml).expect("valid xml"); - assert_eq!(parsed.request_id(), Some("foo-id")); + let parsed = parse_error_metadata(xml).expect("valid xml").build(); assert_eq!(parsed.message(), Some("Hi")); assert_eq!(parsed.code(), Some("InvalidGreeting")); } diff --git a/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs b/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs index c90301bf396..b735b77249c 100644 --- a/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs +++ b/rust-runtime/inlineable/src/rest_xml_wrapped_errors.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +use aws_smithy_types::error::metadata::{Builder as ErrorMetadataBuilder, ErrorMetadata}; use aws_smithy_xml::decode::{try_data, Document, ScopedDecoder, XmlDecodeError}; use std::convert::TryFrom; @@ -13,32 +14,27 @@ pub fn body_is_error(body: &[u8]) -> Result { Ok(scoped.start_el().matches("ErrorResponse")) } -pub fn parse_generic_error(body: &[u8]) -> Result { +#[allow(dead_code)] +pub fn parse_error_metadata(body: &[u8]) -> Result { let mut doc = Document::try_from(body)?; let mut root = doc.root_element()?; - let mut err_builder = aws_smithy_types::Error::builder(); + let mut err_builder = ErrorMetadata::builder(); while let Some(mut tag) = root.next_tag() { - match tag.start_el().local() { - "Error" => { - while let Some(mut error_field) = tag.next_tag() { - match error_field.start_el().local() { - "Code" => { - err_builder.code(try_data(&mut error_field)?); - } - "Message" => { - err_builder.message(try_data(&mut error_field)?); - } - _ => {} + if tag.start_el().local() == "Error" { + while let Some(mut error_field) = tag.next_tag() { + match error_field.start_el().local() { + "Code" => { + err_builder = err_builder.code(try_data(&mut error_field)?); } + "Message" => { + err_builder = err_builder.message(try_data(&mut error_field)?); + } + _ => {} } } - "RequestId" => { - err_builder.request_id(try_data(&mut tag)?); - } - _ => {} } } - Ok(err_builder.build()) + Ok(err_builder) } #[allow(unused)] @@ -65,7 +61,7 @@ pub fn error_scope<'a, 'b>( #[cfg(test)] mod test { - use super::{body_is_error, parse_generic_error}; + use super::{body_is_error, parse_error_metadata}; use crate::rest_xml_wrapped_errors::error_scope; use aws_smithy_xml::decode::Document; use std::convert::TryFrom; @@ -83,8 +79,7 @@ mod test { foo-id "#; assert!(body_is_error(xml).unwrap()); - let parsed = parse_generic_error(xml).expect("valid xml"); - assert_eq!(parsed.request_id(), Some("foo-id")); + let parsed = parse_error_metadata(xml).expect("valid xml").build(); assert_eq!(parsed.message(), Some("Hi")); assert_eq!(parsed.code(), Some("InvalidGreeting")); } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 77c704ff6bf..6a423bf4348 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "1.62.1" +channel = "1.63.0" diff --git a/tools/ci-build/Dockerfile b/tools/ci-build/Dockerfile index 1b5074ffea6..6a34cda1c70 100644 --- a/tools/ci-build/Dockerfile +++ b/tools/ci-build/Dockerfile @@ -6,12 +6,19 @@ # This is the base Docker build image used by CI ARG base_image=public.ecr.aws/amazonlinux/amazonlinux:2 -ARG rust_stable_version=1.62.1 +ARG rust_stable_version=1.63.0 ARG rust_nightly_version=nightly-2022-11-16 FROM ${base_image} AS bare_base_image RUN yum -y updateinfo +FROM bare_base_image as musl_toolchain +RUN yum -y install tar gzip gcc make +RUN curl https://musl.libc.org/releases/musl-1.2.3.tar.gz -o musl-1.2.3.tar.gz \ + && ls \ + && tar xvzf musl-1.2.3.tar.gz \ + && (cd musl-1.2.3 && ./configure && make install) + # # Rust & Tools Installation Stage # @@ -51,6 +58,7 @@ RUN set -eux; \ rustup component add rustfmt; \ rustup component add clippy; \ rustup toolchain install ${rust_nightly_version} --component clippy; \ + rustup target add x86_64-unknown-linux-musl; \ cargo --version; \ cargo +${rust_nightly_version} --version; @@ -106,6 +114,8 @@ ARG maturin_version=0.14.1 ARG rust_nightly_version RUN cargo +${rust_nightly_version} -Z sparse-registry install maturin --locked --version ${maturin_version} + + # # Final image # @@ -138,6 +148,8 @@ COPY --chown=build:build --from=cargo_minimal_versions /opt/cargo/bin/cargo-mini COPY --chown=build:build --from=cargo_check_external_types /opt/cargo/bin/cargo-check-external-types /opt/cargo/bin/cargo-check-external-types COPY --chown=build:build --from=maturin /opt/cargo/bin/maturin /opt/cargo/bin/maturin COPY --chown=build:build --from=install_rust /opt/rustup /opt/rustup +COPY --chown=build:build --from=musl_toolchain /usr/local/musl/ /usr/local/musl/ +ENV PATH=$PATH:/usr/local/musl/bin/ ENV PATH=/opt/cargo/bin:$PATH \ CARGO_HOME=/opt/cargo \ RUSTUP_HOME=/opt/rustup \ @@ -154,6 +166,7 @@ ENV PATH=/opt/cargo/bin:$PATH \ # This is used primarily by the `build.gradle.kts` files in choosing how to execute build tools. If inside the image, # they will assume the tools are on the PATH, but if outside of the image, they will `cargo run` the tools. ENV SMITHY_RS_DOCKER_BUILD_IMAGE=1 +RUN pip3 install --no-cache-dir mypy==0.991 WORKDIR /home/build COPY sanity-test /home/build/sanity-test RUN /home/build/sanity-test diff --git a/tools/ci-build/changelogger/Cargo.toml b/tools/ci-build/changelogger/Cargo.toml index fd4d4e5c8c6..8dd09363652 100644 --- a/tools/ci-build/changelogger/Cargo.toml +++ b/tools/ci-build/changelogger/Cargo.toml @@ -26,5 +26,5 @@ time = { version = "0.3.9", features = ["local-offset"]} toml = "0.5.8" [dev-dependencies] -pretty_assertions = "1.2.1" +pretty_assertions = "1.3" tempfile = "3.3.0" diff --git a/tools/ci-build/changelogger/src/main.rs b/tools/ci-build/changelogger/src/main.rs index db36023ff35..5e29945e349 100644 --- a/tools/ci-build/changelogger/src/main.rs +++ b/tools/ci-build/changelogger/src/main.rs @@ -65,6 +65,7 @@ mod tests { source_to_truncate: PathBuf::from("fromplace"), changelog_output: PathBuf::from("some-changelog"), release_manifest_output: Some(PathBuf::from("some-manifest")), + current_release_versions_manifest: None, previous_release_versions_manifest: None, date_override: None, smithy_rs_location: None, @@ -97,6 +98,7 @@ mod tests { source_to_truncate: PathBuf::from("fromplace"), changelog_output: PathBuf::from("some-changelog"), release_manifest_output: None, + current_release_versions_manifest: None, previous_release_versions_manifest: None, date_override: None, smithy_rs_location: None, @@ -127,6 +129,7 @@ mod tests { source_to_truncate: PathBuf::from("fromplace"), changelog_output: PathBuf::from("some-changelog"), release_manifest_output: None, + current_release_versions_manifest: None, previous_release_versions_manifest: Some(PathBuf::from("path/to/versions.toml")), date_override: None, smithy_rs_location: None, @@ -148,5 +151,42 @@ mod tests { ]) .unwrap() ); + + assert_eq!( + Args::Render(RenderArgs { + change_set: ChangeSet::AwsSdk, + independent_versioning: true, + source: vec![PathBuf::from("fromplace")], + source_to_truncate: PathBuf::from("fromplace"), + changelog_output: PathBuf::from("some-changelog"), + release_manifest_output: None, + current_release_versions_manifest: Some(PathBuf::from( + "path/to/current/versions.toml" + )), + previous_release_versions_manifest: Some(PathBuf::from( + "path/to/previous/versions.toml" + )), + date_override: None, + smithy_rs_location: None, + }), + Args::try_parse_from([ + "./changelogger", + "render", + "--change-set", + "aws-sdk", + "--independent-versioning", + "--source", + "fromplace", + "--source-to-truncate", + "fromplace", + "--changelog-output", + "some-changelog", + "--current-release-versions-manifest", + "path/to/current/versions.toml", + "--previous-release-versions-manifest", + "path/to/previous/versions.toml" + ]) + .unwrap() + ); } } diff --git a/tools/ci-build/changelogger/src/render.rs b/tools/ci-build/changelogger/src/render.rs index 90c26b63d5c..4dd8e22c1e4 100644 --- a/tools/ci-build/changelogger/src/render.rs +++ b/tools/ci-build/changelogger/src/render.rs @@ -13,9 +13,10 @@ use smithy_rs_tool_common::changelog::{ Changelog, HandAuthoredEntry, Reference, SdkModelChangeKind, SdkModelEntry, }; use smithy_rs_tool_common::git::{find_git_repository_root, Git, GitCLI}; +use smithy_rs_tool_common::versions_manifest::{CrateVersionMetadataMap, VersionsManifest}; use std::env; use std::fmt::Write; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use time::OffsetDateTime; pub const EXAMPLE_ENTRY: &str = r#" @@ -67,6 +68,10 @@ pub struct RenderArgs { /// Optional path to output a release manifest file to #[clap(long, action)] pub release_manifest_output: Option, + /// Optional path to the SDK's versions.toml file for the current release. + /// This is used to generate a markdown table showing crate versions. + #[clap(long, action)] + pub current_release_versions_manifest: Option, /// Optional path to the SDK's versions.toml file for the previous release. /// This is used to filter out changelog entries that have `since_commit` information. #[clap(long, action)] @@ -217,6 +222,16 @@ fn indented_message(message: &str) -> String { out } +fn render_table_row(columns: [&str; 2], out: &mut String) { + let mut row = "|".to_owned(); + for column in columns { + row.push_str(column); + row.push('|'); + } + write!(out, "{row}").unwrap(); + out.push('\n'); +} + fn load_changelogs(args: &RenderArgs) -> Result { let mut combined = Changelog::new(); for source in &args.source { @@ -233,6 +248,19 @@ fn load_changelogs(args: &RenderArgs) -> Result { Ok(combined) } +fn load_current_crate_version_metadata_map( + current_release_versions_manifest: Option<&Path>, +) -> CrateVersionMetadataMap { + current_release_versions_manifest + .and_then( + |manifest_path| match VersionsManifest::from_file(manifest_path) { + Ok(manifest) => Some(manifest.crates), + Err(_) => None, + }, + ) + .unwrap_or_default() +} + fn update_changelogs( args: &RenderArgs, smithy_rs: &dyn Git, @@ -250,7 +278,13 @@ fn update_changelogs( args.change_set, args.previous_release_versions_manifest.as_deref(), )?; - let (release_header, release_notes) = render(&entries, &release_metadata.title); + let current_crate_version_metadata_map = + load_current_crate_version_metadata_map(args.current_release_versions_manifest.as_deref()); + let (release_header, release_notes) = render( + &entries, + current_crate_version_metadata_map, + &release_metadata.title, + ); if let Some(output_path) = &args.release_manifest_output { let release_manifest = ReleaseManifest { tag_name: release_metadata.tag.clone(), @@ -329,9 +363,94 @@ fn render_sdk_model_entries<'a>( } } -/// Convert a list of changelog entries into markdown. +fn render_external_contributors(entries: &[ChangelogEntry], out: &mut String) { + let mut external_contribs = entries + .iter() + .filter_map(|entry| entry.hand_authored().map(|e| &e.author)) + .filter(|author| !is_maintainer(author)) + .collect::>(); + if external_contribs.is_empty() { + return; + } + external_contribs.sort(); + external_contribs.dedup(); + out.push_str("**Contributors**\nThank you for your contributions! ❤\n"); + for contributor_handle in external_contribs { + // retrieve all contributions this author made + let mut contribution_references = entries + .iter() + .filter(|entry| { + entry + .hand_authored() + .map(|e| e.author.eq_ignore_ascii_case(contributor_handle.as_str())) + .unwrap_or(false) + }) + .flat_map(|entry| { + entry + .hand_authored() + .unwrap() + .references + .iter() + .map(to_md_link) + }) + .collect::>(); + contribution_references.sort(); + contribution_references.dedup(); + let contribution_references = contribution_references.as_slice().join(", "); + out.push_str("- @"); + out.push_str(contributor_handle); + if !contribution_references.is_empty() { + write!(out, " ({})", contribution_references) + // The `Write` implementation for `String` is infallible, + // see https://doc.rust-lang.org/src/alloc/string.rs.html#2815 + .unwrap() + } + out.push('\n'); + } + out.push('\n'); +} + +fn render_details(summary: &str, body: &str, out: &mut String) { + out.push_str("
"); + out.push('\n'); + write!(out, "{}", summary).unwrap(); + out.push('\n'); + // A blank line is required for the body to be rendered properly + out.push('\n'); + out.push_str(body); + out.push_str("
"); + out.push('\n'); +} + +fn render_crate_versions(crate_version_metadata_map: CrateVersionMetadataMap, out: &mut String) { + if crate_version_metadata_map.is_empty() { + // If the map is empty, we choose to not render anything, as opposed to + // rendering the
element with empty contents and a user toggling + // it only to find out there is nothing in it. + return; + } + + out.push_str("**Crate Versions**"); + out.push('\n'); + + let mut table = String::new(); + render_table_row(["Crate", "Version"], &mut table); + render_table_row(["-", "-"], &mut table); + for (crate_name, version_metadata) in &crate_version_metadata_map { + render_table_row([crate_name, &version_metadata.version], &mut table); + } + + render_details("Click to expand to view crate versions...", &table, out); + out.push('\n'); +} + +/// Convert a list of changelog entries and crate versions into markdown. /// Returns (header, body) -fn render(entries: &[ChangelogEntry], release_header: &str) -> (String, String) { +fn render( + entries: &[ChangelogEntry], + crate_version_metadata_map: CrateVersionMetadataMap, + release_header: &str, +) -> (String, String) { let mut header = String::new(); header.push_str(release_header); header.push('\n'); @@ -349,49 +468,8 @@ fn render(entries: &[ChangelogEntry], release_header: &str) -> (String, String) entries.iter().filter_map(ChangelogEntry::aws_sdk_model), &mut out, ); - - let mut external_contribs = entries - .iter() - .filter_map(|entry| entry.hand_authored().map(|e| &e.author)) - .filter(|author| !is_maintainer(author)) - .collect::>(); - external_contribs.sort(); - external_contribs.dedup(); - if !external_contribs.is_empty() { - out.push_str("**Contributors**\nThank you for your contributions! ❤\n"); - for contributor_handle in external_contribs { - // retrieve all contributions this author made - let mut contribution_references = entries - .iter() - .filter(|entry| { - entry - .hand_authored() - .map(|e| e.author.eq_ignore_ascii_case(contributor_handle.as_str())) - .unwrap_or(false) - }) - .flat_map(|entry| { - entry - .hand_authored() - .unwrap() - .references - .iter() - .map(to_md_link) - }) - .collect::>(); - contribution_references.sort(); - contribution_references.dedup(); - let contribution_references = contribution_references.as_slice().join(", "); - out.push_str("- @"); - out.push_str(contributor_handle); - if !contribution_references.is_empty() { - write!(&mut out, " ({})", contribution_references) - // The `Write` implementation for `String` is infallible, - // see https://doc.rust-lang.org/src/alloc/string.rs.html#2815 - .unwrap() - } - out.push('\n'); - } - } + render_external_contributors(entries, &mut out); + render_crate_versions(crate_version_metadata_map, &mut out); (header, out) } @@ -399,11 +477,15 @@ fn render(entries: &[ChangelogEntry], release_header: &str) -> (String, String) #[cfg(test)] mod test { use super::{date_based_release_metadata, render, Changelog, ChangelogEntries, ChangelogEntry}; - use smithy_rs_tool_common::changelog::SdkAffected; + use smithy_rs_tool_common::{ + changelog::SdkAffected, + package::PackageCategory, + versions_manifest::{CrateVersion, CrateVersionMetadataMap}, + }; use time::OffsetDateTime; fn render_full(entries: &[ChangelogEntry], release_header: &str) -> String { - let (header, body) = render(entries, release_header); + let (header, body) = render(entries, CrateVersionMetadataMap::new(), release_header); format!("{header}{body}") } @@ -494,6 +576,7 @@ v0.3.0 (January 4th, 2022) Thank you for your contributions! ❤ - @another-contrib ([smithy-rs#200](https://github.com/awslabs/smithy-rs/issues/200)) - @external-contrib ([smithy-rs#446](https://github.com/awslabs/smithy-rs/issues/446)) + "# .trim_start(); pretty_assertions::assert_str_eq!(smithy_rs_expected, smithy_rs_rendered); @@ -518,6 +601,7 @@ v0.1.0 (January 4th, 2022) **Contributors** Thank you for your contributions! ❤ - @external-contrib ([smithy-rs#446](https://github.com/awslabs/smithy-rs/issues/446)) + "# .trim_start(); pretty_assertions::assert_str_eq!(aws_sdk_expected, aws_sdk_rust_rendered); @@ -592,9 +676,69 @@ author = "rcoh" #[test] fn test_empty_render() { let smithy_rs = Vec::::new(); - let (release_title, release_notes) = render(&smithy_rs, "some header"); + let (release_title, release_notes) = + render(&smithy_rs, CrateVersionMetadataMap::new(), "some header"); assert_eq!(release_title, "some header\n===========\n"); assert_eq!(release_notes, ""); } + + #[test] + fn test_crate_versions() { + let mut crate_version_metadata_map = CrateVersionMetadataMap::new(); + crate_version_metadata_map.insert( + "aws-config".to_owned(), + CrateVersion { + category: PackageCategory::AwsRuntime, + version: "0.54.1".to_owned(), + source_hash: "e93380cfbd05e68d39801cbf0113737ede552a5eceb28f4c34b090048d539df9" + .to_owned(), + model_hash: None, + }, + ); + crate_version_metadata_map.insert( + "aws-sdk-accessanalyzer".to_owned(), + CrateVersion { + category: PackageCategory::AwsSdk, + version: "0.24.0".to_owned(), + source_hash: "a7728756b41b33d02f68a5865d3456802b7bc3949ec089790bc4e726c0de8539" + .to_owned(), + model_hash: Some( + "71f1f130504ebd55396c3166d9441513f97e49b281a5dd420fd7e2429860b41b".to_owned(), + ), + }, + ); + crate_version_metadata_map.insert( + "aws-smithy-async".to_owned(), + CrateVersion { + category: PackageCategory::SmithyRuntime, + version: "0.54.1".to_owned(), + source_hash: "8ced52afc783cbb0df47ee8b55260b98e9febdc95edd796ed14c43db5199b0a9" + .to_owned(), + model_hash: None, + }, + ); + let (release_title, release_notes) = render( + &Vec::::new(), + crate_version_metadata_map, + "some header", + ); + + assert_eq!(release_title, "some header\n===========\n"); + let expected_body = r#" +**Crate Versions** +
+Click to expand to view crate versions... + +|Crate|Version| +|-|-| +|aws-config|0.54.1| +|aws-sdk-accessanalyzer|0.24.0| +|aws-smithy-async|0.54.1| +
+ +"# + .trim_start(); + pretty_assertions::assert_str_eq!(release_notes, expected_body); + } } diff --git a/tools/ci-build/changelogger/tests/e2e_test.rs b/tools/ci-build/changelogger/tests/e2e_test.rs index 07dc64a8391..2431bad8a4b 100644 --- a/tools/ci-build/changelogger/tests/e2e_test.rs +++ b/tools/ci-build/changelogger/tests/e2e_test.rs @@ -45,6 +45,37 @@ const SDK_MODEL_SOURCE_TOML: &str = r#" message = "Some API change" "#; +const VERSIONS_TOML: &str = r#" + smithy_rs_revision = '41ca31b85b4ba8c0ad680fe62a230266cc52cc44' + aws_doc_sdk_examples_revision = '97a177aab8c3d2fef97416cb66e4b4d0da840138' + + [manual_interventions] + crates_to_remove = [] + [crates.aws-config] + category = 'AwsRuntime' + version = '0.54.1' + source_hash = 'e93380cfbd05e68d39801cbf0113737ede552a5eceb28f4c34b090048d539df9' + + [crates.aws-sdk-accessanalyzer] + category = 'AwsSdk' + version = '0.24.0' + source_hash = 'a7728756b41b33d02f68a5865d3456802b7bc3949ec089790bc4e726c0de8539' + model_hash = '71f1f130504ebd55396c3166d9441513f97e49b281a5dd420fd7e2429860b41b' + + [crates.aws-smithy-async] + category = 'SmithyRuntime' + version = '0.54.1' + source_hash = '8ced52afc783cbb0df47ee8b55260b98e9febdc95edd796ed14c43db5199b0a9' + + [release] + tag = 'release-2023-01-26' + + [release.crates] + aws-config = "0.54.1" + aws-sdk-accessanalyzer = '0.24.0' + aws-smithy-async = '0.54.1' +"#; + fn create_fake_repo_root( path: &Path, smithy_rs_version: &str, @@ -98,7 +129,7 @@ fn create_fake_repo_root( } #[test] -fn split_aws_sdk_test() { +fn split_aws_sdk() { let tmp_dir = TempDir::new().unwrap(); let source_path = tmp_dir.path().join("source.toml"); let dest_path = tmp_dir.path().join("dest.toml"); @@ -226,7 +257,7 @@ fn split_aws_sdk_test() { } #[test] -fn render_smithy_rs_test() { +fn render_smithy_rs() { let tmp_dir = TempDir::new().unwrap(); let source_path = tmp_dir.path().join("source.toml"); let dest_path = tmp_dir.path().join("dest.md"); @@ -253,6 +284,7 @@ fn render_smithy_rs_test() { changelog_output: dest_path.clone(), release_manifest_output: Some(tmp_dir.path().into()), date_override: Some(OffsetDateTime::UNIX_EPOCH), + current_release_versions_manifest: None, previous_release_versions_manifest: None, smithy_rs_location: Some(tmp_dir.path().into()), }) @@ -274,6 +306,7 @@ January 1st, 1970 Thank you for your contributions! ❤ - @another-dev ([smithy-rs#1234](https://github.com/awslabs/smithy-rs/issues/1234)) + v0.41.0 (Some date in the past) ========= @@ -285,7 +318,7 @@ Old entry contents r#"{ "tagName": "release-1970-01-01", "name": "January 1st, 1970", - "body": "**New this release:**\n- (all, [smithy-rs#1234](https://github.com/awslabs/smithy-rs/issues/1234), @another-dev) Another change\n\n**Contributors**\nThank you for your contributions! ❤\n- @another-dev ([smithy-rs#1234](https://github.com/awslabs/smithy-rs/issues/1234))\n", + "body": "**New this release:**\n- (all, [smithy-rs#1234](https://github.com/awslabs/smithy-rs/issues/1234), @another-dev) Another change\n\n**Contributors**\nThank you for your contributions! ❤\n- @another-dev ([smithy-rs#1234](https://github.com/awslabs/smithy-rs/issues/1234))\n\n", "prerelease": true }"#, release_manifest @@ -293,13 +326,13 @@ Old entry contents } #[test] -fn render_aws_sdk_test() { +fn render_aws_sdk() { let tmp_dir = TempDir::new().unwrap(); let source1_path = tmp_dir.path().join("source1.toml"); let source2_path = tmp_dir.path().join("source2.toml"); let dest_path = tmp_dir.path().join("dest.md"); let release_manifest_path = tmp_dir.path().join("aws-sdk-rust-release-manifest.json"); - let versions_manifest_path = tmp_dir.path().join("versions.toml"); + let previous_versions_manifest_path = tmp_dir.path().join("versions.toml"); let (release_1_commit, release_2_commit) = create_fake_repo_root(tmp_dir.path(), "0.42.0", "0.12.0"); @@ -322,7 +355,7 @@ fn render_aws_sdk_test() { .unwrap(); fs::write(&release_manifest_path, "overwrite-me").unwrap(); fs::write( - &versions_manifest_path, + &previous_versions_manifest_path, format!( "smithy_rs_revision = '{release_1_commit}' aws_doc_sdk_examples_revision = 'not-relevant' @@ -339,7 +372,8 @@ fn render_aws_sdk_test() { changelog_output: dest_path.clone(), release_manifest_output: Some(tmp_dir.path().into()), date_override: Some(OffsetDateTime::UNIX_EPOCH), - previous_release_versions_manifest: Some(versions_manifest_path), + current_release_versions_manifest: None, + previous_release_versions_manifest: Some(previous_versions_manifest_path), smithy_rs_location: Some(tmp_dir.path().into()), }) .unwrap(); @@ -368,6 +402,7 @@ January 1st, 1970 Thank you for your contributions! ❤ - @test-dev ([aws-sdk-rust#234](https://github.com/awslabs/aws-sdk-rust/issues/234), [smithy-rs#567](https://github.com/awslabs/smithy-rs/issues/567)) + v0.41.0 (Some date in the past) ========= @@ -379,7 +414,7 @@ Old entry contents r#"{ "tagName": "release-1970-01-01", "name": "January 1st, 1970", - "body": "**New this release:**\n- 🐛 ([aws-sdk-rust#234](https://github.com/awslabs/aws-sdk-rust/issues/234), [smithy-rs#567](https://github.com/awslabs/smithy-rs/issues/567), @test-dev) Some other change\n\n**Service Features:**\n- `aws-sdk-ec2` (0.12.0): Some API change\n\n**Contributors**\nThank you for your contributions! ❤\n- @test-dev ([aws-sdk-rust#234](https://github.com/awslabs/aws-sdk-rust/issues/234), [smithy-rs#567](https://github.com/awslabs/smithy-rs/issues/567))\n", + "body": "**New this release:**\n- 🐛 ([aws-sdk-rust#234](https://github.com/awslabs/aws-sdk-rust/issues/234), [smithy-rs#567](https://github.com/awslabs/smithy-rs/issues/567), @test-dev) Some other change\n\n**Service Features:**\n- `aws-sdk-ec2` (0.12.0): Some API change\n\n**Contributors**\nThank you for your contributions! ❤\n- @test-dev ([aws-sdk-rust#234](https://github.com/awslabs/aws-sdk-rust/issues/234), [smithy-rs#567](https://github.com/awslabs/smithy-rs/issues/567))\n\n", "prerelease": true }"#, release_manifest @@ -460,6 +495,7 @@ author = "LukeMathWalker" changelog_output: dest_path.clone(), release_manifest_output: Some(tmp_dir.path().into()), date_override: Some(OffsetDateTime::UNIX_EPOCH), + current_release_versions_manifest: None, previous_release_versions_manifest: None, smithy_rs_location: Some(tmp_dir.path().into()), }) @@ -487,6 +523,7 @@ Thank you for your contributions! ❤ - @another-dev ([smithy-rs#2](https://github.com/awslabs/smithy-rs/issues/2)) - @server-dev ([smithy-rs#1](https://github.com/awslabs/smithy-rs/issues/1)) + v0.41.0 (Some date in the past) ========= @@ -569,6 +606,7 @@ author = "rcoh" changelog_output: dest_path, release_manifest_output: Some(tmp_dir.path().into()), date_override: Some(OffsetDateTime::UNIX_EPOCH), + current_release_versions_manifest: None, previous_release_versions_manifest: None, smithy_rs_location: Some(tmp_dir.path().into()), }); @@ -582,3 +620,86 @@ author = "rcoh" panic!("This should have been error that aws-sdk-rust has a target entry"); } } + +#[test] +fn render_crate_versions() { + let tmp_dir = TempDir::new().unwrap(); + let source_path = tmp_dir.path().join("source.toml"); + let dest_path = tmp_dir.path().join("dest.md"); + let release_manifest_path = tmp_dir.path().join("smithy-rs-release-manifest.json"); + let current_versions_manifest_path = tmp_dir.path().join("versions.toml"); + + create_fake_repo_root(tmp_dir.path(), "0.54.1", "0.24.0"); + + fs::write(&source_path, SOURCE_TOML).unwrap(); + fs::write( + &dest_path, + format!( + "{}\nv0.54.0 (Some date in the past)\n=========\n\nOld entry contents\n", + USE_UPDATE_CHANGELOGS + ), + ) + .unwrap(); + fs::write(&release_manifest_path, "overwrite-me").unwrap(); + fs::write(¤t_versions_manifest_path, VERSIONS_TOML).unwrap(); + + subcommand_render(&RenderArgs { + change_set: ChangeSet::SmithyRs, + independent_versioning: true, + source: vec![source_path.clone()], + source_to_truncate: source_path.clone(), + changelog_output: dest_path.clone(), + release_manifest_output: Some(tmp_dir.path().into()), + date_override: Some(OffsetDateTime::UNIX_EPOCH), + current_release_versions_manifest: Some(current_versions_manifest_path), + previous_release_versions_manifest: None, + smithy_rs_location: Some(tmp_dir.path().into()), + }) + .unwrap(); + + let source = fs::read_to_string(&source_path).unwrap(); + let dest = fs::read_to_string(&dest_path).unwrap(); + let release_manifest = fs::read_to_string(&release_manifest_path).unwrap(); + + // source file should be empty + pretty_assertions::assert_str_eq!(EXAMPLE_ENTRY.trim(), source); + pretty_assertions::assert_str_eq!( + r#" +January 1st, 1970 +================= +**New this release:** +- (all, [smithy-rs#1234](https://github.com/awslabs/smithy-rs/issues/1234), @another-dev) Another change + +**Contributors** +Thank you for your contributions! ❤ +- @another-dev ([smithy-rs#1234](https://github.com/awslabs/smithy-rs/issues/1234)) + +**Crate Versions** +
+Click to expand to view crate versions... + +|Crate|Version| +|-|-| +|aws-config|0.54.1| +|aws-sdk-accessanalyzer|0.24.0| +|aws-smithy-async|0.54.1| +
+ + +v0.54.0 (Some date in the past) +========= + +Old entry contents +"#, + dest + ); + pretty_assertions::assert_str_eq!( + r#"{ + "tagName": "release-1970-01-01", + "name": "January 1st, 1970", + "body": "**New this release:**\n- (all, [smithy-rs#1234](https://github.com/awslabs/smithy-rs/issues/1234), @another-dev) Another change\n\n**Contributors**\nThank you for your contributions! ❤\n- @another-dev ([smithy-rs#1234](https://github.com/awslabs/smithy-rs/issues/1234))\n\n**Crate Versions**\n
\nClick to expand to view crate versions...\n\n|Crate|Version|\n|-|-|\n|aws-config|0.54.1|\n|aws-sdk-accessanalyzer|0.24.0|\n|aws-smithy-async|0.54.1|\n
\n\n", + "prerelease": true +}"#, + release_manifest + ); +} diff --git a/tools/ci-build/crate-hasher/Cargo.lock b/tools/ci-build/crate-hasher/Cargo.lock index df4ff9f4ffe..c6f1912dd28 100644 --- a/tools/ci-build/crate-hasher/Cargo.lock +++ b/tools/ci-build/crate-hasher/Cargo.lock @@ -504,8 +504,8 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e334db67871c14c18fc066ad14af13f9fdf5f9a91c61af432d1e3a39c8c6a141" dependencies = [ - "hex", - "sha2", + "hex", + "sha2", ] [[package]] diff --git a/tools/ci-build/crate-hasher/Cargo.toml b/tools/ci-build/crate-hasher/Cargo.toml index 3456cdff456..3f9376443a2 100644 --- a/tools/ci-build/crate-hasher/Cargo.toml +++ b/tools/ci-build/crate-hasher/Cargo.toml @@ -22,6 +22,6 @@ sha256 = "1.1" [dev-dependencies] flate2 = "1.0" -pretty_assertions = "1.2" +pretty_assertions = "1.3" tar = "0.4" tempdir = "0.3" diff --git a/tools/ci-build/publisher/Cargo.lock b/tools/ci-build/publisher/Cargo.lock index 98b686679d6..450e0b81b54 100644 --- a/tools/ci-build/publisher/Cargo.lock +++ b/tools/ci-build/publisher/Cargo.lock @@ -4,27 +4,18 @@ version = 3 [[package]] name = "aho-corasick" -version = "0.7.19" +version = "0.7.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4f55bd91a0978cbfd91c457a164bab8b4001c833b7f323132c0a4e1922dd44e" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" dependencies = [ "memchr", ] -[[package]] -name = "ansi_term" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" -dependencies = [ - "winapi", -] - [[package]] name = "anyhow" -version = "1.0.65" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602" +checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" [[package]] name = "async-recursion" @@ -39,9 +30,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.57" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f" +checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" dependencies = [ "proc-macro2", "quote", @@ -54,7 +45,7 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" dependencies = [ - "hermit-abi", + "hermit-abi 0.1.19", "libc", "winapi", ] @@ -67,9 +58,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.13.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" +checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" [[package]] name = "bitflags" @@ -97,15 +88,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.11.0" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" [[package]] name = "bytes" -version = "1.2.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec8a7b6a70fde80372154c65702f00a0f56f3e1c36abbc6c440484be248856db" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" [[package]] name = "cargo_toml" @@ -120,9 +111,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.73" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" [[package]] name = "cfg-if" @@ -132,9 +123,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.22" +version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfd4d1b31faaa3a89d7934dbded3111da0d2ef28e3ebccdb4f0179f5929d1ef1" +checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" dependencies = [ "num-integer", "num-traits", @@ -249,9 +240,9 @@ dependencies = [ [[package]] name = "ctor" -version = "0.1.23" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdffe87e1d521a10f9696f833fe502293ea446d7f256c06128293a4119bdf4cb" +checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" dependencies = [ "quote", "syn", @@ -286,9 +277,9 @@ dependencies = [ [[package]] name = "digest" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adfbc57365a37acbd2ebf2b64d7e69bb766e2fea813521ed536f5d0520dcf86c" +checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ "block-buffer 0.10.3", "crypto-common", @@ -302,9 +293,9 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "encoding_rs" -version = "0.8.31" +version = "0.8.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9852635589dc9f9ea1b6fe9f05b50ef208c85c834a562f0c6abb1c475736ec2b" +checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" dependencies = [ "cfg-if", ] @@ -356,9 +347,9 @@ checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541" [[package]] name = "futures" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" +checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84" dependencies = [ "futures-channel", "futures-core", @@ -371,9 +362,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050" +checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" dependencies = [ "futures-core", "futures-sink", @@ -381,15 +372,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf" +checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" [[package]] name = "futures-executor" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab" +checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e" dependencies = [ "futures-core", "futures-task", @@ -398,15 +389,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68" +checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" [[package]] name = "futures-macro" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17" +checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" dependencies = [ "proc-macro2", "quote", @@ -415,21 +406,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56" +checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" [[package]] name = "futures-task" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1" +checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" [[package]] name = "futures-util" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90" +checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" dependencies = [ "futures-channel", "futures-core", @@ -455,9 +446,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca32592cf21ac7ccab1825cd87f6c9b3d9022c44d086172ed0966bec8af30be" +checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" dependencies = [ "bytes", "fnv", @@ -474,9 +465,9 @@ dependencies = [ [[package]] name = "handlebars" -version = "4.3.4" +version = "4.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56b224eaa4987c03c30b251de7ef0c15a6a59f34222905850dbc3026dfb24d5f" +checksum = "035ef95d03713f2c347a72547b7cd38cbc9af7cd51e6099fb62d586d4a6dee3a" dependencies = [ "log", "pest", @@ -494,9 +485,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "heck" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" @@ -507,6 +498,15 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" +dependencies = [ + "libc", +] + [[package]] name = "hex" version = "0.4.3" @@ -549,9 +549,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.20" +version = "0.14.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac" +checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c" dependencies = [ "bytes", "futures-channel", @@ -596,9 +596,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.1" +version = "1.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e" +checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" dependencies = [ "autocfg", "hashbrown", @@ -615,21 +615,21 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.5.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b" +checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146" [[package]] name = "itoa" -version = "1.0.3" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8af84674fe1f223a982c933a0ee1086ac4d4052aa0fb8060c12c6ad838e754" +checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" [[package]] name = "js-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49409df3e3bf0856b916e2ceaca09ee28e6871cf7d9ce97a692cacfdb2a25a47" +checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" dependencies = [ "wasm-bindgen", ] @@ -642,9 +642,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.134" +version = "0.2.139" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb" +checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" [[package]] name = "lock_api" @@ -688,21 +688,21 @@ checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" [[package]] name = "mio" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf" +checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" dependencies = [ "libc", "log", "wasi", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] name = "native-tls" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd7e2f3618557f980e0b17e8856252eee3c97fa12c54dff0ca290fb6266ca4a9" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" dependencies = [ "lazy_static", "libc", @@ -716,6 +716,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -737,19 +747,19 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.13.1" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" +checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" dependencies = [ - "hermit-abi", + "hermit-abi 0.2.6", "libc", ] [[package]] name = "once_cell" -version = "1.15.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" +checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" [[package]] name = "opaque-debug" @@ -759,9 +769,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "openssl" -version = "0.10.42" +version = "0.10.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12fc0523e3bd51a692c8850d075d74dc062ccf251c0110668cbd921917118a13" +checksum = "b102428fd03bc5edf97f62620f7298614c45cedf287c271e7ed450bbaf83f2e1" dependencies = [ "bitflags", "cfg-if", @@ -791,9 +801,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.76" +version = "0.9.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5230151e44c0f05157effb743e8d517472843121cf9243e8b81393edb5acd9ce" +checksum = "23bbbf7854cd45b83958ebe919f0e8e516793727652e27fda10a8384cfc790b7" dependencies = [ "autocfg", "cc", @@ -804,9 +814,9 @@ dependencies = [ [[package]] name = "os_str_bytes" -version = "6.3.0" +version = "6.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff" +checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" [[package]] name = "output_vt100" @@ -817,6 +827,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.1" @@ -829,15 +845,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.3" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" +checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -848,9 +864,9 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pest" -version = "2.4.0" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbc7bc69c062e492337d74d59b120c274fd3d261b6bf6d3207d499b4b379c41a" +checksum = "4ab62d2fa33726dbe6321cc97ef96d8cde531e3eeaf858a058de53a8a6d40d8f" dependencies = [ "thiserror", "ucd-trie", @@ -858,9 +874,9 @@ dependencies = [ [[package]] name = "pest_derive" -version = "2.4.0" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b75706b9642ebcb34dab3bc7750f811609a0eb1dd8b88c2d15bf628c1c65b2" +checksum = "8bf026e2d0581559db66d837fe5242320f525d85c76283c61f4d51a1238d65ea" dependencies = [ "pest", "pest_generator", @@ -868,9 +884,9 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.4.0" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4f9272122f5979a6511a749af9db9bfc810393f63119970d7085fed1c4ea0db" +checksum = "2b27bd18aa01d91c8ed2b61ea23406a676b42d82609c6e2581fba42f0c15f17f" dependencies = [ "pest", "pest_meta", @@ -881,13 +897,13 @@ dependencies = [ [[package]] name = "pest_meta" -version = "2.4.0" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8717927f9b79515e565a64fe46c38b8cd0427e64c40680b14a7365ab09ac8d" +checksum = "9f02b677c1859756359fc9983c2e56a0237f18624a3789528804406b7e915e5d" dependencies = [ "once_cell", "pest", - "sha1", + "sha2 0.10.6", ] [[package]] @@ -904,9 +920,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" +checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" [[package]] name = "pretty_assertions" @@ -946,9 +962,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.46" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" dependencies = [ "unicode-ident", ] @@ -984,9 +1000,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.21" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" dependencies = [ "proc-macro2", ] @@ -1002,9 +1018,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.6.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b" +checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" dependencies = [ "aho-corasick", "memchr", @@ -1022,9 +1038,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.27" +version = "0.6.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" +checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" [[package]] name = "remove_dir_all" @@ -1037,9 +1053,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.11.12" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "431949c384f4e2ae07605ccaa56d1d9d2ecdb5cadd4f9577ccfab29f2e5149fc" +checksum = "21eed90ec8570952d53b772ecf8f206aa1ec9a3d76b2521c56c42973f2d91ee9" dependencies = [ "base64", "bytes", @@ -1074,18 +1090,17 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" +checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" [[package]] name = "schannel" -version = "0.1.20" +version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2" +checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" dependencies = [ - "lazy_static", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -1096,9 +1111,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "security-framework" -version = "2.7.0" +version = "2.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bc1bb97804af6631813c55739f771071e0f2ed33ee20b68c86ec505d906356c" +checksum = "a332be01508d814fed64bf28f798a146d73792121129962fdf335bb3c49a4254" dependencies = [ "bitflags", "core-foundation", @@ -1109,9 +1124,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.6.1" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0160a13a177a45bfb43ce71c01580998474f556ad854dcbca936dd2841a5c556" +checksum = "31c9bb296072e961fcbd8853511dd39c2d8be2deb1e17c6860b1d30732b323b4" dependencies = [ "core-foundation-sys", "libc", @@ -1119,24 +1134,24 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e25dfac463d778e353db5be2449d1cce89bd6fd23c9f1ea21310ce6e5a1b29c4" +checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" [[package]] name = "serde" -version = "1.0.145" +version = "1.0.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b" +checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.145" +version = "1.0.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c" +checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" dependencies = [ "proc-macro2", "quote", @@ -1145,9 +1160,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.85" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44" +checksum = "7434af0dc1cbd59268aa98b4c22c131c0584d2232f6fb166efb993e2832e896a" dependencies = [ "itoa", "ryu", @@ -1167,37 +1182,37 @@ dependencies = [ ] [[package]] -name = "sha1" -version = "0.10.5" +name = "sha2" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" dependencies = [ + "block-buffer 0.9.0", "cfg-if", "cpufeatures", - "digest 0.10.5", + "digest 0.9.0", + "opaque-debug", ] [[package]] name = "sha2" -version = "0.9.9" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" +checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" dependencies = [ - "block-buffer 0.9.0", "cfg-if", "cpufeatures", - "digest 0.9.0", - "opaque-debug", + "digest 0.10.6", ] [[package]] name = "sha256" -version = "1.0.3" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e84a7f596c081d359de5e06a83877138bc3c4483591e1af1916e1472e6e146e" +checksum = "e334db67871c14c18fc066ad14af13f9fdf5f9a91c61af432d1e3a39c8c6a141" dependencies = [ "hex", - "sha2", + "sha2 0.9.9", ] [[package]] @@ -1268,9 +1283,9 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" -version = "1.0.101" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e90cde112c4b9690b8cbe810cba9ddd8bc1d7472e2cae317b69e9438c1cba7d2" +checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" dependencies = [ "proc-macro2", "quote", @@ -1293,9 +1308,9 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.1.3" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" dependencies = [ "winapi-util", ] @@ -1312,24 +1327,24 @@ dependencies = [ [[package]] name = "textwrap" -version = "0.15.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "949517c0cf1bf4ee812e2e07e08ab448e3ae0d23472aee8a06c985f0c8815b16" +checksum = "b7b3e525a49ec206798b40326a44121291b530c963cfb01018f63e135bac543d" [[package]] name = "thiserror" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" +checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" +checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" dependencies = [ "proc-macro2", "quote", @@ -1356,15 +1371,15 @@ dependencies = [ [[package]] name = "tinyvec_macros" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.21.2" +version = "1.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e03c497dc955702ba729190dc4aac6f2a0ce97f913e5b1b5912fc5039d9099" +checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" dependencies = [ "autocfg", "bytes", @@ -1377,14 +1392,14 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "winapi", + "windows-sys 0.42.0", ] [[package]] name = "tokio-macros" -version = "1.8.0" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9724f9a975fb987ef7a3cd9be0350edcbe130698af5b8f7a631e23d42d052484" +checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" dependencies = [ "proc-macro2", "quote", @@ -1417,9 +1432,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" dependencies = [ "indexmap", "serde", @@ -1433,9 +1448,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.36" +version = "0.1.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fce9567bd60a67d08a16488756721ba392f24f29006402881e43b19aac64307" +checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", "pin-project-lite", @@ -1445,9 +1460,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" +checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", "quote", @@ -1456,9 +1471,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.29" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aeea4303076558a00714b823f9ad67d58a3bbda1df83d8827d21193156e22f7" +checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a" dependencies = [ "once_cell", "valuable", @@ -1477,12 +1492,12 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60db860322da191b40952ad9affe65ea23e7dd6a5c442c2c42865810c6ab8e6b" +checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" dependencies = [ - "ansi_term", "matchers", + "nu-ansi-term", "once_cell", "regex", "sharded-slab", @@ -1495,15 +1510,15 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" [[package]] name = "typenum" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "ucd-trie" @@ -1513,15 +1528,15 @@ checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" [[package]] name = "unicode-bidi" -version = "0.3.8" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992" +checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" [[package]] name = "unicode-ident" -version = "1.0.4" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd" +checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" [[package]] name = "unicode-normalization" @@ -1585,9 +1600,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268" +checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -1595,9 +1610,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8ffb332579b0557b52d268b91feab8df3615f265d5270fec2a8c95b17c1142" +checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" dependencies = [ "bumpalo", "log", @@ -1610,9 +1625,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.33" +version = "0.4.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23639446165ca5a5de86ae1d8896b737ae80319560fbaa4c2887b7da6e7ebd7d" +checksum = "f219e0d211ba40266969f6dbdd90636da12f75bee4fc9d6c23d1260dadb51454" dependencies = [ "cfg-if", "js-sys", @@ -1622,9 +1637,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "052be0f94026e6cbc75cdefc9bae13fd6052cdcaf532fa6c45e7ae33a1e6c810" +checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1632,9 +1647,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" +checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", @@ -1645,15 +1660,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" +checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" [[package]] name = "web-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcda906d8be16e728fd5adc5b729afad4e444e106ab28cd1c7256e54fa61510f" +checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97" dependencies = [ "js-sys", "wasm-bindgen", @@ -1692,46 +1707,84 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-sys" -version = "0.36.1" +version = "0.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" dependencies = [ + "windows_aarch64_gnullvm", "windows_aarch64_msvc", "windows_i686_gnu", "windows_i686_msvc", "windows_x86_64_gnu", + "windows_x86_64_gnullvm", "windows_x86_64_msvc", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" + [[package]] name = "windows_aarch64_msvc" -version = "0.36.1" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" +checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" [[package]] name = "windows_i686_gnu" -version = "0.36.1" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" +checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" [[package]] name = "windows_i686_msvc" -version = "0.36.1" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" +checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" [[package]] name = "windows_x86_64_gnu" -version = "0.36.1" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" +checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" [[package]] name = "windows_x86_64_msvc" -version = "0.36.1" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" +checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" [[package]] name = "winreg" diff --git a/tools/ci-build/publisher/Cargo.toml b/tools/ci-build/publisher/Cargo.toml index fb543a9daa6..f076cedacf6 100644 --- a/tools/ci-build/publisher/Cargo.toml +++ b/tools/ci-build/publisher/Cargo.toml @@ -38,4 +38,4 @@ tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } tempfile = "3.3.0" [dev-dependencies] -pretty_assertions = "1.2.1" +pretty_assertions = "1.3" diff --git a/tools/ci-build/publisher/src/retry.rs b/tools/ci-build/publisher/src/retry.rs index 12db87a9200..1608da8448f 100644 --- a/tools/ci-build/publisher/src/retry.rs +++ b/tools/ci-build/publisher/src/retry.rs @@ -51,7 +51,7 @@ where } ErrorClass::Retry => { info!( - "{} failed on attempt {} with retryable error: {}. Will retry after {:?}", + "{} failed on attempt {} with retryable error: {:?}. Will retry after {:?}", what, attempt, err.into(), backoff ); } diff --git a/tools/ci-build/publisher/src/subcommand/generate_version_manifest.rs b/tools/ci-build/publisher/src/subcommand/generate_version_manifest.rs index 719cb167f72..20751b856e9 100644 --- a/tools/ci-build/publisher/src/subcommand/generate_version_manifest.rs +++ b/tools/ci-build/publisher/src/subcommand/generate_version_manifest.rs @@ -197,7 +197,7 @@ fn hash_models(projection: &SmithyBuildProjection) -> Result { // Must match `hashModels` in `CrateVersioner.kt` let mut hashes = String::new(); for import in &projection.imports { - hashes.push_str(&sha256::digest_file(import).context("hash model")?); + hashes.push_str(&sha256::try_digest(import.as_path())?); hashes.push('\n'); } Ok(sha256::digest(hashes)) @@ -217,7 +217,7 @@ impl SmithyBuildRoot { #[derive(Debug, Deserialize)] struct SmithyBuildProjection { - imports: Vec, + imports: Vec, } #[cfg(test)] @@ -348,10 +348,7 @@ mod tests { fs::write(&model1b, "bar").unwrap(); let hash = hash_models(&SmithyBuildProjection { - imports: vec![ - model1a.to_str().unwrap().to_string(), - model1b.to_str().unwrap().to_string(), - ], + imports: vec![model1a, model1b], }) .unwrap(); diff --git a/tools/ci-build/publisher/src/subcommand/publish.rs b/tools/ci-build/publisher/src/subcommand/publish.rs index 5d76e5ff017..dfa2c7af543 100644 --- a/tools/ci-build/publisher/src/subcommand/publish.rs +++ b/tools/ci-build/publisher/src/subcommand/publish.rs @@ -52,7 +52,7 @@ pub async fn subcommand_publish( // Don't proceed unless the user confirms the plan confirm_plan(&batches, stats, *skip_confirmation)?; - for batch in batches { + for batch in &batches { let mut any_published = false; for package in batch { // Only publish if it hasn't been published yet. @@ -65,13 +65,12 @@ pub async fn subcommand_publish( // Sometimes it takes a little bit of time for the new package version // to become available after publish. If we proceed too quickly, then // the next package publish can fail if it depends on this package. - wait_for_eventual_consistency(&package).await?; - info!("Successfully published `{}`", package.handle); + wait_for_eventual_consistency(package).await?; + info!("Successfully published `{}`", &package.handle); any_published = true; } else { - info!("`{}` was already published", package.handle); + info!("`{}` was already published", &package.handle); } - correct_owner(&package.handle, &package.category).await?; } if any_published { info!("Sleeping 30 seconds after completion of the batch"); @@ -81,6 +80,12 @@ pub async fn subcommand_publish( } } + for batch in &batches { + for package in batch { + correct_owner(&package.handle, &package.category).await?; + } + } + Ok(()) } @@ -142,6 +147,11 @@ async fn wait_for_eventual_consistency(package: &Package) -> Result<()> { /// Corrects the crate ownership. pub async fn correct_owner(handle: &PackageHandle, category: &PackageCategory) -> Result<()> { + // https://github.com/orgs/awslabs/teams/smithy-rs-server + const SMITHY_RS_SERVER_OWNER: &str = "github:awslabs:smithy-rs-server"; + // https://github.com/orgs/awslabs/teams/rust-sdk-owners + const RUST_SDK_OWNER: &str = "github:awslabs:rust-sdk-owners"; + run_with_retry( &format!("Correcting ownership of `{}`", handle.name), 3, @@ -151,7 +161,7 @@ pub async fn correct_owner(handle: &PackageHandle, category: &PackageCategory) - let expected_owners = expected_package_owners(category, &handle.name); let owners_to_be_added = expected_owners.difference(&actual_owners); - let incorrect_owners = actual_owners.difference(&expected_owners); + let owners_to_be_removed = actual_owners.difference(&expected_owners); let mut added_individual = false; for crate_owner in owners_to_be_added { @@ -162,21 +172,26 @@ pub async fn correct_owner(handle: &PackageHandle, category: &PackageCategory) - // Teams in crates.io start with `github:` while individuals are just the GitHub user name added_individual |= !crate_owner.starts_with("github:"); } - for incorrect_owner in incorrect_owners { + for crate_owner in owners_to_be_removed { + // Trying to remove them will result in an error due to a bug in crates.io + // Upstream tracking issue: https://github.com/rust-lang/crates.io/issues/2736 + if crate_owner == SMITHY_RS_SERVER_OWNER || crate_owner == RUST_SDK_OWNER { + continue; + } // Adding an individual owner requires accepting an invite, so don't attempt to remove // anyone if an owner was added, as removing the last individual owner may break. // The next publish run will remove the incorrect owner. if !added_individual { - cargo::RemoveOwner::new(&handle.name, incorrect_owner) + cargo::RemoveOwner::new(&handle.name, crate_owner) .spawn() .await - .context(format!("remove incorrect owner `{}` from crate `{}`", incorrect_owner, handle))?; + .with_context(|| format!("remove incorrect owner `{}` from crate `{}`", crate_owner, handle))?; info!( "Removed incorrect owner `{}` from crate `{}`", - incorrect_owner, handle + crate_owner, handle ); } else { - info!("Skipping removal of incorrect owner `{}` from crate `{}` due to new owners", incorrect_owner, handle); + info!("Skipping removal of incorrect owner `{}` from crate `{}` due to new owners", crate_owner, handle); } } Result::<_, BoxError>::Ok(()) diff --git a/tools/ci-build/publisher/src/subcommand/upgrade_runtime_crates_version.rs b/tools/ci-build/publisher/src/subcommand/upgrade_runtime_crates_version.rs index 0bf17e3c014..8132e1fa66e 100644 --- a/tools/ci-build/publisher/src/subcommand/upgrade_runtime_crates_version.rs +++ b/tools/ci-build/publisher/src/subcommand/upgrade_runtime_crates_version.rs @@ -7,6 +7,7 @@ use crate::fs::Fs; use anyhow::{anyhow, bail, Context}; use clap::Parser; use regex::Regex; +use std::borrow::Cow; use std::path::{Path, PathBuf}; #[derive(Parser, Debug)] @@ -27,35 +28,42 @@ pub async fn subcommand_upgrade_runtime_crates_version( .with_context(|| format!("{} is not a valid semver version", &args.version))?; let fs = Fs::Real; let gradle_properties = read_gradle_properties(fs, &args.gradle_properties_path).await?; + let updated_gradle_properties = update_gradle_properties(&gradle_properties, &upgraded_version) + .with_context(|| { + format!( + "Failed to extract the expected runtime crates version from `{:?}`", + &args.gradle_properties_path + ) + })?; + update_gradle_properties_file( + fs, + &args.gradle_properties_path, + updated_gradle_properties.as_ref(), + ) + .await?; + Ok(()) +} + +fn update_gradle_properties<'a>( + gradle_properties: &'a str, + upgraded_version: &'a semver::Version, +) -> Result, anyhow::Error> { let version_regex = - Regex::new(r"(?Psmithy\.rs\.runtime\.crate\.version=)(?P\d+\.\d+\.\d+-.*)") + Regex::new(r"(?Psmithy\.rs\.runtime\.crate\.version=)(?P\d+\.\d+\.\d+.*)") .unwrap(); - let current_version = version_regex.captures(&gradle_properties).ok_or_else(|| { - anyhow!( - "Failed to extract the expected runtime crates version from `{:?}`", - &args.gradle_properties_path - ) - })?; + let current_version = version_regex + .captures(gradle_properties) + .ok_or_else(|| anyhow!("Failed to extract the expected runtime crates version"))?; let current_version = current_version.name("version").unwrap(); let current_version = semver::Version::parse(current_version.as_str()) .with_context(|| format!("{} is not a valid semver version", current_version.as_str()))?; - if current_version > upgraded_version + if ¤t_version > upgraded_version // Special version tag used on the `main` branch && current_version != semver::Version::parse("0.0.0-smithy-rs-head").unwrap() { bail!("Moving from {current_version} to {upgraded_version} would be a *downgrade*. This command doesn't allow it!"); } - let updated_gradle_properties = version_regex.replace( - &gradle_properties, - format!("${{field}}{}", upgraded_version), - ); - update_gradle_properties( - fs, - &args.gradle_properties_path, - updated_gradle_properties.as_ref(), - ) - .await?; - Ok(()) + Ok(version_regex.replace(gradle_properties, format!("${{field}}{}", upgraded_version))) } async fn read_gradle_properties(fs: Fs, path: &Path) -> Result { @@ -65,7 +73,7 @@ async fn read_gradle_properties(fs: Fs, path: &Path) -> Result; + /// Root struct representing a `versions.toml` manifest #[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] pub struct VersionsManifest { @@ -31,7 +33,7 @@ pub struct VersionsManifest { pub manual_interventions: ManualInterventions, /// All SDK crate version metadata - pub crates: BTreeMap, + pub crates: CrateVersionMetadataMap, /// Crate versions that were a part of this SDK release. /// Releases may not release every single crate, which can happen if a crate has no changes. diff --git a/tools/ci-cdk/canary-lambda/src/s3_canary.rs b/tools/ci-cdk/canary-lambda/src/s3_canary.rs index cb56797f016..70e3d18c55b 100644 --- a/tools/ci-cdk/canary-lambda/src/s3_canary.rs +++ b/tools/ci-cdk/canary-lambda/src/s3_canary.rs @@ -8,8 +8,9 @@ use crate::{mk_canary, CanaryEnv}; use anyhow::Context; use aws_config::SdkConfig; use aws_sdk_s3 as s3; -use s3::error::{GetObjectError, GetObjectErrorKind}; +use aws_sdk_s3::presigning::config::PresigningConfig; use s3::types::ByteStream; +use std::time::Duration; use uuid::Uuid; const METADATA_TEST_VALUE: &str = "some value"; @@ -35,15 +36,13 @@ pub async fn s3_canary(client: s3::Client, s3_bucket_name: String) -> anyhow::Re CanaryError(format!("Expected object {} to not exist in S3", test_key)).into(), ); } - Err(err) => match err.into_service_error() { - GetObjectError { - kind: GetObjectErrorKind::NoSuchKey(..), - .. - } => { - // good + Err(err) => { + let err = err.into_service_error(); + // If we get anything other than "No such key", we have a problem + if !err.is_no_such_key() { + return Err(err).context("unexpected s3::GetObject failure"); } - err => Err(err).context("unexpected s3::GetObject failure")?, - }, + } } // Put the test object @@ -66,6 +65,23 @@ pub async fn s3_canary(client: s3::Client, s3_bucket_name: String) -> anyhow::Re .await .context("s3::GetObject[2]")?; + // repeat the test with a presigned url + let uri = client + .get_object() + .bucket(&s3_bucket_name) + .key(&test_key) + .presigned(PresigningConfig::expires_in(Duration::from_secs(120)).unwrap()) + .await + .unwrap(); + let response = reqwest::get(uri.uri().to_string()) + .await + .context("s3::presigned")? + .text() + .await?; + if response != "test" { + return Err(CanaryError(format!("presigned URL returned bad data: {:?}", response)).into()); + } + let mut result = Ok(()); match output.metadata() { Some(map) => { diff --git a/tools/ci-cdk/canary-runner/Cargo.lock b/tools/ci-cdk/canary-runner/Cargo.lock index 3bdfd38331a..adccd1cddc0 100644 --- a/tools/ci-cdk/canary-runner/Cargo.lock +++ b/tools/ci-cdk/canary-runner/Cargo.lock @@ -10,9 +10,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "aho-corasick" -version = "0.7.19" +version = "0.7.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4f55bd91a0978cbfd91c457a164bab8b4001c833b7f323132c0a4e1922dd44e" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" dependencies = [ "memchr", ] @@ -26,20 +26,11 @@ dependencies = [ "libc", ] -[[package]] -name = "ansi_term" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" -dependencies = [ - "winapi", -] - [[package]] name = "anyhow" -version = "1.0.65" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602" +checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" [[package]] name = "async-recursion" @@ -54,9 +45,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.57" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f" +checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" dependencies = [ "proc-macro2", "quote", @@ -69,7 +60,7 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" dependencies = [ - "hermit-abi", + "hermit-abi 0.1.19", "libc", "winapi", ] @@ -101,7 +92,7 @@ dependencies = [ "http", "hyper", "ring", - "time 0.3.15", + "time 0.3.19", "tokio", "tower", "tracing", @@ -286,7 +277,7 @@ dependencies = [ "percent-encoding", "regex", "ring", - "time 0.3.15", + "time 0.3.19", "tracing", ] @@ -422,7 +413,7 @@ dependencies = [ "itoa", "num-integer", "ryu", - "time 0.3.15", + "time 0.3.19", ] [[package]] @@ -458,9 +449,15 @@ checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff" [[package]] name = "base64" -version = "0.13.0" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" +checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" [[package]] name = "bitflags" @@ -479,9 +476,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.11.0" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" [[package]] name = "byteorder" @@ -515,7 +512,7 @@ dependencies = [ "aws-sdk-cloudwatch", "aws-sdk-lambda", "aws-sdk-s3", - "base64 0.13.0", + "base64 0.13.1", "clap", "hex", "lazy_static", @@ -535,9 +532,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.73" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" [[package]] name = "cfg-if" @@ -547,25 +544,25 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.22" +version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfd4d1b31faaa3a89d7934dbded3111da0d2ef28e3ebccdb4f0179f5929d1ef1" +checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" dependencies = [ "iana-time-zone", "js-sys", "num-integer", "num-traits", "serde", - "time 0.1.44", + "time 0.1.45", "wasm-bindgen", "winapi", ] [[package]] name = "clap" -version = "3.2.22" +version = "3.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86447ad904c7fb335a790c9d7fe3d0d971dc523b8ccd1561a520de9a85302750" +checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5" dependencies = [ "atty", "bitflags", @@ -600,6 +597,16 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "core-foundation" version = "0.9.3" @@ -655,9 +662,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.12" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edbafec5fa1f196ca66527c1b12c2ec4745ca14b50f1ad8f9f6f720b55d11fac" +checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" dependencies = [ "cfg-if", ] @@ -683,10 +690,54 @@ dependencies = [ [[package]] name = "ctor" -version = "0.1.23" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d2301688392eb071b0bf1a37be05c469d3cc4dbbd95df672fe28ab021e6a096" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "cxx" +version = "1.0.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86d3488e7665a7a483b57e25bdd90d0aeb2bc7608c8d0346acf2ad3f1caf1d62" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48fcaf066a053a41a81dfb14d57d99738b767febb8b735c3016e469fac5da690" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2ef98b8b717a829ca5603af80e1f9e2e48013ab227b68ef37872ef84ee479bf" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdffe87e1d521a10f9696f833fe502293ea446d7f256c06128293a4119bdf4cb" +checksum = "086c685979a698443656e5cf7856c95c642295a38599f12fb1ff76fb28d19892" dependencies = [ + "proc-macro2", "quote", "syn", ] @@ -699,9 +750,9 @@ checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" [[package]] name = "digest" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adfbc57365a37acbd2ebf2b64d7e69bb766e2fea813521ed536f5d0520dcf86c" +checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ "block-buffer", "crypto-common", @@ -709,39 +760,39 @@ dependencies = [ [[package]] name = "dyn-clone" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f94fa09c2aeea5b8839e414b7b841bf429fd25b9c522116ac97ee87856d88b2" +checksum = "c9b0705efd4599c15a38151f4721f7bc388306f61084d3bfd50bd07fbca5cb60" [[package]] name = "either" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" [[package]] name = "encoding_rs" -version = "0.8.31" +version = "0.8.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9852635589dc9f9ea1b6fe9f05b50ef208c85c834a562f0c6abb1c475736ec2b" +checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" dependencies = [ "cfg-if", ] [[package]] name = "fastrand" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" dependencies = [ "instant", ] [[package]] name = "flate2" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f82b0f4c27ad9f8bfd1f3208d882da2b09c301bc1c828fd3a00d0216d2fbbff6" +checksum = "a8a2db397cb1c8772f31494cb8917e48cd1e64f0fa7efac59fbd741a0a8ce841" dependencies = [ "crc32fast", "miniz_oxide", @@ -780,9 +831,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" +checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84" dependencies = [ "futures-channel", "futures-core", @@ -795,9 +846,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050" +checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" dependencies = [ "futures-core", "futures-sink", @@ -805,15 +856,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf" +checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" [[package]] name = "futures-executor" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab" +checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e" dependencies = [ "futures-core", "futures-task", @@ -822,15 +873,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68" +checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" [[package]] name = "futures-macro" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17" +checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" dependencies = [ "proc-macro2", "quote", @@ -839,21 +890,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56" +checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" [[package]] name = "futures-task" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1" +checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" [[package]] name = "futures-util" -version = "0.3.24" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90" +checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" dependencies = [ "futures-channel", "futures-core", @@ -879,9 +930,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" +checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ "cfg-if", "libc", @@ -890,9 +941,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca32592cf21ac7ccab1825cd87f6c9b3d9022c44d086172ed0966bec8af30be" +checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" dependencies = [ "bytes", "fnv", @@ -915,9 +966,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "heck" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" @@ -928,6 +979,15 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" +dependencies = [ + "libc", +] + [[package]] name = "hex" version = "0.4.3" @@ -936,9 +996,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "http" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" dependencies = [ "bytes", "fnv", @@ -970,9 +1030,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.20" +version = "0.14.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac" +checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c" dependencies = [ "bytes", "futures-channel", @@ -1011,13 +1071,13 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.23.0" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac" +checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c" dependencies = [ "http", "hyper", - "rustls 0.20.6", + "rustls 0.20.8", "tokio", "tokio-rustls 0.23.4", ] @@ -1041,7 +1101,7 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5617e92fc2f2501c3e2bc6ce547cad841adba2bae5b921c7e52510beca6d084c" dependencies = [ - "base64 0.13.0", + "base64 0.13.1", "bytes", "http", "httpdate", @@ -1053,17 +1113,28 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.50" +version = "0.1.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd911b35d940d2bd0bea0f9100068e5b97b51a1cbe13d13382f132e0365257a0" +checksum = "64c122667b287044802d6ce17ee2ddf13207ed924c712de9a66a5814d5b64765" dependencies = [ "android_system_properties", "core-foundation-sys", + "iana-time-zone-haiku", "js-sys", "wasm-bindgen", "winapi", ] +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca" +dependencies = [ + "cxx", + "cxx-build", +] + [[package]] name = "idna" version = "0.2.3" @@ -1077,9 +1148,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.1" +version = "1.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e" +checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" dependencies = [ "autocfg", "hashbrown", @@ -1096,21 +1167,21 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.5.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b" +checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146" [[package]] name = "itoa" -version = "1.0.3" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8af84674fe1f223a982c933a0ee1086ac4d4052aa0fb8060c12c6ad838e754" +checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" [[package]] name = "js-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49409df3e3bf0856b916e2ceaca09ee28e6871cf7d9ce97a692cacfdb2a25a47" +checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" dependencies = [ "wasm-bindgen", ] @@ -1143,9 +1214,18 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.134" +version = "0.2.139" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb" +checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" + +[[package]] +name = "link-cplusplus" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecd207c9c713c34f95a097a5b029ac2ce6010530c7b49d7fea24d977dede04f5" +dependencies = [ + "cc", +] [[package]] name = "lock_api" @@ -1178,9 +1258,9 @@ dependencies = [ [[package]] name = "matches" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" +checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" [[package]] name = "md-5" @@ -1215,30 +1295,30 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.5.4" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34" +checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" dependencies = [ "adler", ] [[package]] name = "mio" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" dependencies = [ "libc", "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] name = "native-tls" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd7e2f3618557f980e0b17e8856252eee3c97fa12c54dff0ca290fb6266ca4a9" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" dependencies = [ "lazy_static", "libc", @@ -1252,6 +1332,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-bigint" version = "0.2.6" @@ -1284,20 +1374,11 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" -dependencies = [ - "hermit-abi", - "libc", -] - -[[package]] -name = "num_threads" -version = "0.1.6" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" +checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" dependencies = [ + "hermit-abi 0.2.6", "libc", ] @@ -1333,15 +1414,15 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.15.0" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" +checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" [[package]] name = "openssl" -version = "0.10.42" +version = "0.10.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12fc0523e3bd51a692c8850d075d74dc062ccf251c0110668cbd921917118a13" +checksum = "b102428fd03bc5edf97f62620f7298614c45cedf287c271e7ed450bbaf83f2e1" dependencies = [ "bitflags", "cfg-if", @@ -1371,9 +1452,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.76" +version = "0.9.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5230151e44c0f05157effb743e8d517472843121cf9243e8b81393edb5acd9ce" +checksum = "23bbbf7854cd45b83958ebe919f0e8e516793727652e27fda10a8384cfc790b7" dependencies = [ "autocfg", "cc", @@ -1403,9 +1484,9 @@ dependencies = [ [[package]] name = "os_str_bytes" -version = "6.3.0" +version = "6.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff" +checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" [[package]] name = "output_vt100" @@ -1416,6 +1497,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.1" @@ -1428,15 +1515,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.3" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" +checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -1445,7 +1532,7 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd56cbd21fea48d0c440b41cd69c589faacade08c992d9a54e471b79d0fd13eb" dependencies = [ - "base64 0.13.0", + "base64 0.13.1", "once_cell", "regex", ] @@ -1490,15 +1577,15 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" +checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" [[package]] name = "ppv-lite86" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "pretty_assertions" @@ -1538,18 +1625,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.46" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.21" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" dependencies = [ "proc-macro2", ] @@ -1595,9 +1682,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.6.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b" +checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" dependencies = [ "aho-corasick", "memchr", @@ -1615,9 +1702,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.27" +version = "0.6.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" +checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" [[package]] name = "remove_dir_all" @@ -1630,11 +1717,11 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.11.12" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "431949c384f4e2ae07605ccaa56d1d9d2ecdb5cadd4f9577ccfab29f2e5149fc" +checksum = "21eed90ec8570952d53b772ecf8f206aa1ec9a3d76b2521c56c42973f2d91ee9" dependencies = [ - "base64 0.13.0", + "base64 0.21.0", "bytes", "encoding_rs", "futures-core", @@ -1643,7 +1730,7 @@ dependencies = [ "http", "http-body", "hyper", - "hyper-rustls 0.23.0", + "hyper-rustls 0.23.2", "hyper-tls", "ipnet", "js-sys", @@ -1654,7 +1741,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.20.6", + "rustls 0.20.8", "rustls-pemfile", "serde", "serde_json", @@ -1737,9 +1824,9 @@ dependencies = [ [[package]] name = "retry-policies" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47f9e19b18c6cdd796cc70aea8a9ea5ee7b813be611c6589e3624fcdbfd05f9d" +checksum = "e09bbcb5003282bcb688f0bae741b278e9c7e8f378f561522c9806c58e075d9b" dependencies = [ "anyhow", "chrono", @@ -1776,7 +1863,7 @@ version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" dependencies = [ - "base64 0.13.0", + "base64 0.13.1", "log", "ring", "sct 0.6.1", @@ -1785,9 +1872,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.20.6" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033" +checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" dependencies = [ "log", "ring", @@ -1809,27 +1896,26 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0864aeff53f8c05aa08d86e5ef839d3dfcf07aeba2db32f12db0ef716e87bd55" +checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.13.0", + "base64 0.21.0", ] [[package]] name = "ryu" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" +checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" [[package]] name = "schannel" -version = "0.1.20" +version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2" +checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" dependencies = [ - "lazy_static", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -1866,6 +1952,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "scratch" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddccb15bcce173023b3fedd9436f882a0739b8dfb45e4f6b6002bee5929f61b2" + [[package]] name = "sct" version = "0.6.1" @@ -1888,9 +1980,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.7.0" +version = "2.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bc1bb97804af6631813c55739f771071e0f2ed33ee20b68c86ec505d906356c" +checksum = "a332be01508d814fed64bf28f798a146d73792121129962fdf335bb3c49a4254" dependencies = [ "bitflags", "core-foundation", @@ -1901,9 +1993,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.6.1" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0160a13a177a45bfb43ce71c01580998474f556ad854dcbca936dd2841a5c556" +checksum = "31c9bb296072e961fcbd8853511dd39c2d8be2deb1e17c6860b1d30732b323b4" dependencies = [ "core-foundation-sys", "libc", @@ -1911,24 +2003,24 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e25dfac463d778e353db5be2449d1cce89bd6fd23c9f1ea21310ce6e5a1b29c4" +checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" [[package]] name = "serde" -version = "1.0.145" +version = "1.0.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b" +checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.145" +version = "1.0.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c" +checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" dependencies = [ "proc-macro2", "quote", @@ -1948,9 +2040,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.85" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44" +checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" dependencies = [ "itoa", "ryu", @@ -2002,9 +2094,9 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" dependencies = [ "libc", ] @@ -2022,9 +2114,9 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" dependencies = [ "autocfg", ] @@ -2076,9 +2168,9 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" -version = "1.0.101" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e90cde112c4b9690b8cbe810cba9ddd8bc1d7472e2cae317b69e9438c1cba7d2" +checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" dependencies = [ "proc-macro2", "quote", @@ -2110,33 +2202,33 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.1.3" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" dependencies = [ "winapi-util", ] [[package]] name = "textwrap" -version = "0.15.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "949517c0cf1bf4ee812e2e07e08ab448e3ae0d23472aee8a06c985f0c8815b16" +checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" +checksum = "6a9cd18aa97d5c45c6603caea1da6628790b37f7a34b6ca89522331c5180fed0" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" +checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" dependencies = [ "proc-macro2", "quote", @@ -2145,18 +2237,19 @@ dependencies = [ [[package]] name = "thread_local" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" dependencies = [ + "cfg-if", "once_cell", ] [[package]] name = "time" -version = "0.1.44" +version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db9e6914ab8b1ae1c260a4ae7a49b6c5611b40328a735b21862567685e73255" +checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" dependencies = [ "libc", "wasi 0.10.0+wasi-snapshot-preview1", @@ -2165,12 +2258,28 @@ dependencies = [ [[package]] name = "time" -version = "0.3.15" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d634a985c4d4238ec39cacaed2e7ae552fbd3c476b552c1deac3021b7d7eaf0c" +checksum = "53250a3b3fed8ff8fd988587d8925d26a83ac3845d9e03b220b37f34c2b8d6c2" dependencies = [ - "libc", - "num_threads", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" + +[[package]] +name = "time-macros" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a460aeb8de6dcb0f381e1ee05f1cd56fcf5a5f6eb8187ff3d8f0b11078d38b7c" +dependencies = [ + "time-core", ] [[package]] @@ -2184,15 +2293,15 @@ dependencies = [ [[package]] name = "tinyvec_macros" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.21.2" +version = "1.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e03c497dc955702ba729190dc4aac6f2a0ce97f913e5b1b5912fc5039d9099" +checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" dependencies = [ "autocfg", "bytes", @@ -2205,14 +2314,14 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "winapi", + "windows-sys 0.42.0", ] [[package]] name = "tokio-macros" -version = "1.8.0" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9724f9a975fb987ef7a3cd9be0350edcbe130698af5b8f7a631e23d42d052484" +checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" dependencies = [ "proc-macro2", "quote", @@ -2221,9 +2330,9 @@ dependencies = [ [[package]] name = "tokio-native-tls" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d995660bd2b7f8c1568414c1126076c13fbb725c40112dc0120b78eb9b717b" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" dependencies = [ "native-tls", "tokio", @@ -2246,16 +2355,16 @@ version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" dependencies = [ - "rustls 0.20.6", + "rustls 0.20.8", "tokio", "webpki 0.22.0", ] [[package]] name = "tokio-stream" -version = "0.1.10" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6edf2d6bc038a43d31353570e27270603f4648d18f5ed10c0e179abe43255af" +checksum = "8fb52b74f05dbf495a8fba459fdc331812b96aa086d9eb78101fa0d4569c3313" dependencies = [ "futures-core", "pin-project-lite", @@ -2264,9 +2373,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.4" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bb2e075f03b3d66d8d8785356224ba688d2906a371015e225beeb65ca92c740" +checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" dependencies = [ "bytes", "futures-core", @@ -2278,9 +2387,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" dependencies = [ "indexmap", "serde", @@ -2304,9 +2413,9 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" [[package]] name = "tower-service" @@ -2316,9 +2425,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.36" +version = "0.1.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fce9567bd60a67d08a16488756721ba392f24f29006402881e43b19aac64307" +checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", "log", @@ -2329,9 +2438,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" +checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", "quote", @@ -2340,9 +2449,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.29" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aeea4303076558a00714b823f9ad67d58a3bbda1df83d8827d21193156e22f7" +checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a" dependencies = [ "once_cell", "valuable", @@ -2375,12 +2484,12 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60db860322da191b40952ad9affe65ea23e7dd6a5c442c2c42865810c6ab8e6b" +checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" dependencies = [ - "ansi_term", "matchers", + "nu-ansi-term", "once_cell", "regex", "sharded-slab", @@ -2393,15 +2502,15 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" [[package]] name = "typenum" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "unicase" @@ -2414,15 +2523,15 @@ dependencies = [ [[package]] name = "unicode-bidi" -version = "0.3.8" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992" +checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" [[package]] name = "unicode-ident" -version = "1.0.4" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd" +checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" [[package]] name = "unicode-normalization" @@ -2433,6 +2542,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-width" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" + [[package]] name = "untrusted" version = "0.7.1" @@ -2505,9 +2620,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268" +checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -2515,9 +2630,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8ffb332579b0557b52d268b91feab8df3615f265d5270fec2a8c95b17c1142" +checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" dependencies = [ "bumpalo", "log", @@ -2530,9 +2645,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.33" +version = "0.4.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23639446165ca5a5de86ae1d8896b737ae80319560fbaa4c2887b7da6e7ebd7d" +checksum = "f219e0d211ba40266969f6dbdd90636da12f75bee4fc9d6c23d1260dadb51454" dependencies = [ "cfg-if", "js-sys", @@ -2542,9 +2657,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "052be0f94026e6cbc75cdefc9bae13fd6052cdcaf532fa6c45e7ae33a1e6c810" +checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2552,9 +2667,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" +checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", @@ -2565,15 +2680,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" +checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" [[package]] name = "web-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcda906d8be16e728fd5adc5b729afad4e444e106ab28cd1c7256e54fa61510f" +checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97" dependencies = [ "js-sys", "wasm-bindgen", @@ -2601,9 +2716,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.22.5" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368bfe657969fb01238bb756d351dcade285e0f6fcbd36dcb23359a5169975be" +checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" dependencies = [ "webpki 0.22.0", ] @@ -2641,46 +2756,84 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-sys" -version = "0.36.1" +version = "0.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" dependencies = [ + "windows_aarch64_gnullvm", "windows_aarch64_msvc", "windows_i686_gnu", "windows_i686_msvc", "windows_x86_64_gnu", + "windows_x86_64_gnullvm", "windows_x86_64_msvc", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" + [[package]] name = "windows_aarch64_msvc" -version = "0.36.1" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" +checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" [[package]] name = "windows_i686_gnu" -version = "0.36.1" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" +checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" [[package]] name = "windows_i686_msvc" -version = "0.36.1" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" +checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" [[package]] name = "windows_x86_64_gnu" -version = "0.36.1" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" +checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" [[package]] name = "windows_x86_64_msvc" -version = "0.36.1" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" +checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" [[package]] name = "winreg" @@ -2711,9 +2864,9 @@ checksum = "c394b5bd0c6f669e7275d9c20aa90ae064cb22e75a1cad54e1b34088034b149f" [[package]] name = "zip" -version = "0.6.2" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf225bcf73bb52cbb496e70475c7bd7a3f769df699c0020f6c7bd9a96dcf0b8d" +checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef" dependencies = [ "byteorder", "crc32fast", diff --git a/tools/ci-cdk/canary-runner/Cargo.toml b/tools/ci-cdk/canary-runner/Cargo.toml index aaa0872a0c2..d2325d59f9f 100644 --- a/tools/ci-cdk/canary-runner/Cargo.toml +++ b/tools/ci-cdk/canary-runner/Cargo.toml @@ -33,4 +33,4 @@ tracing-subscriber = { version = "0.3.15", features = ["env-filter", "fmt"] } zip = { version = "0.6.2", default-features = false, features = ["deflate"] } [dev-dependencies] -pretty_assertions = "1.1" +pretty_assertions = "1.3" diff --git a/tools/ci-cdk/canary-runner/additional-ci b/tools/ci-cdk/canary-runner/additional-ci new file mode 100755 index 00000000000..14e8002b88c --- /dev/null +++ b/tools/ci-cdk/canary-runner/additional-ci @@ -0,0 +1,5 @@ +# run build-bundle on musl to verify that everything works +cargo run -- build-bundle \ + --sdk-release-tag release-2022-12-14 \ + --canary-path ../canary-lambda \ + --rust-version stable --musl diff --git a/tools/ci-cdk/canary-runner/src/build_bundle.rs b/tools/ci-cdk/canary-runner/src/build_bundle.rs index aa3857f3309..1f0c4a3f982 100644 --- a/tools/ci-cdk/canary-runner/src/build_bundle.rs +++ b/tools/ci-cdk/canary-runner/src/build_bundle.rs @@ -53,6 +53,7 @@ tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } uuid = { version = "0.8", features = ["v4"] } tokio-stream = "0" tracing-texray = "0.1.1" +reqwest = { version = "0.11.14", features = ["rustls-tls"], default-features = false } "#; const REQUIRED_SDK_CRATES: &[&str] = &[ @@ -428,6 +429,7 @@ tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } uuid = { version = "0.8", features = ["v4"] } tokio-stream = "0" tracing-texray = "0.1.1" +reqwest = { version = "0.11.14", features = ["rustls-tls"], default-features = false } aws-config = { path = "some/sdk/path/aws-config" } aws-sdk-s3 = { path = "some/sdk/path/s3" } aws-sdk-ec2 = { path = "some/sdk/path/ec2" } @@ -490,6 +492,7 @@ tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } uuid = { version = "0.8", features = ["v4"] } tokio-stream = "0" tracing-texray = "0.1.1" +reqwest = { version = "0.11.14", features = ["rustls-tls"], default-features = false } aws-config = "0.46.0" aws-sdk-s3 = "0.20.0" aws-sdk-ec2 = "0.19.0" diff --git a/tools/ci-scripts/codegen-diff-revisions.py b/tools/ci-scripts/codegen-diff-revisions.py index ae2bc420114..0892719bcca 100755 --- a/tools/ci-scripts/codegen-diff-revisions.py +++ b/tools/ci-scripts/codegen-diff-revisions.py @@ -14,22 +14,21 @@ # # ``` # $ cd test/smithy-rs -# $ ../../smithy-rs/tools/codegen-diff-revisions.py . +# $ ../../smithy-rs/tools/ci-scripts/codegen-diff-revisions.py . # ``` # # It will diff the generated code from HEAD against any commit hash you feed it. If you want to test # a specific range, change the HEAD of the test repository. # -# This script requires `diff2html-cli` to be installed from NPM: +# This script requires `difftags` to be installed from `tools/ci-build/difftags`: # ``` -# $ npm install -g diff2html-cli@5.1.11 +# $ cargo install --path tools/ci-build/difftags # ``` # Make sure the local version matches the version referenced from the GitHub Actions workflow. import os import sys import subprocess -import tempfile import shlex @@ -89,29 +88,35 @@ def main(): def generate_and_commit_generated_code(revision_sha): # Clean the build artifacts before continuing run("rm -rf aws/sdk/build") + run("cd rust-runtime/aws-smithy-http-server-python/examples && make distclean", shell=True) run("./gradlew codegen-core:clean codegen-client:clean codegen-server:clean aws:sdk-codegen:clean") # Generate code - run("./gradlew --rerun-tasks :aws:sdk:assemble") - run("./gradlew --rerun-tasks :codegen-server-test:assemble") - run("./gradlew --rerun-tasks :codegen-server-test:python:assemble") + run("./gradlew --rerun-tasks aws:sdk:assemble codegen-client-test:assemble codegen-server-test:assemble") + run("cd rust-runtime/aws-smithy-http-server-python/examples && make build", shell=True, check=False) # Move generated code into codegen-diff/ directory run(f"rm -rf {OUTPUT_PATH}") run(f"mkdir {OUTPUT_PATH}") run(f"mv aws/sdk/build/aws-sdk {OUTPUT_PATH}/") + run(f"mv codegen-client-test/build/smithyprojections/codegen-client-test {OUTPUT_PATH}/") run(f"mv codegen-server-test/build/smithyprojections/codegen-server-test {OUTPUT_PATH}/") - run(f"mv codegen-server-test/python/build/smithyprojections/codegen-server-test-python {OUTPUT_PATH}/") + run(f"mv rust-runtime/aws-smithy-http-server-python/examples/pokemon-service-server-sdk/ {OUTPUT_PATH}/codegen-server-test-python/", check=False) + + # Clean up the SDK directory + run(f"rm -f {OUTPUT_PATH}/aws-sdk/versions.toml") + + # Clean up the client-test folder + run(f"rm -rf {OUTPUT_PATH}/codegen-client-test/source") + run(f"find {OUTPUT_PATH}/codegen-client-test | " + f"grep -E 'smithy-build-info.json|sources/manifest|model.json' | " + f"xargs rm -f", shell=True) # Clean up the server-test folder run(f"rm -rf {OUTPUT_PATH}/codegen-server-test/source") - run(f"rm -rf {OUTPUT_PATH}/codegen-server-test-python/source") run(f"find {OUTPUT_PATH}/codegen-server-test | " f"grep -E 'smithy-build-info.json|sources/manifest|model.json' | " f"xargs rm -f", shell=True) - run(f"find {OUTPUT_PATH}/codegen-server-test-python | " - f"grep -E 'smithy-build-info.json|sources/manifest|model.json' | " - f"xargs rm -f", shell=True) run(f"git add -f {OUTPUT_PATH}") run(f"git -c 'user.name=GitHub Action (generated code preview)' " @@ -155,6 +160,10 @@ def make_diffs(base_commit_sha, head_commit_sha): head_commit_sha, "aws-sdk", whitespace=True) sdk_nows = make_diff("AWS SDK", f"{OUTPUT_PATH}/aws-sdk", base_commit_sha, head_commit_sha, "aws-sdk-ignore-whitespace", whitespace=False) + client_ws = make_diff("Client Test", f"{OUTPUT_PATH}/codegen-client-test", base_commit_sha, + head_commit_sha, "client-test", whitespace=True) + client_nows = make_diff("Client Test", f"{OUTPUT_PATH}/codegen-client-test", base_commit_sha, + head_commit_sha, "client-test-ignore-whitespace", whitespace=False) server_ws = make_diff("Server Test", f"{OUTPUT_PATH}/codegen-server-test", base_commit_sha, head_commit_sha, "server-test", whitespace=True) server_nows = make_diff("Server Test", f"{OUTPUT_PATH}/codegen-server-test", base_commit_sha, @@ -166,6 +175,8 @@ def make_diffs(base_commit_sha, head_commit_sha): sdk_links = diff_link('AWS SDK', 'No codegen difference in the AWS SDK', sdk_ws, 'ignoring whitespace', sdk_nows) + client_links = diff_link('Client Test', 'No codegen difference in the Client Test', + client_ws, 'ignoring whitespace', client_nows) server_links = diff_link('Server Test', 'No codegen difference in the Server Test', server_ws, 'ignoring whitespace', server_nows) server_links_python = diff_link('Server Test Python', 'No codegen difference in the Server Test Python', @@ -173,6 +184,7 @@ def make_diffs(base_commit_sha, head_commit_sha): # Save escaped newlines so that the GitHub Action script gets the whole message return "A new generated diff is ready to view.\\n"\ f"- {sdk_links}\\n"\ + f"- {client_links}\\n"\ f"- {server_links}\\n"\ f"- {server_links_python}\\n" @@ -188,10 +200,10 @@ def eprint(*args, **kwargs): # Runs a shell command -def run(command, shell=False): +def run(command, shell=False, check=True): if not shell: command = shlex.split(command) - subprocess.run(command, stdout=sys.stderr, stderr=sys.stderr, shell=shell, check=True) + subprocess.run(command, stdout=sys.stderr, stderr=sys.stderr, shell=shell, check=check) # Returns the output from a shell command. Bails if the command failed