"""Cross platform abstractions for inter-process communication On Unix, this uses AF_UNIX sockets. On Windows, this uses NamedPipes. """ from __future__ import annotations import base64 import os import shutil import sys import tempfile from types import TracebackType from typing import Callable, Final if sys.platform == "win32": # This may be private, but it is needed for IPC on Windows, and is basically stable import ctypes import _winapi _IPCHandle = int kernel32 = ctypes.windll.kernel32 DisconnectNamedPipe: Callable[[_IPCHandle], int] = kernel32.DisconnectNamedPipe FlushFileBuffers: Callable[[_IPCHandle], int] = kernel32.FlushFileBuffers else: import socket _IPCHandle = socket.socket class IPCException(Exception): """Exception for IPC issues.""" class IPCBase: """Base class for communication between the dmypy client and server. This contains logic shared between the client and server, such as reading and writing. """ connection: _IPCHandle def __init__(self, name: str, timeout: float | None) -> None: self.name = name self.timeout = timeout def read(self, size: int = 100000) -> bytes: """Read bytes from an IPC connection until its empty.""" bdata = bytearray() if sys.platform == "win32": while True: ov, err = _winapi.ReadFile(self.connection, size, overlapped=True) try: if err == _winapi.ERROR_IO_PENDING: timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE res = _winapi.WaitForSingleObject(ov.event, timeout) if res != _winapi.WAIT_OBJECT_0: raise IPCException(f"Bad result from I/O wait: {res}") except BaseException: ov.cancel() raise _, err = ov.GetOverlappedResult(True) more = ov.getbuffer() if more: bdata.extend(more) if err == 0: # we are done! break elif err == _winapi.ERROR_MORE_DATA: # read again continue elif err == _winapi.ERROR_OPERATION_ABORTED: raise IPCException("ReadFile operation aborted.") else: while True: more = self.connection.recv(size) if not more: break bdata.extend(more) return bytes(bdata) def write(self, data: bytes) -> None: """Write bytes to an IPC connection.""" if sys.platform == "win32": try: ov, err = _winapi.WriteFile(self.connection, data, overlapped=True) try: if err == _winapi.ERROR_IO_PENDING: timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE res = _winapi.WaitForSingleObject(ov.event, timeout) if res != _winapi.WAIT_OBJECT_0: raise IPCException(f"Bad result from I/O wait: {res}") elif err != 0: raise IPCException(f"Failed writing to pipe with error: {err}") except BaseException: ov.cancel() raise bytes_written, err = ov.GetOverlappedResult(True) assert err == 0, err assert bytes_written == len(data) except OSError as e: raise IPCException(f"Failed to write with error: {e.winerror}") from e else: self.connection.sendall(data) self.connection.shutdown(socket.SHUT_WR) def close(self) -> None: if sys.platform == "win32": if self.connection != _winapi.NULL: _winapi.CloseHandle(self.connection) else: self.connection.close() class IPCClient(IPCBase): """The client side of an IPC connection.""" def __init__(self, name: str, timeout: float | None) -> None: super().__init__(name, timeout) if sys.platform == "win32": timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER try: _winapi.WaitNamedPipe(self.name, timeout) except FileNotFoundError as e: raise IPCException(f"The NamedPipe at {self.name} was not found.") from e except OSError as e: if e.winerror == _winapi.ERROR_SEM_TIMEOUT: raise IPCException("Timed out waiting for connection.") from e else: raise try: self.connection = _winapi.CreateFile( self.name, _winapi.GENERIC_READ | _winapi.GENERIC_WRITE, 0, _winapi.NULL, _winapi.OPEN_EXISTING, _winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL, ) except OSError as e: if e.winerror == _winapi.ERROR_PIPE_BUSY: raise IPCException("The connection is busy.") from e else: raise _winapi.SetNamedPipeHandleState( self.connection, _winapi.PIPE_READMODE_MESSAGE, None, None ) else: self.connection = socket.socket(socket.AF_UNIX) self.connection.settimeout(timeout) self.connection.connect(name) def __enter__(self) -> IPCClient: return self def __exit__( self, exc_ty: type[BaseException] | None = None, exc_val: BaseException | None = None, exc_tb: TracebackType | None = None, ) -> None: self.close() class IPCServer(IPCBase): BUFFER_SIZE: Final = 2**16 def __init__(self, name: str, timeout: float | None = None) -> None: if sys.platform == "win32": name = r"\\.\pipe\{}-{}.pipe".format( name, base64.urlsafe_b64encode(os.urandom(6)).decode() ) else: name = f"{name}.sock" super().__init__(name, timeout) if sys.platform == "win32": self.connection = _winapi.CreateNamedPipe( self.name, _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE | _winapi.FILE_FLAG_OVERLAPPED, _winapi.PIPE_READMODE_MESSAGE | _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_WAIT | 0x8, # PIPE_REJECT_REMOTE_CLIENTS 1, # one instance self.BUFFER_SIZE, self.BUFFER_SIZE, _winapi.NMPWAIT_WAIT_FOREVER, 0, # Use default security descriptor ) if self.connection == -1: # INVALID_HANDLE_VALUE err = _winapi.GetLastError() raise IPCException(f"Invalid handle to pipe: {err}") else: self.sock_directory = tempfile.mkdtemp() sockfile = os.path.join(self.sock_directory, self.name) self.sock = socket.socket(socket.AF_UNIX) self.sock.bind(sockfile) self.sock.listen(1) if timeout is not None: self.sock.settimeout(timeout) def __enter__(self) -> IPCServer: if sys.platform == "win32": # NOTE: It is theoretically possible that this will hang forever if the # client never connects, though this can be "solved" by killing the server try: ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True) except OSError as e: # Don't raise if the client already exists, or the client already connected if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA): raise else: try: timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE res = _winapi.WaitForSingleObject(ov.event, timeout) assert res == _winapi.WAIT_OBJECT_0 except BaseException: ov.cancel() _winapi.CloseHandle(self.connection) raise _, err = ov.GetOverlappedResult(True) assert err == 0 else: try: self.connection, _ = self.sock.accept() except socket.timeout as e: raise IPCException("The socket timed out") from e return self def __exit__( self, exc_ty: type[BaseException] | None = None, exc_val: BaseException | None = None, exc_tb: TracebackType | None = None, ) -> None: if sys.platform == "win32": try: # Wait for the client to finish reading the last write before disconnecting if not FlushFileBuffers(self.connection): raise IPCException( "Failed to flush NamedPipe buffer, maybe the client hung up?" ) finally: DisconnectNamedPipe(self.connection) else: self.close() def cleanup(self) -> None: if sys.platform == "win32": self.close() else: shutil.rmtree(self.sock_directory) @property def connection_name(self) -> str: if sys.platform == "win32": return self.name else: name = self.sock.getsockname() assert isinstance(name, str) return name