Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op(_scaled_dot_product_flash_attention) | feat(torchlib) #1043

Merged
merged 8 commits into from
Sep 6, 2023

Conversation

titaiwangms
Copy link
Contributor

@titaiwangms titaiwangms commented Aug 31, 2023

_scaled_dot_product_flash_attention is one out of three ATen implementations of nn.functional.scaled_dot_product_attention according to the page: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html.

As of which one of three ATen operator is representing nn.functional.scaled_dot_product_attention in a model is decided by a context manager: https://pytorch.org/docs/stable/backends.html. From ONNX perspective, they have no difference except the function signature.

Only the first result matters in terms of the model prediction, and the unrelated outputs are following the below code:

@register_meta(
    [
        aten._scaled_dot_product_flash_attention,
    ]
)
def meta__scaled_dot_product_flash(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    return_debug_mask: bool = False,
    scale: Optional[float] = None,
):
    batch_size = query.size(0)
    num_heads = query.size(1)
    max_seqlen_batch_q = query.size(2)
    head_dim = query.size(3)

    max_seqlen_batch_k = key.size(2)
    if device_hint(query) == "cpu":
        Nnz_q = batch_size * max_seqlen_batch_q
        query_t = query.transpose(1, 2)
        query_reshaped = query_t.reshape(Nnz_q, num_heads, head_dim)
        attention = torch.empty_like(query_reshaped, device=query.device)
        attention = attention.view(
            batch_size, max_seqlen_batch_q, num_heads, head_dim
        ).transpose(1, 2)
        logsumexp = torch.empty(
            (
                batch_size,
                max_seqlen_batch_q,
                num_heads,
            ),
            dtype=torch.float,
            device=query.device,
        ).transpose(1, 2)
        return (
            attention,
            logsumexp,
            torch.empty((), dtype=torch.int32, device="meta"),
            torch.empty((), dtype=torch.int32, device="meta"),
            0,
            0,
            torch.empty((), dtype=torch.long, device="meta"),
            torch.empty((), dtype=torch.long, device="meta"),
            torch.empty((), dtype=query.dtype, device=query.device),
        )

    # Cuda Path
    query_t = query.transpose(1, 2)
    attention = torch.empty_like(query_t).transpose(1, 2)
    logsumexp = torch.empty(
        (batch_size, num_heads, max_seqlen_batch_q),
        dtype=torch.float,
        device=query.device,
    )
    cumulative_sequence_length_q = torch.empty(
        batch_size + 1, dtype=torch.int32, device="meta"
    )
    cumulative_sequence_length_k = torch.empty(
        batch_size + 1, dtype=torch.int32, device="meta"
    )

    if return_debug_mask:
        blocksize_c = 128 if head_dim > 64 else 256
        max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
        if max_seqlen_batch_k <= 128:
            max_seqlen_k = 128
        elif max_seqlen_batch_k <= 256:
            max_seqlen_k = 256
        debug_mask = torch.empty(
            (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
            dtype=query.dtype,
            device=query.device,
        )
    else:
        debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)

    # Note [Seed and Offset]: device for seed and offset below depends on whether we are
    # capturing or not, but at the time of tracing we don't know if we
    # are going to use cudagraphs or not, so we return meta tensors here
    # it's possible we'll need to have some special handling in inductor for sdpa

    return (
        attention,
        logsumexp,
        None,
        None,
        max_seqlen_batch_q,
        max_seqlen_batch_k,
        torch.empty((), dtype=torch.long, device="meta"),
        torch.empty((), dtype=torch.long, device="meta"),
        debug_mask,
    )

NOTE: PyTorch converter should consider None would appear in _fill_tensor_shape_type, otherwise, the exporter crashes.

@titaiwangms titaiwangms added the topic: torch_lib Related to the torch/aten function lib in development label Aug 31, 2023
@codecov
Copy link

codecov bot commented Aug 31, 2023

Codecov Report

Merging #1043 (376765b) into main (0c25215) will increase coverage by 0.04%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main    #1043      +/-   ##
==========================================
+ Coverage   77.68%   77.73%   +0.04%     
==========================================
  Files         114      114              
  Lines       14445    14473      +28     
  Branches     1545     1546       +1     
==========================================
+ Hits        11222    11250      +28     
  Misses       2857     2857              
  Partials      366      366              
Files Changed Coverage Δ
...ipt/tests/function_libs/torch_lib/ops_test_data.py 96.03% <ø> (ø)
onnxscript/function_libs/torch_lib/ops/nn.py 80.06% <100.00%> (+0.44%) ⬆️
...ript/tests/function_libs/torch_lib/extra_opinfo.py 98.29% <100.00%> (+0.08%) ⬆️

@titaiwangms titaiwangms added the hold on merging Don't merge yet label Aug 31, 2023
@titaiwangms titaiwangms marked this pull request as draft August 31, 2023 20:45
@titaiwangms titaiwangms marked this pull request as ready for review September 5, 2023 17:27
return (
result,
logsumexp,
empty_tensor_int,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I create TInt for these guys?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

INT64?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one in embedding remains TFloat though. But I can do INT64 in this case. Depends should we follow native-func sig or what we really return.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the return types for the empty float values need to be TFloat, do we need a CaskLike self here? Otherwise it would be FLOAT because the dtype is set and not dependent on the input?

@justinchuby justinchuby self-requested a review September 5, 2023 17:55
@titaiwangms titaiwangms removed the hold on merging Don't merge yet label Sep 5, 2023
@justinchuby
Copy link
Contributor

lgtm with the return types fixed

@titaiwangms
Copy link
Contributor Author

@justinchuby I found CI all fails except torch-nightly. I guess it needs torch-nightly to test this op?

@justinchuby
Copy link
Contributor

@justinchuby I found CI all fails except torch-nightly. I guess it needs torch-nightly to test this op?

Looks like so. We can skip the tests for older torch by using .skip(enabled_if=version_utils.torch_older_than("2.1"))

@titaiwangms titaiwangms merged commit 0e9c495 into microsoft:main Sep 6, 2023
29 of 30 checks passed
@titaiwangms titaiwangms deleted the titaiwang/support_flash branch September 6, 2023 00:01
@justinchuby justinchuby mentioned this pull request Sep 8, 2023
titaiwangms added a commit to pytorch/pytorch that referenced this pull request Sep 12, 2023
… inputs"


Previous to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph.

This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None).

Reference test from a TorchBench model:
```python

    def test_nanogpt(self):
        import sys

        sys.path.append("/home/titaiwang")

        from nanoGPT.model import GPT, GPTConfig

        # Load the model
        kwargs = {
            "block_size": 256,
            "vocab_size": 8096,  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
            "n_layer": 2,
            "n_head": 2,
            "n_embd": 128,
            "dropout": 0.0,
            "bias": False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
        }
        config = GPTConfig(**kwargs)
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_mem_efficient=True
        ):
            model = GPT(config)
        print("Done loading model")
        inputs = torch.arange(128).view(2, 64)
        targets = torch.arange(128).view(2, 64)

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            model,
            (inputs,),
            input_kwargs={
                "targets": targets,
            },
            verbose=True,
        )
```

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this pull request Sep 12, 2023
Previous to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph.

This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None).

Reference test from a TorchBench model:
```python

    def test_nanogpt(self):
        import sys

        sys.path.append("/home/titaiwang")

        from nanoGPT.model import GPT, GPTConfig

        # Load the model
        kwargs = {
            "block_size": 256,
            "vocab_size": 8096,  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
            "n_layer": 2,
            "n_head": 2,
            "n_embd": 128,
            "dropout": 0.0,
            "bias": False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
        }
        config = GPTConfig(**kwargs)
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_mem_efficient=True
        ):
            model = GPT(config)
        print("Done loading model")
        inputs = torch.arange(128).view(2, 64)
        targets = torch.arange(128).view(2, 64)

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            model,
            (inputs,),
            input_kwargs={
                "targets": targets,
            },
            verbose=True,
        )
```

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this pull request Sep 12, 2023
… inputs"


Prior to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph.

This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None).

Reference test from a TorchBench model:
```python

    def test_nanogpt(self):
        import sys

        sys.path.append("/home/titaiwang")

        from nanoGPT.model import GPT, GPTConfig

        # Load the model
        kwargs = {
            "block_size": 256,
            "vocab_size": 8096,  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
            "n_layer": 2,
            "n_head": 2,
            "n_embd": 128,
            "dropout": 0.0,
            "bias": False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
        }
        config = GPTConfig(**kwargs)
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_mem_efficient=True
        ):
            model = GPT(config)
        print("Done loading model")
        inputs = torch.arange(128).view(2, 64)
        targets = torch.arange(128).view(2, 64)

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            model,
            (inputs,),
            input_kwargs={
                "targets": targets,
            },
            verbose=True,
        )
```

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this pull request Sep 12, 2023
Prior to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph.

This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None).

Reference test from a TorchBench model:
```python

    def test_nanogpt(self):
        import sys

        sys.path.append("/home/titaiwang")

        from nanoGPT.model import GPT, GPTConfig

        # Load the model
        kwargs = {
            "block_size": 256,
            "vocab_size": 8096,  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
            "n_layer": 2,
            "n_head": 2,
            "n_embd": 128,
            "dropout": 0.0,
            "bias": False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
        }
        config = GPTConfig(**kwargs)
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_mem_efficient=True
        ):
            model = GPT(config)
        print("Done loading model")
        inputs = torch.arange(128).view(2, 64)
        targets = torch.arange(128).view(2, 64)

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            model,
            (inputs,),
            input_kwargs={
                "targets": targets,
            },
            verbose=True,
        )
```

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Sep 13, 2023
Prior to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph.

This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None).

Reference test from a TorchBench model:
```python

    def test_nanogpt(self):
        import sys

        sys.path.append("/home/titaiwang")

        from nanoGPT.model import GPT, GPTConfig

        # Load the model
        kwargs = {
            "block_size": 256,
            "vocab_size": 8096,  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
            "n_layer": 2,
            "n_head": 2,
            "n_embd": 128,
            "dropout": 0.0,
            "bias": False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
        }
        config = GPTConfig(**kwargs)
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_mem_efficient=True
        ):
            model = GPT(config)
        print("Done loading model")
        inputs = torch.arange(128).view(2, 64)
        targets = torch.arange(128).view(2, 64)

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            model,
            (inputs,),
            input_kwargs={
                "targets": targets,
            },
            verbose=True,
        )
```
Pull Request resolved: #108708
Approved by: https://github.com/justinchuby, https://github.com/thiagocrepaldi
titaiwangms added a commit that referenced this pull request Dec 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants