Skip to content

proxystore_ex.connectors.dim.zmq

ZeroMQ-based distributed in-memory connector implementation.

ZeroMQConnector

ZeroMQConnector(
    port: int,
    address: str | None = None,
    interface: str | None = None,
    chunk_length: int | None = None,
    timeout: float = 1,
)

ZeroMQ-based distributed in-memory connector.

Note

The first instance of this connector created on a process will spawn a ZeroMQServer that will store data. Hence, this connector just acts as an interface to that server.

Parameters:

  • address (str | None, default: None ) –

    The network IP address to use. Takes precedence over interface if both are provided.

  • interface (str | None, default: None ) –

    The network interface to use. address arg takes precedence if both are provided.

  • port (int) –

    The desired port for the spawned server.

  • chunk_length (int | None, default: None ) –

    Message chunk size in bytes. Defaults to MAX_CHUNK_LENGTH_DEFAULT.

  • timeout (float, default: 1 ) –

    Timeout in seconds to try connecting to local server before spawning one.

Raises:

  • ServerTimeoutError

    If a local server cannot be connected to within timeout seconds, and a new local server does not response within timeout seconds after being started.

Source code in proxystore_ex/connectors/dim/zmq.py
def __init__(
    self,
    port: int,
    address: str | None = None,
    interface: str | None = None,
    chunk_length: int | None = None,
    timeout: float = 1,
) -> None:
    # ZMQ is not a default dependency so we don't want to raise
    # an error unless the user actually tries to use this code
    if zmq_import_error is not None:  # pragma: no cover
        raise zmq_import_error

    self._address = address
    self._interface = interface
    self.port = port
    self.chunk_length = (
        MAX_CHUNK_LENGTH_DEFAULT if chunk_length is None else chunk_length
    )
    self.timeout = timeout

    if self._address is not None:
        self.address = self._address
    elif self._interface is not None:  # pragma: darwin no cover
        self.address = get_ip_address(self._interface)
    else:
        host = socket.gethostname()
        self.address = socket.gethostbyname(host)

    self.url = f'tcp://{self.address}:{self.port}'

    self.server: multiprocessing.Process | None
    try:
        logger.info(
            f'Connecting to local server (url={self.url})...',
        )
        wait_for_server(self.address, self.port, self.timeout)
        logger.info(
            f'Connected to local server (url={self.url})',
        )
    except ServerTimeoutError:
        logger.info(
            'Failed to connect to local server '
            f'(address={self.url}, timeout={self.timeout})',
        )
        self.server = spawn_server(
            self.address,
            self.port,
            chunk_length=self.chunk_length,
            spawn_timeout=self.timeout,
        )
        logger.info(f'Spawned local server (url={self.url})')
    else:
        self.server = None

    self.context = zmq.Context()
    self.socket = self.context.socket(zmq.REQ)

close

close(kill_server: bool = True) -> None

Close the connector.

Parameters:

  • kill_server (bool, default: True ) –

    Whether to kill the server process. If this instance did not spawn the local node's server process, this is a no-op.

Source code in proxystore_ex/connectors/dim/zmq.py
def close(self, kill_server: bool = True) -> None:
    """Close the connector.

    Args:
        kill_server: Whether to kill the server process. If this instance
            did not spawn the local node's server process, this is a
            no-op.
    """
    if kill_server and self.server is not None:
        self.server.terminate()
        self.server.join()
        logger.info(
            'Terminated local server on connector close '
            f'(pid={self.server.pid})',
        )

    self.socket.close()
    self.context.term()
    logger.info('Closed ZMQ connector')

config

config() -> dict[str, Any]

Get the connector configuration.

The configuration contains all the information needed to reconstruct the connector object.

Source code in proxystore_ex/connectors/dim/zmq.py
def config(self) -> dict[str, Any]:
    """Get the connector configuration.

    The configuration contains all the information needed to reconstruct
    the connector object.
    """
    return {
        'address': self._address,
        'interface': self._interface,
        'port': self.port,
        'chunk_length': self.chunk_length,
        'timeout': self.timeout,
    }

from_config classmethod

from_config(config: dict[str, Any]) -> ZeroMQConnector

Create a new connector instance from a configuration.

Parameters:

  • config (dict[str, Any]) –

    Configuration returned by .config().

Source code in proxystore_ex/connectors/dim/zmq.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> ZeroMQConnector:
    """Create a new connector instance from a configuration.

    Args:
        config: Configuration returned by `#!python .config()`.
    """
    return cls(**config)

evict

evict(key: DIMKey) -> None

Evict the object associated with the key.

Parameters:

  • key (DIMKey) –

    Key associated with object to evict.

Source code in proxystore_ex/connectors/dim/zmq.py
def evict(self, key: DIMKey) -> None:
    """Evict the object associated with the key.

    Args:
        key: Key associated with object to evict.
    """
    rpc = RPC(operation='evict', key=key)
    self._send_rpcs([rpc])

exists

exists(key: DIMKey) -> bool

Check if an object associated with the key exists.

Parameters:

  • key (DIMKey) –

    Key potentially associated with stored object.

Returns:

  • bool

    If an object associated with the key exists.

Source code in proxystore_ex/connectors/dim/zmq.py
def exists(self, key: DIMKey) -> bool:
    """Check if an object associated with the key exists.

    Args:
        key: Key potentially associated with stored object.

    Returns:
        If an object associated with the key exists.
    """
    rpc = RPC(operation='exists', key=key)
    (response,) = self._send_rpcs([rpc])
    assert response.exists is not None
    return response.exists

get

get(key: DIMKey) -> bytes | None

Get the serialized object associated with the key.

Parameters:

  • key (DIMKey) –

    Key associated with the object to retrieve.

Returns:

  • bytes | None

    Serialized object or None if the object does not exist.

Source code in proxystore_ex/connectors/dim/zmq.py
def get(self, key: DIMKey) -> bytes | None:
    """Get the serialized object associated with the key.

    Args:
        key: Key associated with the object to retrieve.

    Returns:
        Serialized object or `None` if the object does not exist.
    """
    rpc = RPC(operation='get', key=key)
    (result,) = self._send_rpcs([rpc])
    return result.data

get_batch

get_batch(keys: Sequence[DIMKey]) -> list[bytes | None]

Get a batch of serialized objects associated with the keys.

Parameters:

  • keys (Sequence[DIMKey]) –

    Sequence of keys associated with objects to retrieve.

Returns:

  • list[bytes | None]

    List with same order as keys with the serialized objects or None if the corresponding key does not have an associated object.

Source code in proxystore_ex/connectors/dim/zmq.py
def get_batch(self, keys: Sequence[DIMKey]) -> list[bytes | None]:
    """Get a batch of serialized objects associated with the keys.

    Args:
        keys: Sequence of keys associated with objects to retrieve.

    Returns:
        List with same order as `keys` with the serialized objects or \
        `None` if the corresponding key does not have an associated object.
    """
    rpcs = [RPC(operation='get', key=key) for key in keys]
    responses = self._send_rpcs(rpcs)
    return [r.data for r in responses]

put

put(obj: bytes) -> DIMKey

Put a serialized object in the store.

Parameters:

  • obj (bytes) –

    Serialized object to put in the store.

Returns:

  • DIMKey

    Key which can be used to retrieve the object.

Source code in proxystore_ex/connectors/dim/zmq.py
def put(self, obj: bytes) -> DIMKey:
    """Put a serialized object in the store.

    Args:
        obj: Serialized object to put in the store.

    Returns:
        Key which can be used to retrieve the object.
    """
    key = DIMKey(
        dim_type='zmq',
        obj_id=str(uuid.uuid4()),
        size=len(obj),
        peer_host=self.address,
        peer_port=self.port,
    )
    rpc = RPC(operation='put', key=key, data=obj)
    self._send_rpcs([rpc])
    return key

put_batch

put_batch(objs: Sequence[bytes]) -> list[DIMKey]

Put a batch of serialized objects in the store.

Parameters:

  • objs (Sequence[bytes]) –

    Sequence of serialized objects to put in the store.

Returns:

  • list[DIMKey]

    List of keys with the same order as objs which can be used to retrieve the objects.

Source code in proxystore_ex/connectors/dim/zmq.py
def put_batch(self, objs: Sequence[bytes]) -> list[DIMKey]:
    """Put a batch of serialized objects in the store.

    Args:
        objs: Sequence of serialized objects to put in the store.

    Returns:
        List of keys with the same order as `objs` which can be used to \
        retrieve the objects.
    """
    keys = [
        DIMKey(
            dim_type='zmq',
            obj_id=str(uuid.uuid4()),
            size=len(obj),
            peer_host=self.address,
            peer_port=self.port,
        )
        for obj in objs
    ]
    rpcs = [
        RPC(operation='put', key=key, data=obj)
        for key, obj in zip(keys, objs)
    ]
    self._send_rpcs(rpcs)
    return keys

ZeroMQServer

ZeroMQServer()

ZeroMQServer implementation.

Source code in proxystore_ex/connectors/dim/zmq.py
def __init__(self) -> None:
    self.data: dict[str, bytes] = {}

evict

evict(key: str) -> None

Evict the object associated with the key.

Parameters:

  • key (str) –

    Key associated with object to evict.

Source code in proxystore_ex/connectors/dim/zmq.py
def evict(self, key: str) -> None:
    """Evict the object associated with the key.

    Args:
        key: Key associated with object to evict.
    """
    self.data.pop(key, None)

exists

exists(key: str) -> bool

Check if an object associated with the key exists.

Parameters:

  • key (str) –

    Key potentially associated with stored object.

Returns:

  • bool

    If an object associated with the key exists.

Source code in proxystore_ex/connectors/dim/zmq.py
def exists(self, key: str) -> bool:
    """Check if an object associated with the key exists.

    Args:
        key: Key potentially associated with stored object.

    Returns:
        If an object associated with the key exists.
    """
    return key in self.data

get

get(key: str) -> bytes | None

Get the serialized object associated with the key.

Parameters:

  • key (str) –

    Key associated with the object to retrieve.

Returns:

  • bytes | None

    Data or None if no data associated with the key exists.

Source code in proxystore_ex/connectors/dim/zmq.py
def get(self, key: str) -> bytes | None:
    """Get the serialized object associated with the key.

    Args:
        key: Key associated with the object to retrieve.

    Returns:
        Data or `None` if no data associated with the key exists.
    """
    return self.data.get(key, None)

put

put(key: str, data: bytes) -> None

Put data in the store.

Parameters:

  • key (str) –

    Key associated with data.

  • data (bytes) –

    Data to put in the store.

Source code in proxystore_ex/connectors/dim/zmq.py
def put(self, key: str, data: bytes) -> None:
    """Put data in the store.

    Args:
        key: Key associated with data.
        data: Data to put in the store.
    """
    self.data[key] = data

handle_rpc

handle_rpc(rpc: RPC) -> RPCResponse

Process an RPC request.

Parameters:

  • rpc (RPC) –

    Client RPC to process.

Returns:

  • RPCResponse

    Response containing result or an exception if the operation failed.

Source code in proxystore_ex/connectors/dim/zmq.py
def handle_rpc(self, rpc: RPC) -> RPCResponse:
    """Process an RPC request.

    Args:
        rpc: Client RPC to process.

    Returns:
        Response containing result or an exception if the operation failed.
    """
    response: RPCResponse
    try:
        if rpc.operation == 'exists':
            exists = self.exists(rpc.key.obj_id)
            response = RPCResponse('exists', key=rpc.key, exists=exists)
        elif rpc.operation == 'evict':
            self.evict(rpc.key.obj_id)
            response = RPCResponse('evict', key=rpc.key)
        elif rpc.operation == 'get':
            data = self.get(rpc.key.obj_id)
            response = RPCResponse('get', key=rpc.key, data=data)
        elif rpc.operation == 'put':
            assert rpc.data is not None
            self.put(rpc.key.obj_id, rpc.data)
            response = RPCResponse('put', key=rpc.key)
        else:
            raise AssertionError('Unreachable.')
    except Exception as e:
        response = RPCResponse(rpc.operation, key=rpc.key, exception=e)
    return response

run_server async

run_server(
    address: str, port: int, chunk_length: int | None = None
) -> None

Listen and reply to RPCs from clients.

Warning

This function does not return until SIGINT or SIGTERM is received.

Parameters:

  • address (str) –

    IP address the server should bind to.

  • port (int) –

    Port the server should listen on.

  • chunk_length (int | None, default: None ) –

    Message chunk size in bytes. Defaults to MAX_CHUNK_LENGTH_DEFAULT.

Source code in proxystore_ex/connectors/dim/zmq.py
async def run_server(
    address: str,
    port: int,
    chunk_length: int | None = None,
) -> None:
    """Listen and reply to RPCs from clients.

    Warning:
        This function does not return until SIGINT or SIGTERM is received.

    Args:
        address: IP address the server should bind to.
        port: Port the server should listen on.
        chunk_length: Message chunk size in bytes. Defaults to
            `MAX_CHUNK_LENGTH_DEFAULT`.
    """
    loop = asyncio.get_running_loop()
    close_future = loop.create_future()

    loop.add_signal_handler(signal.SIGINT, close_future.set_result, None)
    loop.add_signal_handler(signal.SIGTERM, close_future.set_result, None)

    server = ZeroMQServer()
    chunk_length = (
        MAX_CHUNK_LENGTH_DEFAULT if chunk_length is None else chunk_length
    )

    context = zmq.asyncio.Context()
    socket = context.socket(zmq.REP)
    socket.setsockopt(zmq.RCVTIMEO, 100)

    with socket.bind(f'tcp://{address}:{port}'):
        while not close_future.done():
            try:
                rpc_parts = await socket.recv_multipart()
            except zmq.error.Again:
                continue

            rpc_bytes = b''.join(rpc_parts)

            if rpc_bytes == b'ping':
                await socket.send(b'pong')
                continue

            rpc: RPC = deserialize(rpc_bytes)
            response = server.handle_rpc(rpc)

            message = serialize(response)
            await socket.send_multipart(
                list(chunk_bytes(message, chunk_length)),
            )

    loop.remove_signal_handler(signal.SIGINT)
    loop.remove_signal_handler(signal.SIGTERM)

    socket.close()
    context.term()

start_server

start_server(
    address: str, port: int, chunk_length: int | None = None
) -> None

Run a local server.

Note

This function creates an event loop and executes run_server() within that loop.

Parameters:

  • address (str) –

    IP address the server should bind to.

  • port (int) –

    Port the server should listen on.

  • chunk_length (int | None, default: None ) –

    Message chunk size in bytes. Defaults to MAX_CHUNK_LENGTH_DEFAULT.

Source code in proxystore_ex/connectors/dim/zmq.py
def start_server(
    address: str,
    port: int,
    chunk_length: int | None = None,
) -> None:
    """Run a local server.

    Note:
        This function creates an event loop and executes
        [`run_server()`][proxystore_ex.connectors.dim.zmq.run_server] within
        that loop.

    Args:
        address: IP address the server should bind to.
        port: Port the server should listen on.
        chunk_length: Message chunk size in bytes. Defaults to
            `MAX_CHUNK_LENGTH_DEFAULT`.
    """
    asyncio.run(run_server(address, port, chunk_length))

spawn_server

spawn_server(
    address: str,
    port: int,
    *,
    chunk_length: int | None = None,
    spawn_timeout: float = 5.0,
    kill_timeout: float | None = 1.0
) -> Process

Spawn a local server running in a separate process.

Note

An atexit callback is registered which will terminate the spawned server process when the calling process exits.

Parameters:

  • address (str) –

    IP address the server should bind to.

  • port (int) –

    Port the server will listen on.

  • chunk_length (int | None, default: None ) –

    Message chunk size in bytes. Defaults to MAX_CHUNK_LENGTH_DEFAULT.

  • spawn_timeout (float, default: 5.0 ) –

    Max time in seconds to wait for the server to start.

  • kill_timeout (float | None, default: 1.0 ) –

    Max time in seconds to wait for the server to shutdown on exit.

Returns:

  • Process

    The process that the server is running in.

Source code in proxystore_ex/connectors/dim/zmq.py
def spawn_server(
    address: str,
    port: int,
    *,
    chunk_length: int | None = None,
    spawn_timeout: float = 5.0,
    kill_timeout: float | None = 1.0,
) -> multiprocessing.Process:
    """Spawn a local server running in a separate process.

    Note:
        An `atexit` callback is registered which will terminate the spawned
        server process when the calling process exits.

    Args:
        address: IP address the server should bind to.
        port: Port the server will listen on.
        chunk_length: Message chunk size in bytes. Defaults to
            `MAX_CHUNK_LENGTH_DEFAULT`.
        spawn_timeout: Max time in seconds to wait for the server to start.
        kill_timeout: Max time in seconds to wait for the server to shutdown
            on exit.

    Returns:
        The process that the server is running in.
    """
    server_process = multiprocessing.Process(
        target=start_server,
        args=(address, port, chunk_length),
    )
    server_process.start()

    def _kill_on_exit() -> None:  # pragma: no cover
        server_process.terminate()
        server_process.join(timeout=kill_timeout)
        if server_process.is_alive():
            server_process.kill()
            server_process.join()
        logger.debug(
            'Server terminated on parent process exit '
            f'(pid={server_process.pid})',
        )

    atexit.register(_kill_on_exit)
    logger.debug('Registered server cleanup atexit callback')

    wait_for_server(address, port, timeout=spawn_timeout)
    logger.debug(
        'Server started '
        f'(host={address}, port={port}, pid={server_process.pid})',
    )

    return server_process

wait_for_server

wait_for_server(
    address: str, port: int, timeout: float = 0.1
) -> None

Wait until the server responds.

Parameters:

  • address (str) –

    Host of the server to ping.

  • port (int) –

    Port of the server to ping.

  • timeout (float, default: 0.1 ) –

    Max time in seconds to wait for server response.

Raises:

Source code in proxystore_ex/connectors/dim/zmq.py
def wait_for_server(address: str, port: int, timeout: float = 0.1) -> None:
    """Wait until the server responds.

    Args:
        address: Host of the server to ping.
        port: Port of the server to ping.
        timeout: Max time in seconds to wait for server response.

    Raises:
        ServerTimeoutError: If the server does not respond within the timeout.
    """
    start = time.time()
    context = zmq.Context()
    socket = context.socket(zmq.REQ)
    socket.setsockopt(zmq.LINGER, 0)
    socket.connect(f'tcp://{address}:{port}')
    socket.send(b'ping')

    poller = zmq.Poller()
    poller.register(socket, zmq.POLLIN)

    while time.time() - start < timeout:
        # Poll for 100ms
        event = poller.poll(100)
        if len(event) != 0:
            response = socket.recv()
            assert response == b'pong'
            socket.close()
            return

    socket.close()

    raise ServerTimeoutError(
        f'Failed to connect to server within timeout ({timeout} seconds).',
    )