Skip to content

Commit

Permalink
Merge pull request from GHSA-5mqj-xc49-246p
Browse files Browse the repository at this point in the history
  • Loading branch information
crewjam committed Mar 22, 2023
1 parent 2aeb2ef commit 8e92368
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 3 deletions.
31 changes: 31 additions & 0 deletions flate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package saml

import (
"compress/flate"
"fmt"
"io"
)

const flateUncompressLimit = 10 * 1024 * 1024 // 10MB

func newSaferFlateReader(r io.Reader) io.ReadCloser {
return &saferFlateReader{r: flate.NewReader(r)}
}

type saferFlateReader struct {
r io.ReadCloser
count int
}

func (r *saferFlateReader) Read(p []byte) (n int, err error) {
if r.count+len(p) > flateUncompressLimit {
return 0, fmt.Errorf("flate: uncompress limit exceeded (%d bytes)", flateUncompressLimit)
}
n, err = r.r.Read(p)
r.count += n
return n, err
}

func (r *saferFlateReader) Close() error {
return r.r.Close()
}
3 changes: 1 addition & 2 deletions identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package saml

import (
"bytes"
"compress/flate"
"crypto"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -363,7 +362,7 @@ func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnReques
if err != nil {
return nil, fmt.Errorf("cannot decode request: %s", err)
}
req.RequestBuffer, err = ioutil.ReadAll(flate.NewReader(bytes.NewReader(compressedRequest)))
req.RequestBuffer, err = ioutil.ReadAll(newSaferFlateReader(bytes.NewReader(compressedRequest)))
if err != nil {
return nil, fmt.Errorf("cannot decompress request: %s", err)
}
Expand Down
28 changes: 28 additions & 0 deletions identity_provider_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package saml

import (
"bytes"
"compress/flate"
"crypto"
"crypto/rsa"
"crypto/x509"
Expand Down Expand Up @@ -1013,3 +1015,29 @@ func TestIDPNoDestination(t *testing.T) {
err = req.MakeResponse()
assert.Check(t, err)
}

func TestIDPRejectDecompressionBomb(t *testing.T) {
test := NewIdentifyProviderTest(t)
test.IDP.SessionProvider = &mockSessionProvider{
GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session {
fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s",
req.RelayState, req.RequestBuffer)
return nil
},
}

//w := httptest.NewRecorder()

data := bytes.Repeat([]byte("a"), 768*1024*1024)
var compressed bytes.Buffer
w, _ := flate.NewWriter(&compressed, flate.BestCompression)
w.Write(data)
w.Close()
encoded := base64.StdEncoding.EncodeToString(compressed.Bytes())

r, _ := http.NewRequest("GET", "/dontcare?"+url.Values{
"SAMLRequest": {encoded},
}.Encode(), nil)
_, err := NewIdpAuthnRequest(&test.IDP, r)
assert.Error(t, err, "cannot decompress request: flate: uncompress limit exceeded (10485760 bytes)")
}
2 changes: 1 addition & 1 deletion service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
}
retErr.Response = string(rawResponseBuf)

gr, err := ioutil.ReadAll(flate.NewReader(bytes.NewBuffer(rawResponseBuf)))
gr, err := ioutil.ReadAll(newSaferFlateReader(bytes.NewBuffer(rawResponseBuf)))
if err != nil {
retErr.PrivateErr = err
return retErr
Expand Down

0 comments on commit 8e92368

Please sign in to comment.