"""
An abstraction layer over OS-dependent file-like objects, that provides a
consistent view of a *duplex byte stream*.
"""
import sys
import os
import socket
import errno
from rpyc.lib import safe_import, Timeout, socket_backoff_connect
from rpyc.lib.compat import poll, select_error, BYTES_LITERAL, get_exc_errno, maxint # noqa: F401
from rpyc.core.consts import STREAM_CHUNK
win32file = safe_import("win32file")
win32pipe = safe_import("win32pipe")
win32event = safe_import("win32event")
ssl = safe_import("ssl")
retry_errnos = (errno.EAGAIN, errno.EWOULDBLOCK)
[docs]class Stream(object):
"""Base Stream"""
__slots__ = ()
[docs] def close(self):
"""closes the stream, releasing any system resources associated with it"""
raise NotImplementedError()
@property
def closed(self):
"""tests whether the stream is closed or not"""
raise NotImplementedError()
[docs] def fileno(self):
"""returns the stream's file descriptor"""
raise NotImplementedError()
[docs] def poll(self, timeout):
"""indicates whether the stream has data to read (within *timeout*
seconds)"""
timeout = Timeout(timeout)
try:
p = poll() # from lib.compat, it may be a select object on non-Unix platforms
p.register(self.fileno(), "r")
while True:
try:
rl = p.poll(timeout.timeleft())
except select_error:
ex = sys.exc_info()[1]
if ex.args[0] == errno.EINTR:
continue
else:
raise
else:
break
except ValueError:
# if the underlying call is a select(), then the following errors may happen:
# - "ValueError: filedescriptor cannot be a negative integer (-1)"
# - "ValueError: filedescriptor out of range in select()"
# let's translate them to select.error
ex = sys.exc_info()[1]
raise select_error(str(ex))
return bool(rl)
[docs] def read(self, count):
"""reads **exactly** *count* bytes, or raise EOFError
:param count: the number of bytes to read
:returns: read data
"""
raise NotImplementedError()
[docs] def write(self, data):
"""writes the entire *data*, or raise EOFError
:param data: a string of binary data
"""
raise NotImplementedError()
def __enter__(self):
return self
def __exit__(self, *exc_info):
self.close()
class ClosedFile(object):
"""Represents a closed file object (singleton)"""
__slots__ = ()
def __getattr__(self, name):
if name.startswith("__"): # issue 71
raise AttributeError("stream has been closed")
raise EOFError("stream has been closed")
def close(self):
pass
@property
def closed(self):
return True
def fileno(self):
raise EOFError("stream has been closed")
ClosedFile = ClosedFile()
[docs]class SocketStream(Stream):
"""A stream over a socket"""
__slots__ = ("sock",)
MAX_IO_CHUNK = STREAM_CHUNK
def __init__(self, sock):
self.sock = sock
@classmethod
def _connect(cls, host, port, family=socket.AF_INET, socktype=socket.SOCK_STREAM,
proto=0, timeout=3, nodelay=False, keepalive=False, attempts=6):
family, socktype, proto, _, sockaddr = socket.getaddrinfo(host, port, family,
socktype, proto)[0]
s = socket_backoff_connect(family, socktype, proto, sockaddr, timeout, attempts)
try:
if nodelay:
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if keepalive:
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# Linux specific: after <keepalive> idle seconds, start sending keepalives every <keepalive> seconds.
is_linux_socket = hasattr(socket, "TCP_KEEPIDLE")
is_linux_socket &= hasattr(socket, "TCP_KEEPINTVL")
is_linux_socket &= hasattr(socket, "TCP_KEEPCNT")
if is_linux_socket:
# Drop connection after 5 failed keepalives
# `keepalive` may be a bool or an integer
if keepalive is True:
keepalive = 60
if keepalive < 1:
raise ValueError("Keepalive minimal value is 1 second")
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, keepalive)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, keepalive)
return s
except BaseException:
s.close()
raise
[docs] @classmethod
def connect(cls, host, port, **kwargs):
"""factory method that creates a ``SocketStream`` over a socket connected
to *host* and *port*
:param host: the host name
:param port: the TCP port
:param family: specify a custom socket family
:param socktype: specify a custom socket type
:param proto: specify a custom socket protocol
:param timeout: connection timeout (default is 3 seconds)
:param nodelay: set the TCP_NODELAY socket option
:param keepalive: enable TCP keepalives. The value should be a boolean,
but on Linux, it can also be an integer specifying the
keepalive interval (in seconds)
:param ipv6: if True, creates an IPv6 socket (``AF_INET6``); otherwise
an IPv4 (``AF_INET``) socket is created
:returns: a :class:`SocketStream`
"""
if kwargs.pop("ipv6", False):
kwargs["family"] = socket.AF_INET6
return cls(cls._connect(host, port, **kwargs))
[docs] @classmethod
def unix_connect(cls, path, timeout=3):
"""factory method that creates a ``SocketStream`` over a unix domain socket
located in *path*
:param path: the path to the unix domain socket
:param timeout: socket timeout
"""
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
s.settimeout(timeout)
s.connect(path)
return cls(s)
except BaseException:
s.close()
raise
[docs] @classmethod
def ssl_connect(cls, host, port, ssl_kwargs, **kwargs):
"""factory method that creates a ``SocketStream`` over an SSL-wrapped
socket, connected to *host* and *port* with the given credentials.
:param host: the host name
:param port: the TCP port
:param ssl_kwargs: a dictionary of keyword arguments for
``ssl.SSLContext`` and ``ssl.SSLContext.wrap_socket``
:param kwargs: additional keyword arguments: ``family``, ``socktype``,
``proto``, ``timeout``, ``nodelay``, passed directly to
the ``socket`` constructor, or ``ipv6``.
:param ipv6: if True, creates an IPv6 socket (``AF_INET6``); otherwise
an IPv4 (``AF_INET``) socket is created
:returns: a :class:`SocketStream`
"""
if kwargs.pop("ipv6", False):
kwargs["family"] = socket.AF_INET6
s = cls._connect(host, port, **kwargs)
try:
if "ssl_version" in ssl_kwargs:
context = ssl.SSLContext(ssl_kwargs.pop("ssl_version"))
else:
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
certfile = ssl_kwargs.pop("certfile", None)
keyfile = ssl_kwargs.pop("keyfile", None)
if certfile is not None:
context.load_cert_chain(certfile, keyfile=keyfile)
ca_certs = ssl_kwargs.pop("ca_certs", None)
if ca_certs is not None:
context.load_verify_locations(ca_certs)
ciphers = ssl_kwargs.pop("ciphers", None)
if ciphers is not None:
context.set_ciphers(ciphers)
check_hostname = ssl_kwargs.pop("check_hostname", None)
if check_hostname is not None:
context.check_hostname = check_hostname
cert_reqs = ssl_kwargs.pop("cert_reqs", None)
if cert_reqs is not None:
context.verify_mode = cert_reqs
s2 = context.wrap_socket(s, server_hostname=host, **ssl_kwargs)
return cls(s2)
except BaseException:
s.close()
raise
@property
def closed(self):
return self.sock is ClosedFile
[docs] def close(self):
if not self.closed:
try:
self.sock.shutdown(socket.SHUT_RDWR)
except Exception:
pass
self.sock.close()
self.sock = ClosedFile
[docs] def fileno(self):
try:
return self.sock.fileno()
except socket.error:
self.close()
ex = sys.exc_info()[1]
if get_exc_errno(ex) == errno.EBADF:
raise EOFError()
else:
raise
[docs] def read(self, count):
data = []
while count > 0:
try:
buf = self.sock.recv(min(self.MAX_IO_CHUNK, count))
except socket.timeout:
continue
except socket.error:
ex = sys.exc_info()[1]
if get_exc_errno(ex) in retry_errnos:
# windows just has to be a bitch
continue
self.close()
raise EOFError(ex)
if not buf:
self.close()
raise EOFError("connection closed by peer")
data.append(buf)
count -= len(buf)
return BYTES_LITERAL("").join(data)
[docs] def write(self, data):
try:
while data:
count = self.sock.send(data[:self.MAX_IO_CHUNK])
data = data[count:]
except socket.error:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
[docs]class TunneledSocketStream(SocketStream):
"""A socket stream over an SSH tunnel (terminates the tunnel when the connection closes)"""
__slots__ = ("tun",)
def __init__(self, sock):
self.sock = sock
self.tun = None
[docs] def close(self):
SocketStream.close(self)
if self.tun:
self.tun.close()
[docs]class PipeStream(Stream):
"""A stream over two simplex pipes (one used to input, another for output)"""
__slots__ = ("incoming", "outgoing")
MAX_IO_CHUNK = STREAM_CHUNK
def __init__(self, incoming, outgoing):
outgoing.flush()
self.incoming = incoming
self.outgoing = outgoing
[docs] @classmethod
def from_std(cls):
"""factory method that creates a PipeStream over the standard pipes
(``stdin`` and ``stdout``)
:returns: a :class:`PipeStream` instance
"""
return cls(sys.stdin, sys.stdout)
[docs] @classmethod
def create_pair(cls):
"""factory method that creates two pairs of anonymous pipes, and
creates two PipeStreams over them. Useful for ``fork()``.
:returns: a tuple of two :class:`PipeStream` instances
"""
r1, w1 = os.pipe()
r2, w2 = os.pipe()
side1 = cls(os.fdopen(r1, "rb"), os.fdopen(w2, "wb"))
side2 = cls(os.fdopen(r2, "rb"), os.fdopen(w1, "wb"))
return side1, side2
@property
def closed(self):
return self.incoming is ClosedFile
[docs] def close(self):
self.incoming.close()
self.outgoing.close()
self.incoming = ClosedFile
self.outgoing = ClosedFile
[docs] def fileno(self):
return self.incoming.fileno()
[docs] def read(self, count):
data = []
try:
while count > 0:
buf = os.read(self.incoming.fileno(), min(self.MAX_IO_CHUNK, count))
if not buf:
raise EOFError("connection closed by peer")
data.append(buf)
count -= len(buf)
except EOFError:
self.close()
raise
except EnvironmentError:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
return BYTES_LITERAL("").join(data)
[docs] def write(self, data):
try:
while data:
chunk = data[:self.MAX_IO_CHUNK]
written = os.write(self.outgoing.fileno(), chunk)
data = data[written:]
except EnvironmentError:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
[docs]class Win32PipeStream(Stream):
"""A stream over two simplex pipes (one used to input, another for output).
This is an implementation for Windows pipes (which suck)"""
__slots__ = ("incoming", "outgoing", "_fileno", "_keepalive")
PIPE_BUFFER_SIZE = 130000
MAX_IO_CHUNK = STREAM_CHUNK
def __init__(self, incoming, outgoing):
import msvcrt
self._keepalive = (incoming, outgoing)
if hasattr(incoming, "fileno"):
self._fileno = incoming.fileno()
incoming = msvcrt.get_osfhandle(incoming.fileno())
if hasattr(outgoing, "fileno"):
outgoing = msvcrt.get_osfhandle(outgoing.fileno())
self.incoming = incoming
self.outgoing = outgoing
@classmethod
def from_std(cls):
return cls(sys.stdin, sys.stdout)
@classmethod
def create_pair(cls):
r1, w1 = win32pipe.CreatePipe(None, cls.PIPE_BUFFER_SIZE)
r2, w2 = win32pipe.CreatePipe(None, cls.PIPE_BUFFER_SIZE)
return cls(r1, w2), cls(r2, w1)
[docs] def fileno(self):
return self._fileno
@property
def closed(self):
return self.incoming is ClosedFile
[docs] def close(self):
if self.closed:
return
try:
win32file.CloseHandle(self.incoming)
except Exception:
pass
self.incoming = ClosedFile
try:
win32file.CloseHandle(self.outgoing)
except Exception:
pass
self.outgoing = ClosedFile
[docs] def read(self, count):
try:
data = []
while count > 0:
dummy, buf = win32file.ReadFile(self.incoming, int(min(self.MAX_IO_CHUNK, count)))
count -= len(buf)
data.append(buf)
except TypeError:
ex = sys.exc_info()[1]
if not self.closed:
raise
raise EOFError(ex)
except win32file.error:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
return BYTES_LITERAL("").join(data)
[docs] def write(self, data):
try:
while data:
dummy, count = win32file.WriteFile(self.outgoing, data[:self.MAX_IO_CHUNK])
data = data[count:]
except TypeError:
ex = sys.exc_info()[1]
if not self.closed:
raise
raise EOFError(ex)
except win32file.error:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
[docs] def poll(self, timeout, interval=0.001):
"""a Windows version of select()"""
timeout = Timeout(timeout)
try:
while True:
if win32pipe.PeekNamedPipe(self.incoming, 0)[1] != 0:
return True
if timeout.expired():
return False
timeout.sleep(interval)
except TypeError:
ex = sys.exc_info()[1]
if not self.closed:
raise
raise EOFError(ex)
[docs]class NamedPipeStream(Win32PipeStream):
"""A stream over two named pipes (one used to input, another for output).
Windows implementation."""
NAMED_PIPE_PREFIX = r'\\.\pipe\rpyc_'
PIPE_IO_TIMEOUT = 3
CONNECT_TIMEOUT = 3
def __init__(self, handle, is_server_side):
import pywintypes
Win32PipeStream.__init__(self, handle, handle)
self.is_server_side = is_server_side
self.read_overlapped = pywintypes.OVERLAPPED()
self.read_overlapped.hEvent = win32event.CreateEvent(None, 1, 1, None)
self.write_overlapped = pywintypes.OVERLAPPED()
self.write_overlapped.hEvent = win32event.CreateEvent(None, 1, 1, None)
self.poll_buffer = win32file.AllocateReadBuffer(1)
self.poll_read = False
@classmethod
def from_std(cls):
raise NotImplementedError()
@classmethod
def create_pair(cls):
raise NotImplementedError()
[docs] @classmethod
def create_server(cls, pipename, connect=True):
"""factory method that creates a server-side ``NamedPipeStream``, over
a newly-created *named pipe* of the given name.
:param pipename: the name of the pipe. It will be considered absolute if
it starts with ``\\\\.``; otherwise ``\\\\.\\pipe\\rpyc``
will be prepended.
:param connect: whether to connect on creation or not
:returns: a :class:`NamedPipeStream` instance
"""
if not pipename.startswith("\\\\."):
pipename = cls.NAMED_PIPE_PREFIX + pipename
handle = win32pipe.CreateNamedPipe(
pipename,
win32pipe.PIPE_ACCESS_DUPLEX | win32file.FILE_FLAG_OVERLAPPED,
win32pipe.PIPE_TYPE_BYTE | win32pipe.PIPE_READMODE_BYTE,
1,
cls.PIPE_BUFFER_SIZE,
cls.PIPE_BUFFER_SIZE,
cls.PIPE_IO_TIMEOUT * 1000,
None
)
inst = cls(handle, True)
if connect:
inst.connect_server()
return inst
[docs] def connect_server(self):
"""connects the server side of an unconnected named pipe (blocks
until a connection arrives)"""
if not self.is_server_side:
raise ValueError("this must be the server side")
win32pipe.ConnectNamedPipe(self.incoming, self.write_overlapped)
win32event.WaitForSingleObject(self.write_overlapped.hEvent, win32event.INFINITE)
[docs] @classmethod
def create_client(cls, pipename):
"""factory method that creates a client-side ``NamedPipeStream``, over
a newly-created *named pipe* of the given name.
:param pipename: the name of the pipe. It will be considered absolute if
it starts with ``\\\\.``; otherwise ``\\\\.\\pipe\\rpyc``
will be prepended.
:returns: a :class:`NamedPipeStream` instance
"""
if not pipename.startswith("\\\\."):
pipename = cls.NAMED_PIPE_PREFIX + pipename
handle = win32file.CreateFile(
pipename,
win32file.GENERIC_READ | win32file.GENERIC_WRITE,
0,
None,
win32file.OPEN_EXISTING,
win32file.FILE_FLAG_OVERLAPPED,
None
)
return cls(handle, False)
[docs] def close(self):
if self.closed:
return
if self.is_server_side:
win32file.FlushFileBuffers(self.outgoing)
win32pipe.DisconnectNamedPipe(self.outgoing)
win32file.CloseHandle(self.read_overlapped.hEvent)
win32file.CloseHandle(self.write_overlapped.hEvent)
Win32PipeStream.close(self)
[docs] def read(self, count):
try:
if self.poll_read:
win32file.GetOverlappedResult(self.incoming, self.read_overlapped, 1)
data = [self.poll_buffer[:]]
self.poll_read = False
count -= 1
else:
data = []
while count > 0:
hr, buf = win32file.ReadFile(self.incoming,
win32file.AllocateReadBuffer(int(min(self.MAX_IO_CHUNK, count))),
self.read_overlapped)
n = win32file.GetOverlappedResult(self.incoming, self.read_overlapped, 1)
count -= n
data.append(buf[:n])
except TypeError:
ex = sys.exc_info()[1]
if not self.closed:
raise
raise EOFError(ex)
except win32file.error:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
return BYTES_LITERAL("").join(data)
[docs] def write(self, data):
try:
while data:
dummy, count = win32file.WriteFile(self.outgoing, data[:self.MAX_IO_CHUNK], self.write_overlapped)
data = data[count:]
except TypeError:
ex = sys.exc_info()[1]
if not self.closed:
raise
raise EOFError(ex)
except win32file.error:
ex = sys.exc_info()[1]
self.close()
raise EOFError(ex)
[docs] def poll(self, timeout, interval=0.001):
"""Windows version of select()"""
timeout = Timeout(timeout)
try:
if timeout.finite:
wait_time = int(max(1, timeout.timeleft() * 1000))
else:
wait_time = win32event.INFINITE
if not self.poll_read:
hr, self.poll_buffer = win32file.ReadFile(self.incoming,
self.poll_buffer,
self.read_overlapped)
self.poll_read = True
if hr == 0:
return True
res = win32event.WaitForSingleObject(self.read_overlapped.hEvent, wait_time)
return res == win32event.WAIT_OBJECT_0
except TypeError:
ex = sys.exc_info()[1]
if not self.closed:
raise
raise EOFError(ex)
if sys.platform == "win32":
PipeStream = Win32PipeStream # noqa: F811