Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(property-provider): avoid generating default rejected promise when chaining #4843

Merged
merged 11 commits into from
Jun 14, 2023
115 changes: 92 additions & 23 deletions packages/property-provider/src/chain.spec.ts
Original file line number Diff line number Diff line change
@@ -1,48 +1,117 @@
import { chain } from "./chain";
import { fromStatic } from "./fromStatic";
import { ProviderError } from "./ProviderError";

const resolveStatic = (staticValue: unknown) => jest.fn().mockResolvedValue(staticValue);
const rejectWithError = (errorMsg: string) => jest.fn().mockRejectedValue(new Error(errorMsg));
const rejectWithProviderError = (errorMsg: string) => jest.fn().mockRejectedValue(new ProviderError(errorMsg));

describe("chain", () => {
it("should distill many credential providers into one", async () => {
const provider = chain(fromStatic("foo"), fromStatic("bar"));

const provider = chain(resolveStatic("foo"), resolveStatic("bar"));
expect(typeof (await provider())).toBe("string");
});

it("should return the resolved value of the first successful promise", async () => {
const provider = chain(
() => Promise.reject(new ProviderError("Move along")),
() => Promise.reject(new ProviderError("Nothing to see here")),
fromStatic("foo")
);
const expectedOutput = "foo";
const providers = [
rejectWithProviderError("Move along"),
rejectWithProviderError("Nothing to see here"),
resolveStatic(expectedOutput),
];

expect(await provider()).toBe("foo");
try {
const result = await chain(...providers)();
expect(result).toBe(expectedOutput);
} catch (error) {
throw error;
}

expect(providers[0]).toHaveBeenCalledTimes(1);
expect(providers[1]).toHaveBeenCalledTimes(1);
expect(providers[2]).toHaveBeenCalledTimes(1);
});

it("should not invoke subsequent providers once one resolves", async () => {
const expectedOutput = "foo";
const providers = [
jest.fn().mockRejectedValue(new ProviderError("Move along")),
jest.fn().mockResolvedValue("foo"),
jest.fn(() => fail("This provider should not be invoked")),
rejectWithProviderError("Move along"),
resolveStatic(expectedOutput),
rejectWithProviderError("This provider should not be invoked"),
];

expect(await chain(...providers)()).toBe("foo");
expect(providers[0].mock.calls.length).toBe(1);
expect(providers[1].mock.calls.length).toBe(1);
expect(providers[2].mock.calls.length).toBe(0);
try {
const result = await chain(...providers)();
expect(result).toBe(expectedOutput);
} catch (error) {
throw error;
}

expect(providers[0]).toHaveBeenCalledTimes(1);
expect(providers[1]).toHaveBeenCalledTimes(1);
expect(providers[2]).not.toHaveBeenCalled();
});

describe("should throw if no provider resolves", () => {
const expectedErrorMsg = "Last provider failed";

it.each([
[ProviderError, rejectWithProviderError(expectedErrorMsg)],
[Error, rejectWithError(expectedErrorMsg)],
])("case %p", async (errorType, errorProviderMockFn) => {
const firstProviderWhichRejects = rejectWithProviderError("Move along");
try {
await chain(firstProviderWhichRejects, errorProviderMockFn)();
throw new Error("Should not get here");
} catch (error) {
expect(error).toEqual(new errorType(expectedErrorMsg));
}
expect(firstProviderWhichRejects).toHaveBeenCalledTimes(1);
expect(errorProviderMockFn).toHaveBeenCalledTimes(1);
});
});

it("should halt if an unrecognized error is encountered", async () => {
const provider = chain(
() => Promise.reject(new ProviderError("Move along")),
() => Promise.reject(new Error("Unrelated failure")),
fromStatic("foo")
);
const expectedErrorMsg = "Unrelated failure";
const providers = [rejectWithProviderError("Move along"), rejectWithError(expectedErrorMsg), resolveStatic("foo")];

try {
await chain(...providers)();
throw new Error("Should not get here");
} catch (error) {
expect(error).toEqual(new Error(expectedErrorMsg));
}

expect(providers[0]).toHaveBeenCalledTimes(1);
expect(providers[1]).toHaveBeenCalledTimes(1);
expect(providers[2]).not.toHaveBeenCalled();
});

it("should halt if ProviderError explicitly requests it", async () => {
const expectedError = new ProviderError("ProviderError with tryNextLink set to false", false);
const providers = [
rejectWithProviderError("Move along"),
jest.fn().mockRejectedValue(expectedError),
resolveStatic("foo"),
];

try {
await chain(...providers)();
throw new Error("Should not get here");
} catch (error) {
expect(error).toEqual(expectedError);
}

await expect(provider()).rejects.toMatchObject(new Error("Unrelated failure"));
expect(providers[0]).toHaveBeenCalledTimes(1);
expect(providers[1]).toHaveBeenCalledTimes(1);
expect(providers[2]).not.toHaveBeenCalled();
});

it("should reject chains with no links", async () => {
await expect(chain()()).rejects.toMatchObject(new Error("No providers in chain"));
try {
await chain()();
throw new Error("Should not get here");
} catch (error) {
expect(error).toEqual(new Error("No providers in chain"));
}
});
});
28 changes: 17 additions & 11 deletions packages/property-provider/src/chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { ProviderError } from "./ProviderError";

/**
* @internal
*
*
* Compose a single credential provider function from multiple credential
* providers. The first provider in the argument list will always be invoked;
* subsequent providers in the list will be invoked in the order in which the
Expand All @@ -13,19 +13,25 @@ import { ProviderError } from "./ProviderError";
* If no providers were received or no provider resolves successfully, the
* returned promise will be rejected.
*/
export function chain<T>(...providers: Array<Provider<T>>): Provider<T> {
return () => {
let promise: Promise<T> = Promise.reject(new ProviderError("No providers in chain"));
export const chain =
<T>(...providers: Array<Provider<T>>): Provider<T> =>
async () => {
if (providers.length === 0) {
throw new ProviderError("No providers in chain");
}

let lastProviderError: Error | undefined;
for (const provider of providers) {
promise = promise.catch((err: any) => {
try {
const credentials = await provider();
return credentials;
} catch (err) {
lastProviderError = err;
if (err?.tryNextLink) {
return provider();
continue;
}

throw err;
});
}
}

return promise;
throw lastProviderError;
};
}