diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index c97b0eb217..61e96526f2 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -189,7 +189,8 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s sock.setblocking(False) await asyncio.get_running_loop().sock_connect(sock, host) return sock - except OSError: + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak sock.close() raise @@ -238,6 +239,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s except OSError as e: sock.close() err = e # type: ignore[assignment] + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak + sock.close() + raise if err is not None: raise err @@ -289,19 +294,25 @@ async def _async_configured_socket( # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak + sock.close() + raise + try: + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] - except _CertificateError: - ssl_sock.close() - raise - ssl_sock.settimeout(options.socket_timeout) - return ssl_sock + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + except BaseException: + # Protect against cancellation, _CertificateError, or interruption + # where the raw socket would otherwise leak. + ssl_sock.close() + raise async def _configured_protocol_interface( @@ -362,26 +373,27 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: + try: + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore] - except _CertificateError: - transport.abort() - raise - if ssl_session_cache is not None: - ssl_obj = transport.get_extra_info("ssl_object") - if ssl_obj is not None: - new_session = ssl_obj.session - if new_session is not None: - ssl_session_cache[0] = new_session - - return AsyncNetworkingInterface((transport, protocol)) + if ssl_session_cache is not None: + ssl_obj = transport.get_extra_info("ssl_object") + if ssl_obj is not None: + new_session = ssl_obj.session + if new_session is not None: + ssl_session_cache[0] = new_session + + return AsyncNetworkingInterface((transport, protocol)) + except BaseException: + # Protect against cancellation, _CertificateError, or interruption + # where the transport would otherwise leak. + transport.abort() + raise def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index a96e28d832..a034af0341 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -17,13 +17,18 @@ from __future__ import annotations import asyncio +import functools +import socket +import ssl import sys +from unittest.mock import patch from test.asynchronous.utils import async_get_pool from test.utils_shared import delay, one sys.path[0:0] = [""] +from pymongo import pool_shared from test.asynchronous import AsyncIntegrationTest, async_client_context, connected @@ -129,3 +134,100 @@ async def task(): await task self.assertTrue(change_stream._closed) + + async def test_cancellation_closes_socket_during_create_connection(self): + address = (await async_client_context.host, await async_client_context.port) + options = (await async_get_pool(self.client)).opts + + created_sockets: list[socket.socket] = [] + real_socket_cls = socket.socket + target_task = None + + def tracking_socket(*args, **kwargs): + s = real_socket_cls(*args, **kwargs) + if asyncio.current_task() is target_task: + created_sockets.append(s) + return s + + loop = asyncio.get_running_loop() + real_sock_connect = loop.sock_connect + started = asyncio.Event() + block_forever = asyncio.Event() + + async def slow_sock_connect(sock, addr): + if sock in created_sockets: + started.set() + await block_forever.wait() + return None + return await real_sock_connect(sock, addr) + + with ( + patch.object(socket, "socket", tracking_socket), + patch.object(loop, "sock_connect", slow_sock_connect), + ): + task = asyncio.create_task(pool_shared._async_create_connection(address, options)) + target_task = task + await asyncio.wait_for(started.wait(), timeout=5) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + self.assertTrue(created_sockets, "expected at least one socket to be created") + for sock in created_sockets: + self.assertEqual( + sock.fileno(), + -1, + f"socket leaked across cancellation: {sock!r}", + ) + + async def test_cancellation_closes_socket_during_ssl_wrap_socket(self): + address = (await async_client_context.host, await async_client_context.port) + options = (await async_get_pool(self.client)).opts + fake_ssl_context = ssl.create_default_context() + + created_sockets: list[socket.socket] = [] + real_socket_cls = socket.socket + target_task = None + + def tracking_socket(*args, **kwargs): + s = real_socket_cls(*args, **kwargs) + if asyncio.current_task() is target_task: + created_sockets.append(s) + return s + + loop = asyncio.get_running_loop() + real_run_in_executor = loop.run_in_executor + started = asyncio.Event() + + def slow_run_in_executor(executor, func, *args): + # Need to unwrap the SNI branch here if present + inner = func.func if isinstance(func, functools.partial) else func + # Each `ctx.wrap_socket` access returns a fresh bound-method + # object, so we check the bound instance (__self__) instead + if ( + getattr(inner, "__self__", None) is fake_ssl_context + and asyncio.current_task() is target_task + ): + started.set() + # Return a future that never completes for cancellation. + return asyncio.get_running_loop().create_future() + return real_run_in_executor(executor, func, *args) + + with ( + patch.object(socket, "socket", tracking_socket), + patch.object(loop, "run_in_executor", slow_run_in_executor), + patch.object(options, "_PoolOptions__ssl_context", fake_ssl_context), + ): + task = asyncio.create_task(pool_shared._async_configured_socket(address, options)) + target_task = task + await asyncio.wait_for(started.wait(), timeout=5) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(created_sockets, "expected at least one socket to be created") + for sock in created_sockets: + self.assertEqual( + sock.fileno(), + -1, + f"socket leaked across cancellation: {sock!r}", + )