diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index b6d92f56e0c..bbe65a25103 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -27,6 +27,7 @@ import { MongoNetworkTimeoutError, MongoOperationTimeoutError, MongoParseError, + MongoRuntimeError, MongoServerError, MongoUnexpectedServerResponseError } from '../error'; @@ -791,22 +792,41 @@ export class SizedMessageTransform extends Transform { } this.bufferPool.append(chunk); - const sizeOfMessage = this.bufferPool.getInt32(); - if (sizeOfMessage == null) { - return callback(); - } + while (this.bufferPool.length) { + // While there are any bytes in the buffer - if (sizeOfMessage < 0) { - return callback(new MongoParseError(`Invalid message size: ${sizeOfMessage}, too small`)); - } + // Try to fetch a size from the top 4 bytes + const sizeOfMessage = this.bufferPool.getInt32(); + + if (sizeOfMessage == null) { + // Not even an int32 worth of data. Stop the loop, we need more chunks. + break; + } + + if (sizeOfMessage < 0) { + // The size in the message has a negative value, this is probably corruption, throw: + return callback(new MongoParseError(`Message size cannot be negative: ${sizeOfMessage}`)); + } - if (sizeOfMessage > this.bufferPool.length) { - return callback(); + if (sizeOfMessage > this.bufferPool.length) { + // We do not have enough bytes to make a sizeOfMessage chunk + break; + } + + // Add a message to the stream + const message = this.bufferPool.read(sizeOfMessage); + + if (!this.push(message)) { + // We only subscribe to data events so we should never get backpressure + // if we do, we do not have the handling for it. + return callback( + new MongoRuntimeError(`SizedMessageTransform does not support backpressure`) + ); + } } - const message = this.bufferPool.read(sizeOfMessage); - return callback(null, message); + callback(); } } diff --git a/test/integration/connection-monitoring-and-pooling/connection.test.ts b/test/integration/connection-monitoring-and-pooling/connection.test.ts index 4307ee32f21..0e5fd45323f 100644 --- a/test/integration/connection-monitoring-and-pooling/connection.test.ts +++ b/test/integration/connection-monitoring-and-pooling/connection.test.ts @@ -22,7 +22,7 @@ import { } from '../../mongodb'; import * as mock from '../../tools/mongodb-mock/index'; import { skipBrokenAuthTestBeforeEachHook } from '../../tools/runner/hooks/configuration'; -import { sleep } from '../../tools/utils'; +import { processTick, sleep } from '../../tools/utils'; import { assert as test, setupDatabase } from '../shared'; const commonConnectOptions = { @@ -249,6 +249,54 @@ describe('Connection', function () { client.connect(); }); + describe( + 'when a monitoring Connection receives many hellos in one chunk', + { requires: { topology: 'replicaset', mongodb: '>=4.4' } }, // need to be on a streaming hello version + function () { + let client: MongoClient; + + beforeEach(async function () { + // set heartbeatFrequencyMS just so we don't have to wait so long for a hello + client = this.configuration.newClient({}, { heartbeatFrequencyMS: 10 }); + }); + + afterEach(async function () { + await client.close(); + }); + + // In the future we may want to skip processing concatenated heartbeats. + // This test exists to prevent regression of processing many messages inside one chunk. + it( + 'processes all of them and emits heartbeats', + { requires: { topology: 'replicaset', mongodb: '>=4.4' } }, + async function () { + let hbSuccess = 0; + client.on('serverHeartbeatSucceeded', () => (hbSuccess += 1)); + expect(hbSuccess).to.equal(0); + + await client.db().command({ ping: 1 }); // start monitoring. + const monitor = [...client.topology.s.servers.values()][0].monitor; + + // @ts-expect-error: accessing private property + const messageStream = monitor.connection.messageStream; + // @ts-expect-error: accessing private property + const socket = monitor.connection.socket; + + const [hello] = (await once(messageStream, 'data')) as [Buffer]; + + const thousandHellos = Array.from({ length: 1000 }, () => [...hello]).flat(1); + + // pretend this came from the server + socket.emit('data', Buffer.from(thousandHellos)); + + // All of the hb will be emitted synchronously in the next tick as the entire chunk is processed. + await processTick(); + expect(hbSuccess).to.be.greaterThan(1000); + } + ); + } + ); + context( 'when a large message is written to the socket', { requires: { topology: 'single', auth: 'disabled' } }, diff --git a/test/unit/cmap/connection.test.ts b/test/unit/cmap/connection.test.ts index aa3e86e2dc6..79fe9ea863d 100644 --- a/test/unit/cmap/connection.test.ts +++ b/test/unit/cmap/connection.test.ts @@ -1,4 +1,5 @@ import { Socket } from 'node:net'; +import { Writable } from 'node:stream'; import { expect } from 'chai'; import * as sinon from 'sinon'; @@ -11,7 +12,9 @@ import { MongoClientAuthProviders, MongoDBCollectionNamespace, MongoNetworkTimeoutError, + MongoRuntimeError, ns, + promiseWithResolvers, SizedMessageTransform } from '../../mongodb'; import * as mock from '../../tools/mongodb-mock/index'; @@ -333,5 +336,76 @@ describe('new Connection()', function () { expect(stream.read(1)).to.deep.equal(Buffer.from([6, 0, 0, 0, 5, 6])); expect(stream.read(1)).to.equal(null); }); + + it('parses many wire messages when a single chunk arrives', function () { + const stream = new SizedMessageTransform({ connection: {} as any }); + + let dataCount = 0; + stream.on('data', chunk => { + expect(chunk).to.have.lengthOf(8); + dataCount += 1; + }); + + // 3 messages of size 8 + stream.write( + Buffer.from([ + ...[8, 0, 0, 0, 0, 0, 0, 0], + ...[8, 0, 0, 0, 0, 0, 0, 0], + ...[8, 0, 0, 0, 0, 0, 0, 0] + ]) + ); + + expect(dataCount).to.equal(3); + }); + + it('parses many wire messages when a single chunk arrives and processes the remaining partial when it is complete', function () { + const stream = new SizedMessageTransform({ connection: {} as any }); + + let dataCount = 0; + stream.on('data', chunk => { + expect(chunk).to.have.lengthOf(8); + dataCount += 1; + }); + + // 3 messages of size 8 + stream.write( + Buffer.from([ + ...[8, 0, 0, 0, 0, 0, 0, 0], + ...[8, 0, 0, 0, 0, 0, 0, 0], + ...[8, 0, 0, 0, 0, 0, 0, 0], + ...[8, 0, 0, 0, 0, 0] // two shy of 8 + ]) + ); + + expect(dataCount).to.equal(3); + + stream.write(Buffer.from([0, 0])); // the rest of the last 8 + + expect(dataCount).to.equal(4); + }); + + it('throws an error when backpressure detected', async function () { + const stream = new SizedMessageTransform({ connection: {} as any }); + const destination = new Writable({ + highWaterMark: 1, + objectMode: true, + write: (chunk, encoding, callback) => { + void stream; + setTimeout(1).then(() => callback()); + } + }); + + // 1000 messages of size 8 + stream.write( + Buffer.from(Array.from({ length: 1000 }, () => [8, 0, 0, 0, 0, 0, 0, 0]).flat(1)) + ); + + const { promise, resolve, reject } = promiseWithResolvers(); + + stream.on('error', reject).pipe(destination).on('error', reject).on('finish', resolve); + + const error = await promise.catch(error => error); + expect(error).to.be.instanceOf(MongoRuntimeError); + }); }); });