Adding dask-ctl support to your project

Adding dask-ctl support to your project

dask-ctl is an opt-in utility package for managing Dask cluster lifecycles.

In order for a cluster manager to work with dask-ctl it must have the following things:

  • A discovery method registered as an entrypoint

  • A from_name class method which reconstructs the cluster manager


In order for cluster to be visible in dask-ctl the cluster manager which created it must implement a discovery method and register it as an dask_cluster_discovery entrypoint.

import setuptools


This method must be an async generator which returns tuples of the cluster name and a class which can be used to reconstruct it.

from typing import Callable, AsyncIterator, Tuple

from my_package.cluster import MyClusterManager  # A cluster manager class which supports the ``from_name`` classmethod

async def discover() -> AsyncIterator[Tuple[str, Callable]]:

    # Discover cluster names in whatever way is appropriate
    cluster_names = [...]

    for cluster_name in cluster_names:
        yield (cluster_name, MyClusterManager)

From name

When dask-ctl discovers clusters it iterates through all the registered discovery methods and constructs a list of name/cluster manager pairs.

Then when making calls such as get_cluster it will attempt to call the from_name class method on the cluster manager and pass in the name that was provided during discovery.

Cluster managers are contructed from name during almost all dask-ctl operations. Even calling dask cluster list on the CLI will create all cluster managers in order to query information about them such as number of workers and resources via the scheduler comm.

Implementation of this method will vary drastically depending on how the cluster manager is implemented. But the interface should take the name argument and contruct a cluster manager class and return it.

from distributed.deploy.cluster import Cluster

class MyClusterManager(Cluster):


    def from_name(
        cls, name: str, loop: asyncio.BaseEventLoop = None, asynchronous: bool = False
        cluster = cls(name=name, asynchronous=asynchronous)

        # Connect to the scheduler comm
        cluster.scheduler_comm = rpc(...)

        # Put the cluster manager into a started and running state

        return cluster

Testing integration

A useful test to ensure your cluster manager will be compliant with dask-ctl would be to follow these steps:

  • Create a cluster using your cluster manager class

  • Record the name of that cluster

  • Run dask-ctl discovery and ensure the cluster is listed

  • Ensure the cluster is not created when the cluster manager is destroyed

  • Delete the cluster manager object

  • Recreate the cluster manager object from the name

  • Check that the cluster is working as expected

import pytest

from dask.distributed import Client
from dask_ctl.discovery import (

from my_package.cluster import MyClusterManager

async def test_from_name():
    # Create cluster
    cluster = await MyClusterManager(*args, **kwargs)
    await cluster.scale(1)
    name =

    # Check cluster listed in discovery
    discovery = "mycluster"
    assert discovery in list_discovery_methods()
    clusters_names = [
        cluster async for cluster in discover_cluster_names(discovery=discovery)
    assert len(clusters_names) == 1
    discovered_name, discovered_class = cluster_names[0]
    assert discovered_name == name
    assert discovered_class == MyClusterManager

    # Delete cluster manager
    cluster.shutdown_on_close = False
    del cluster

    # Recreate cluster manager from name
    cluster = await MyClusterManager.from_name(name, asynchronous=True)
    assert "id" in cluster.scheduler_info
    assert cluster.status == Status.running

    # Ensure work can be run on cluster
    async with Client(cluster, asynchronous=True) as client:
        # Ensure that inter-worker communication works well
        futures = x: x + 1, range(10))
        total = client.submit(sum, futures)
        assert (await total) == sum(map(lambda x: x + 1, range(10)))
        assert all((await client.has_what()).values())