diff --git a/cheroot/makefile.py b/cheroot/makefile.py index f5780a1ede..59c00e9769 100644 --- a/cheroot/makefile.py +++ b/cheroot/makefile.py @@ -2,12 +2,16 @@ # prefer slower Python-based io module import _pyio as io +import select import socket # Write only 16K at a time to sockets SOCK_WRITE_BLOCKSIZE = 16384 +# Seconds to wait for a blocked socket to become writable +SOCK_WRITE_TIMEOUT = 10 + class BufferedWriter(io.BufferedWriter): """Faux file object attached to a socket object.""" @@ -26,12 +30,28 @@ def write(self, b): def _flush_unlocked(self): self._checkClosed('flush of closed file') while self._write_buf: + n = None try: # ssl sockets only except 'bytes', not bytearrays # so perhaps we should conditionally wrap this for perf? - n = self.raw.write(bytes(self._write_buf)) + n = self.raw.write( + bytes(self._write_buf[:SOCK_WRITE_BLOCKSIZE]), + ) except io.BlockingIOError as e: n = e.characters_written + if n is None: + _, writable, _ = select.select( + [], + [self.raw], + [], + SOCK_WRITE_TIMEOUT, + ) + if not writable: + raise io.BlockingIOError( + 0, + 'raw stream blocked; no bytes written', + ) + continue del self._write_buf[:n] diff --git a/cheroot/makefile.pyi b/cheroot/makefile.pyi index 3f5ea2756b..8e8d57d73f 100644 --- a/cheroot/makefile.pyi +++ b/cheroot/makefile.pyi @@ -1,6 +1,7 @@ import io SOCK_WRITE_BLOCKSIZE: int +SOCK_WRITE_TIMEOUT: int class BufferedWriter(io.BufferedWriter): def write(self, b): ... diff --git a/cheroot/test/test_makefile.py b/cheroot/test/test_makefile.py index d65d4ea268..f0582dba15 100644 --- a/cheroot/test/test_makefile.py +++ b/cheroot/test/test_makefile.py @@ -51,3 +51,100 @@ def test_bytes_written(): wfile = makefile.MakeFile(sock, 'w') wfile.write(b'bar') assert wfile.bytes_written == 3 + + +class _RawWriteBlockOnce: + """Mock raw.write() returning None once, then writing normally.""" + + def __init__(self): + """Initialize _RawWriteBlockOnce.""" + self.call_count = 0 + self.written = bytearray() + + def __call__(self, chunk): + """Return None on first call to simulate a blocked write.""" + self.call_count += 1 + if self.call_count == 1: + return None + self.written.extend(chunk) + return len(chunk) + + def fileno(self): + """Return a fake fd for select().""" + return -1 + + +class _RawWriteBlockAlways: + """Mock raw.write() that always returns None.""" + + def __init__(self): + """Initialize _RawWriteBlockAlways.""" + self.call_count = 0 + + def __call__(self, chunk): + """Return None to simulate a permanently blocked socket.""" + self.call_count += 1 + + def fileno(self): + """Return a fake fd for select().""" + return -1 + + +def test_flush_recovers_from_temporary_block(monkeypatch): + """_flush_unlocked() retries after select when raw.write() returns None. + + A temporarily blocked socket should recover once select() reports + the socket is writable again, delivering all buffered data. + """ + data = b'x' * (makefile.SOCK_WRITE_BLOCKSIZE * 2) + + sock = MockSocket() + wfile = makefile.MakeFile(sock, 'w') + wfile._write_buf.extend(data) + + mock = _RawWriteBlockOnce() + wfile.raw.write = mock + + # select() reports writable immediately + monkeypatch.setattr( + 'cheroot.makefile.select.select', + lambda _rlist, wlist, _xlist, _timeout: ([], wlist, []), + ) + wfile._flush_unlocked() + + assert bytes(mock.written) == data, ( + 'all buffered data should be written after select retry' + ) + + +def test_flush_raises_on_sustained_block(monkeypatch): + """_flush_unlocked() raises BlockingIOError after select timeout. + + If the socket stays blocked past SOCK_WRITE_TIMEOUT, the write + buffer must be preserved and BlockingIOError raised. + """ + import io + + import pytest + + data = b'x' * makefile.SOCK_WRITE_BLOCKSIZE + + sock = MockSocket() + wfile = makefile.MakeFile(sock, 'w') + wfile._write_buf.extend(data) + + mock = _RawWriteBlockAlways() + wfile.raw.write = mock + + # select() reports not writable (timeout) + monkeypatch.setattr( + 'cheroot.makefile.select.select', + lambda _rlist, _wlist, _xlist, _timeout: ([], [], []), + ) + + with pytest.raises(io.BlockingIOError): + wfile._flush_unlocked() + + assert len(wfile._write_buf) == len(data), ( + 'write buffer must be preserved when socket stays blocked' + ) diff --git a/docs/changelog-fragments.d/822.bugfix.rst b/docs/changelog-fragments.d/822.bugfix.rst new file mode 100644 index 0000000000..97493eeb59 --- /dev/null +++ b/docs/changelog-fragments.d/822.bugfix.rst @@ -0,0 +1,3 @@ +Fixed a bug that could cause premature clearing of the write buffer when a socket write is blocked. + +-- by :user:`cbbm142` diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 06b204518e..981a6b3f97 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -5,6 +5,7 @@ backports bugfixes builtin b'xb +buf compat config conftest @@ -22,6 +23,7 @@ hardcoded hostname inclusivity intersphinx +io iterable linter linters @@ -48,6 +50,7 @@ preconfigure py pytest pythonic +RawIOBase readonly rebase Refactor