Skip to content

Commit

Permalink
fix(property-provider): avoid generating default rejected promise whe…
Browse files Browse the repository at this point in the history
…n chaining (#4843)
  • Loading branch information
trivikr committed Jun 14, 2023
1 parent 8668bab commit ecc9b5f
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 34 deletions.
115 changes: 92 additions & 23 deletions packages/property-provider/src/chain.spec.ts
Original file line number Diff line number Diff line change
@@ -1,48 +1,117 @@
import { chain } from "./chain";
import { fromStatic } from "./fromStatic";
import { ProviderError } from "./ProviderError";

const resolveStatic = (staticValue: unknown) => jest.fn().mockResolvedValue(staticValue);
const rejectWithError = (errorMsg: string) => jest.fn().mockRejectedValue(new Error(errorMsg));
const rejectWithProviderError = (errorMsg: string) => jest.fn().mockRejectedValue(new ProviderError(errorMsg));

describe("chain", () => {
it("should distill many credential providers into one", async () => {
const provider = chain(fromStatic("foo"), fromStatic("bar"));

const provider = chain(resolveStatic("foo"), resolveStatic("bar"));
expect(typeof (await provider())).toBe("string");
});

it("should return the resolved value of the first successful promise", async () => {
const provider = chain(
() => Promise.reject(new ProviderError("Move along")),
() => Promise.reject(new ProviderError("Nothing to see here")),
fromStatic("foo")
);
const expectedOutput = "foo";
const providers = [
rejectWithProviderError("Move along"),
rejectWithProviderError("Nothing to see here"),
resolveStatic(expectedOutput),
];

expect(await provider()).toBe("foo");
try {
const result = await chain(...providers)();
expect(result).toBe(expectedOutput);
} catch (error) {
throw error;
}

expect(providers[0]).toHaveBeenCalledTimes(1);
expect(providers[1]).toHaveBeenCalledTimes(1);
expect(providers[2]).toHaveBeenCalledTimes(1);
});

it("should not invoke subsequent providers once one resolves", async () => {
const expectedOutput = "foo";
const providers = [
jest.fn().mockRejectedValue(new ProviderError("Move along")),
jest.fn().mockResolvedValue("foo"),
jest.fn(() => fail("This provider should not be invoked")),
rejectWithProviderError("Move along"),
resolveStatic(expectedOutput),
rejectWithProviderError("This provider should not be invoked"),
];

expect(await chain(...providers)()).toBe("foo");
expect(providers[0].mock.calls.length).toBe(1);
expect(providers[1].mock.calls.length).toBe(1);
expect(providers[2].mock.calls.length).toBe(0);
try {
const result = await chain(...providers)();
expect(result).toBe(expectedOutput);
} catch (error) {
throw error;
}

expect(providers[0]).toHaveBeenCalledTimes(1);
expect(providers[1]).toHaveBeenCalledTimes(1);
expect(providers[2]).not.toHaveBeenCalled();
});

describe("should throw if no provider resolves", () => {
const expectedErrorMsg = "Last provider failed";

it.each([
[ProviderError, rejectWithProviderError(expectedErrorMsg)],
[Error, rejectWithError(expectedErrorMsg)],
])("case %p", async (errorType, errorProviderMockFn) => {
const firstProviderWhichRejects = rejectWithProviderError("Move along");
try {
await chain(firstProviderWhichRejects, errorProviderMockFn)();
throw new Error("Should not get here");
} catch (error) {
expect(error).toEqual(new errorType(expectedErrorMsg));
}
expect(firstProviderWhichRejects).toHaveBeenCalledTimes(1);
expect(errorProviderMockFn).toHaveBeenCalledTimes(1);
});
});

it("should halt if an unrecognized error is encountered", async () => {
const provider = chain(
() => Promise.reject(new ProviderError("Move along")),
() => Promise.reject(new Error("Unrelated failure")),
fromStatic("foo")
);
const expectedErrorMsg = "Unrelated failure";
const providers = [rejectWithProviderError("Move along"), rejectWithError(expectedErrorMsg), resolveStatic("foo")];

try {
await chain(...providers)();
throw new Error("Should not get here");
} catch (error) {
expect(error).toEqual(new Error(expectedErrorMsg));
}

expect(providers[0]).toHaveBeenCalledTimes(1);
expect(providers[1]).toHaveBeenCalledTimes(1);
expect(providers[2]).not.toHaveBeenCalled();
});

it("should halt if ProviderError explicitly requests it", async () => {
const expectedError = new ProviderError("ProviderError with tryNextLink set to false", false);
const providers = [
rejectWithProviderError("Move along"),
jest.fn().mockRejectedValue(expectedError),
resolveStatic("foo"),
];

try {
await chain(...providers)();
throw new Error("Should not get here");
} catch (error) {
expect(error).toEqual(expectedError);
}

await expect(provider()).rejects.toMatchObject(new Error("Unrelated failure"));
expect(providers[0]).toHaveBeenCalledTimes(1);
expect(providers[1]).toHaveBeenCalledTimes(1);
expect(providers[2]).not.toHaveBeenCalled();
});

it("should reject chains with no links", async () => {
await expect(chain()()).rejects.toMatchObject(new Error("No providers in chain"));
try {
await chain()();
throw new Error("Should not get here");
} catch (error) {
expect(error).toEqual(new Error("No providers in chain"));
}
});
});
28 changes: 17 additions & 11 deletions packages/property-provider/src/chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { ProviderError } from "./ProviderError";

/**
* @internal
*
*
* Compose a single credential provider function from multiple credential
* providers. The first provider in the argument list will always be invoked;
* subsequent providers in the list will be invoked in the order in which the
Expand All @@ -13,19 +13,25 @@ import { ProviderError } from "./ProviderError";
* If no providers were received or no provider resolves successfully, the
* returned promise will be rejected.
*/
export function chain<T>(...providers: Array<Provider<T>>): Provider<T> {
return () => {
let promise: Promise<T> = Promise.reject(new ProviderError("No providers in chain"));
export const chain =
<T>(...providers: Array<Provider<T>>): Provider<T> =>
async () => {
if (providers.length === 0) {
throw new ProviderError("No providers in chain");
}

let lastProviderError: Error | undefined;
for (const provider of providers) {
promise = promise.catch((err: any) => {
try {
const credentials = await provider();
return credentials;
} catch (err) {
lastProviderError = err;
if (err?.tryNextLink) {
return provider();
continue;
}

throw err;
});
}
}

return promise;
throw lastProviderError;
};
}

0 comments on commit ecc9b5f

Please sign in to comment.