Skip to content
10 changes: 10 additions & 0 deletions pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,16 @@ def get_conn(self) -> PyMongoProtocol:
def sock(self) -> socket.socket:
return self.conn[0].get_extra_info("socket")

def __del__(self) -> None:
# Synchronously release the raw socket in case the event loop is already closed
# or this connection was orphaned.
# Safe even if asyncio has already closed the socket.
try:
if self.sock is not None:
self.sock.close()
except Exception: # noqa: S110
pass
Comment thread
NoahStapp marked this conversation as resolved.
Outdated


class NetworkingInterface(NetworkingInterfaceBase):
def __init__(self, conn: Union[socket.socket, _sslConn]):
Expand Down
75 changes: 48 additions & 27 deletions pymongo/pool_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
sock.setblocking(False)
await asyncio.get_running_loop().sock_connect(sock, host)
return sock
except OSError:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
Comment thread
NoahStapp marked this conversation as resolved.

Expand Down Expand Up @@ -231,6 +232,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
except OSError as e:
sock.close()
err = e # type: ignore[assignment]
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise

if err is not None:
raise err
Expand Down Expand Up @@ -282,19 +287,25 @@ async def _async_configured_socket(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise

ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the raw socket would otherwise leak.
ssl_sock.close()
raise


async def _configured_protocol_interface(
Expand All @@ -311,11 +322,16 @@ async def _configured_protocol_interface(
timeout = options.socket_timeout

if ssl_context is None:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
try:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
)
)
)
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise

host = address[0]
try:
Expand All @@ -337,18 +353,23 @@ async def _configured_protocol_interface(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise

return AsyncNetworkingInterface((transport, protocol))
return AsyncNetworkingInterface((transport, protocol))
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the transport would otherwise leak.
transport.abort()
raise


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
Expand Down
Loading