Skip to content
58 changes: 35 additions & 23 deletions pymongo/pool_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
sock.setblocking(False)
await asyncio.get_running_loop().sock_connect(sock, host)
return sock
except OSError:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
Comment thread
NoahStapp marked this conversation as resolved.

Expand Down Expand Up @@ -231,6 +232,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
except OSError as e:
sock.close()
err = e # type: ignore[assignment]
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise

if err is not None:
raise err
Expand Down Expand Up @@ -282,19 +287,25 @@ async def _async_configured_socket(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise

ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the raw socket would otherwise leak.
ssl_sock.close()
raise


async def _configured_protocol_interface(
Expand Down Expand Up @@ -337,18 +348,19 @@ async def _configured_protocol_interface(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise

return AsyncNetworkingInterface((transport, protocol))
return AsyncNetworkingInterface((transport, protocol))
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the transport would otherwise leak.
transport.abort()
raise


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
Expand Down
42 changes: 42 additions & 0 deletions test/asynchronous/test_async_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
from __future__ import annotations

import asyncio
import socket as _socket
import sys
from test.asynchronous.utils import async_get_pool
from test.utils_shared import delay, one
from unittest.mock import patch

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, connected

from pymongo import pool_shared


class TestAsyncCancellation(AsyncIntegrationTest):
async def test_async_cancellation_closes_connection(self):
Expand Down Expand Up @@ -127,3 +131,41 @@ async def task():
await task

self.assertTrue(change_stream._closed)

async def test_cancellation_closes_socket_during_create_connection(self):
address = (await async_client_context.host, await async_client_context.port)
options = (await async_get_pool(self.client)).opts

created_sockets: list[_socket.socket] = []
real_socket_cls = _socket.socket

def tracking_socket(*args, **kwargs):
s = real_socket_cls(*args, **kwargs)
created_sockets.append(s)
return s
Comment thread
NoahStapp marked this conversation as resolved.

loop = asyncio.get_running_loop()
started = asyncio.Event()
block_forever = asyncio.Event()

async def slow_sock_connect(sock, addr):
started.set()
await block_forever.wait()

with (
patch.object(_socket, "socket", tracking_socket),
patch.object(loop, "sock_connect", slow_sock_connect),
):
task = asyncio.create_task(pool_shared._async_create_connection(address, options))
await asyncio.wait_for(started.wait(), timeout=5)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

Comment thread
NoahStapp marked this conversation as resolved.
Outdated
self.assertTrue(created_sockets, "expected at least one socket to be created")
for sock in created_sockets:
self.assertEqual(
sock.fileno(),
-1,
f"socket leaked across cancellation: {sock!r}",
)
Loading