Skip to content

Commit

Permalink
fix(NODE-5993): connection's aborted promise leak
Browse files Browse the repository at this point in the history
  • Loading branch information
nbbeeken committed Mar 6, 2024
1 parent 4ac9675 commit a8b3540
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 55 deletions.
74 changes: 34 additions & 40 deletions src/cmap/connection.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { type Readable, Transform, type TransformCallback } from 'stream';
import { clearTimeout, setTimeout } from 'timers';
import { promisify } from 'util';

import type { BSONSerializeOptions, Document, ObjectId } from '../bson';
import type { AutoEncrypter } from '../client-side-encryption/auto_encrypter';
Expand Down Expand Up @@ -180,18 +179,18 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
* Once connection is established, command logging can log events (if enabled)
*/
public established: boolean;
/** Indicates that the connection (including underlying TCP socket) has been closed. */
public closed = false;

private lastUseTime: number;
private clusterTime: Document | null = null;
private error: Error | null = null;
private dataEvents: AsyncGenerator<Buffer, void, void> | null = null;

private readonly socketTimeoutMS: number;
private readonly monitorCommands: boolean;
private readonly socket: Stream;
private readonly controller: AbortController;
private readonly signal: AbortSignal;
private readonly messageStream: Readable;
private readonly socketWrite: (buffer: Uint8Array) => Promise<void>;
private readonly aborted: Promise<never>;

/** @event */
static readonly COMMAND_STARTED = COMMAND_STARTED;
Expand All @@ -211,6 +210,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
constructor(stream: Stream, options: ConnectionOptions) {
super();

this.socket = stream;
this.id = options.id;
this.address = streamIdentifier(stream, options);
this.socketTimeoutMS = options.socketTimeoutMS ?? 0;
Expand All @@ -223,39 +223,12 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
this.generation = options.generation;
this.lastUseTime = now();

this.socket = stream;

// TODO: Remove signal from connection layer
this.controller = new AbortController();
const { signal } = this.controller;
this.signal = signal;
const { promise: aborted, reject } = promiseWithResolvers<never>();
aborted.then(undefined, () => null); // Prevent unhandled rejection
this.signal.addEventListener(
'abort',
function onAbort() {
reject(signal.reason);
},
{ once: true }
);
this.aborted = aborted;

this.messageStream = this.socket
.on('error', this.onError.bind(this))
.pipe(new SizedMessageTransform({ connection: this }))
.on('error', this.onError.bind(this));
this.socket.on('close', this.onClose.bind(this));
this.socket.on('timeout', this.onTimeout.bind(this));

const socketWrite = promisify(this.socket.write.bind(this.socket));
this.socketWrite = async buffer => {
return Promise.race([socketWrite(buffer), this.aborted]);
};
}

/** Indicates that the connection (including underlying TCP socket) has been closed. */
public get closed(): boolean {
return this.signal.aborted;
}

public get hello() {
Expand Down Expand Up @@ -355,7 +328,11 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
}

this.socket.destroy();
this.controller.abort(error);
if (error) {
this.error = error;
this.dataEvents?.throw(error).then(undefined, () => null); // squash unhandled rejection
}
this.closed = true;
this.emit(Connection.CLOSE);
}

Expand Down Expand Up @@ -596,7 +573,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
}

private throwIfAborted() {
this.signal.throwIfAborted();
if (this.error) throw this.error;
}

/**
Expand All @@ -619,7 +596,18 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {

const buffer = Buffer.concat(await finalCommand.toBin());

return this.socketWrite(buffer);
if (this.socket.write(buffer)) return;

const { promise: drained, resolve, reject } = promiseWithResolvers<void>();
const onDrain = () => resolve();
const onError = (error: Error) => reject(error);

this.socket.once('drain', onDrain).once('error', onError);
try {
return await drained;
} finally {
this.socket.off('drain', onDrain).off('error', onError);
}
}

/**
Expand All @@ -632,13 +620,19 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
* Note that `for-await` loops call `return` automatically when the loop is exited.
*/
private async *readMany(): AsyncGenerator<OpMsgResponse | OpQueryResponse> {
for await (const message of onData(this.messageStream, { signal: this.signal })) {
const response = await decompressResponse(message);
yield response;
try {
this.dataEvents = this.dataEvents = onData(this.messageStream);
for await (const message of this.dataEvents) {
const response = await decompressResponse(message);
yield response;

if (!response.moreToCome) {
return;
if (!response.moreToCome) {
return;
}
}
} finally {
this.dataEvents = null;
this.throwIfAborted();
}
}
}
Expand Down
16 changes: 1 addition & 15 deletions src/cmap/wire_protocol/on_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ type PendingPromises = Omit<
* Returns an AsyncIterator that iterates each 'data' event emitted from emitter.
* It will reject upon an error event or if the provided signal is aborted.
*/
export function onData(emitter: EventEmitter, options: { signal: AbortSignal }) {
const signal = options.signal;

export function onData(emitter: EventEmitter) {
// Setup pending events and pending promise lists
/**
* When the caller has not yet called .next(), we store the
Expand Down Expand Up @@ -89,19 +87,8 @@ export function onData(emitter: EventEmitter, options: { signal: AbortSignal })
emitter.on('data', eventHandler);
emitter.on('error', errorHandler);

if (signal.aborted) {
// If the signal is aborted, set up the first .next() call to be a rejection
queueMicrotask(abortListener);
} else {
signal.addEventListener('abort', abortListener, { once: true });
}

return iterator;

function abortListener() {
errorHandler(signal.reason);
}

function eventHandler(value: Buffer) {
const promise = unconsumedPromises.shift();
if (promise != null) promise.resolve({ value, done: false });
Expand All @@ -119,7 +106,6 @@ export function onData(emitter: EventEmitter, options: { signal: AbortSignal })
// Adding event handlers
emitter.off('data', eventHandler);
emitter.off('error', errorHandler);
signal.removeEventListener('abort', abortListener);
finished = true;
const doneResult = { value: undefined, done: finished } as const;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import type * as net from 'node:net';

import { expect } from 'chai';
import { type EventEmitter, once } from 'events';
import * as sinon from 'sinon';

import {
Binary,
connect,
Connection,
type ConnectionOptions,
HostAddress,
isHello,
LEGACY_HELLO_COMMAND,
makeClientMetadata,
MongoClient,
Expand All @@ -14,7 +20,9 @@ import {
ServerHeartbeatStartedEvent,
Topology
} from '../../mongodb';
import * as mock from '../../tools/mongodb-mock/index';
import { skipBrokenAuthTestBeforeEachHook } from '../../tools/runner/hooks/configuration';
import { getSymbolFrom, sleep } from '../../tools/utils';
import { assert as test, setupDatabase } from '../shared';

const commonConnectOptions = {
Expand Down Expand Up @@ -197,6 +205,80 @@ describe('Connection', function () {
client.connect();
});

context(
'when a large message is written to the socket',
{ requires: { topology: 'single' } },
() => {
let client, mockServer: import('../../tools/mongodb-mock/src/server').MockServer;

beforeEach(async function () {
mockServer = await mock.createServer();

mockServer
.addMessageHandler('insert', req => {
setTimeout(() => {
req.reply({ ok: 1 });
}, 800);
})
.addMessageHandler('hello', req => {
req.reply(Object.assign({}, mock.HELLO));
})
.addMessageHandler(LEGACY_HELLO_COMMAND, req => {
req.reply(Object.assign({}, mock.HELLO));
});

client = new MongoClient(`mongodb://${mockServer.uri()}`, {
minPoolSize: 1,
maxPoolSize: 1
});
});

afterEach(async function () {
await client.close();
mockServer.destroy();
sinon.restore();
});

it('waits for an async drain event because the write was buffered', async () => {
const connectionReady = once(client, 'connectionReady');
await client.connect();
await connectionReady;

// Get the only connection
const pool = [...client.topology.s.servers.values()][0].pool;
const connection = pool[getSymbolFrom(pool, 'connections')].first();
const socket: EventEmitter = connection.socket;

// Spy on the socket event listeners
const addedListeners: string[] = [];
const removedListeners: string[] = [];
socket
.on('removeListener', name => removedListeners.push(name))
.on('newListener', name => addedListeners.push(name));

// Make server sockets block
for (const s of mockServer.sockets) s.pause();

const insert = client
.db('test')
.collection('test')
// Anything above 16Kb should work I think (10mb to be extra sure)
.insertOne({ a: new Binary(Buffer.alloc(10 * (2 ** 10) ** 2), 127) });

// Sleep a bit and unblock server sockets
await sleep(10);
for (const s of mockServer.sockets) s.resume();

// Let the operation finish
await insert;

// Ensure that we used the drain event for this write
expect(addedListeners).to.deep.equal(['drain', 'error']);
expect(removedListeners).to.deep.equal(['drain', 'error']);
});
}
);

context('when connecting with a username and password', () => {
let utilClient: MongoClient;
let client: MongoClient;
Expand Down
33 changes: 33 additions & 0 deletions test/integration/node-specific/resource_clean_up.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import * as v8 from 'node:v8';

import { expect } from 'chai';

import { sleep } from '../../tools/utils';
import { runScript } from './resource_tracking_script_builder';

/**
Expand Down Expand Up @@ -86,4 +89,34 @@ describe('Driver Resources', () => {
});
});
});

context('when 100s of operations are executed and complete', () => {
beforeEach(function () {
if (this.currentTest && typeof v8.queryObjects !== 'function') {
this.currentTest.skipReason = 'Test requires v8.queryObjects API to count Promises';
this.currentTest?.skip();
}
});

let client;
beforeEach(async function () {
client = this.configuration.newClient();
});

afterEach(async function () {
await client.close();
});

it('does not leave behind additional promises', async () => {
const test = client.db('test').collection('test');
const promiseCountBefore = v8.queryObjects(Promise, { format: 'count' });
for (let i = 0; i < 100; i++) {
await test.findOne();
}
await sleep(10);
const promiseCountAfter = v8.queryObjects(Promise, { format: 'count' });

expect(promiseCountAfter).to.be.within(promiseCountBefore - 5, promiseCountBefore + 5);
});
});
});

0 comments on commit a8b3540

Please sign in to comment.