Skip to content

Commit

Permalink
openapi3: refacto ref-resolving end conditions (#874)
Browse files Browse the repository at this point in the history
  • Loading branch information
fenollp committed Nov 26, 2023
1 parent 377bb40 commit 663b0dd
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 19 deletions.
81 changes: 62 additions & 19 deletions openapi3/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,22 @@ func (loader *Loader) resolveRef(doc *T, ref string, path *url.URL) (*T, string,
return doc, fragment, resolvedPath, nil
}

var (
errMUSTCallback = errors.New("invalid callback: value MUST be an object")
errMUSTExample = errors.New("invalid example: value MUST be an object")
errMUSTHeader = errors.New("invalid header: value MUST be an object")
errMUSTLink = errors.New("invalid link: value MUST be an object")
errMUSTParameter = errors.New("invalid parameter: value MUST be an object")
errMUSTPathItem = errors.New("invalid path item: value MUST be an object")
errMUSTRequestBody = errors.New("invalid requestBody: value MUST be an object")
errMUSTResponse = errors.New("invalid response: value MUST be an object")
errMUSTSchema = errors.New("invalid schema: value MUST be an object")
errMUSTSecurityScheme = errors.New("invalid securityScheme: value MUST be an object")
)

func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPath *url.URL) (err error) {
if component == nil {
return errors.New("invalid header: value MUST be an object")
if component.isEmpty() {
return errMUSTHeader
}

if component.Value != nil {
Expand All @@ -520,6 +533,9 @@ func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPat
return err
}
if err := loader.resolveHeaderRef(doc, &resolved, componentPath); err != nil {
if err == errMUSTHeader {
return nil
}
return err
}
component.Value = resolved.Value
Expand All @@ -539,8 +555,8 @@ func (loader *Loader) resolveHeaderRef(doc *T, component *HeaderRef, documentPat
}

func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, documentPath *url.URL) (err error) {
if component == nil {
return errors.New("invalid parameter: value MUST be an object")
if component.isEmpty() {
return errMUSTParameter
}

if component.Value != nil {
Expand All @@ -567,6 +583,9 @@ func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, docum
return err
}
if err := loader.resolveParameterRef(doc, &resolved, componentPath); err != nil {
if err == errMUSTParameter {
return nil
}
return err
}
component.Value = resolved.Value
Expand Down Expand Up @@ -596,8 +615,8 @@ func (loader *Loader) resolveParameterRef(doc *T, component *ParameterRef, docum
}

func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, documentPath *url.URL) (err error) {
if component == nil {
return errors.New("invalid requestBody: value MUST be an object")
if component.isEmpty() {
return errMUSTRequestBody
}

if component.Value != nil {
Expand All @@ -624,6 +643,9 @@ func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, d
return err
}
if err = loader.resolveRequestBodyRef(doc, &resolved, componentPath); err != nil {
if err == errMUSTRequestBody {
return nil
}
return err
}
component.Value = resolved.Value
Expand Down Expand Up @@ -660,8 +682,8 @@ func (loader *Loader) resolveRequestBodyRef(doc *T, component *RequestBodyRef, d
}

func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documentPath *url.URL) (err error) {
if component == nil {
return errors.New("invalid response: value MUST be an object")
if component.isEmpty() {
return errMUSTResponse
}

if component.Value != nil {
Expand All @@ -688,6 +710,9 @@ func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documen
return err
}
if err := loader.resolveResponseRef(doc, &resolved, componentPath); err != nil {
if err == errMUSTResponse {
return nil
}
return err
}
component.Value = resolved.Value
Expand Down Expand Up @@ -735,8 +760,8 @@ func (loader *Loader) resolveResponseRef(doc *T, component *ResponseRef, documen
}

func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPath *url.URL, visited []string) (err error) {
if component == nil {
return errors.New("invalid schema: value MUST be an object")
if component.isEmpty() {
return errMUSTSchema
}

if component.Value != nil {
Expand Down Expand Up @@ -769,6 +794,9 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat
return err
}
if err := loader.resolveSchemaRef(doc, &resolved, componentPath, visited); err != nil {
if err == errMUSTSchema {
return nil
}
return err
}
component.Value = resolved.Value
Expand Down Expand Up @@ -823,8 +851,8 @@ func (loader *Loader) resolveSchemaRef(doc *T, component *SchemaRef, documentPat
}

func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecuritySchemeRef, documentPath *url.URL) (err error) {
if component == nil {
return errors.New("invalid securityScheme: value MUST be an object")
if component.isEmpty() {
return errMUSTSecurityScheme
}

if component.Value != nil {
Expand All @@ -851,6 +879,9 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme
return err
}
if err := loader.resolveSecuritySchemeRef(doc, &resolved, componentPath); err != nil {
if err == errMUSTSecurityScheme {
return nil
}
return err
}
component.Value = resolved.Value
Expand All @@ -860,8 +891,8 @@ func (loader *Loader) resolveSecuritySchemeRef(doc *T, component *SecurityScheme
}

func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentPath *url.URL) (err error) {
if component == nil {
return errors.New("invalid example: value MUST be an object")
if component.isEmpty() {
return errMUSTExample
}

if component.Value != nil {
Expand All @@ -888,6 +919,9 @@ func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentP
return err
}
if err := loader.resolveExampleRef(doc, &resolved, componentPath); err != nil {
if err == errMUSTExample {
return nil
}
return err
}
component.Value = resolved.Value
Expand All @@ -897,8 +931,8 @@ func (loader *Loader) resolveExampleRef(doc *T, component *ExampleRef, documentP
}

func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documentPath *url.URL) (err error) {
if component == nil {
return errors.New("invalid callback: value MUST be an object")
if component.isEmpty() {
return errMUSTCallback
}

if component.Value != nil {
Expand All @@ -925,6 +959,9 @@ func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documen
return err
}
if err = loader.resolveCallbackRef(doc, &resolved, componentPath); err != nil {
if err == errMUSTCallback {
return nil
}
return err
}
component.Value = resolved.Value
Expand All @@ -944,8 +981,8 @@ func (loader *Loader) resolveCallbackRef(doc *T, component *CallbackRef, documen
}

func (loader *Loader) resolveLinkRef(doc *T, component *LinkRef, documentPath *url.URL) (err error) {
if component == nil {
return errors.New("invalid link: value MUST be an object")
if component.isEmpty() {
return errMUSTLink
}

if component.Value != nil {
Expand All @@ -972,6 +1009,9 @@ func (loader *Loader) resolveLinkRef(doc *T, component *LinkRef, documentPath *u
return err
}
if err := loader.resolveLinkRef(doc, &resolved, componentPath); err != nil {
if err == errMUSTLink {
return nil
}
return err
}
component.Value = resolved.Value
Expand All @@ -982,7 +1022,7 @@ func (loader *Loader) resolveLinkRef(doc *T, component *LinkRef, documentPath *u

func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPath *url.URL) (err error) {
if pathItem == nil {
err = errors.New("invalid path item: value MUST be an object")
err = errMUSTPathItem
return
}

Expand All @@ -999,6 +1039,9 @@ func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPat
} else {
var resolved PathItem
if doc, documentPath, err = loader.resolveComponent(doc, ref, documentPath, &resolved); err != nil {
if err == errMUSTPathItem {
return nil
}
return
}
*pathItem = resolved
Expand Down
18 changes: 18 additions & 0 deletions openapi3/refs.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type CallbackRef struct {

var _ jsonpointer.JSONPointable = (*CallbackRef)(nil)

func (x *CallbackRef) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }

// MarshalYAML returns the YAML encoding of CallbackRef.
func (x CallbackRef) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down Expand Up @@ -96,6 +98,8 @@ type ExampleRef struct {

var _ jsonpointer.JSONPointable = (*ExampleRef)(nil)

func (x *ExampleRef) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }

// MarshalYAML returns the YAML encoding of ExampleRef.
func (x ExampleRef) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down Expand Up @@ -172,6 +176,8 @@ type HeaderRef struct {

var _ jsonpointer.JSONPointable = (*HeaderRef)(nil)

func (x *HeaderRef) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }

// MarshalYAML returns the YAML encoding of HeaderRef.
func (x HeaderRef) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down Expand Up @@ -248,6 +254,8 @@ type LinkRef struct {

var _ jsonpointer.JSONPointable = (*LinkRef)(nil)

func (x *LinkRef) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }

// MarshalYAML returns the YAML encoding of LinkRef.
func (x LinkRef) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down Expand Up @@ -324,6 +332,8 @@ type ParameterRef struct {

var _ jsonpointer.JSONPointable = (*ParameterRef)(nil)

func (x *ParameterRef) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }

// MarshalYAML returns the YAML encoding of ParameterRef.
func (x ParameterRef) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down Expand Up @@ -400,6 +410,8 @@ type RequestBodyRef struct {

var _ jsonpointer.JSONPointable = (*RequestBodyRef)(nil)

func (x *RequestBodyRef) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }

// MarshalYAML returns the YAML encoding of RequestBodyRef.
func (x RequestBodyRef) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down Expand Up @@ -476,6 +488,8 @@ type ResponseRef struct {

var _ jsonpointer.JSONPointable = (*ResponseRef)(nil)

func (x *ResponseRef) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }

// MarshalYAML returns the YAML encoding of ResponseRef.
func (x ResponseRef) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down Expand Up @@ -552,6 +566,8 @@ type SchemaRef struct {

var _ jsonpointer.JSONPointable = (*SchemaRef)(nil)

func (x *SchemaRef) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }

// MarshalYAML returns the YAML encoding of SchemaRef.
func (x SchemaRef) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down Expand Up @@ -628,6 +644,8 @@ type SecuritySchemeRef struct {

var _ jsonpointer.JSONPointable = (*SecuritySchemeRef)(nil)

func (x *SecuritySchemeRef) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }

// MarshalYAML returns the YAML encoding of SecuritySchemeRef.
func (x SecuritySchemeRef) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down
2 changes: 2 additions & 0 deletions refs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type ${type}Ref struct {
var _ jsonpointer.JSONPointable = (*${type}Ref)(nil)
func (x *${type}Ref) isEmpty() bool { return x == nil || x.Ref == "" && x.Value == nil }
// MarshalYAML returns the YAML encoding of ${type}Ref.
func (x ${type}Ref) MarshalYAML() (interface{}, error) {
if ref := x.Ref; ref != "" {
Expand Down

0 comments on commit 663b0dd

Please sign in to comment.