[python] Asyncio.gather vs asyncio.wait

In addition to all the previous answers, I would like to tell about the different behavior of gather() and wait() in case they are cancelled.

Gather cancellation

If gather() is cancelled, all submitted awaitables (that have not completed yet) are also cancelled.

Wait cancellation

If the wait() task is cancelled, it simply throws an CancelledError and the waited tasks remain intact.

Simple example:

import asyncio


async def task(arg):
    await asyncio.sleep(5)
    return arg


async def cancel_waiting_task(work_task, waiting_task):
    await asyncio.sleep(2)
    waiting_task.cancel()
    try:
        await waiting_task
        print("Waiting done")
    except asyncio.CancelledError:
        print("Waiting task cancelled")

    try:
        res = await work_task
        print(f"Work result: {res}")
    except asyncio.CancelledError:
        print("Work task cancelled")


async def main():
    work_task = asyncio.create_task(task("done"))
    waiting = asyncio.create_task(asyncio.wait({work_task}))
    await cancel_waiting_task(work_task, waiting)

    work_task = asyncio.create_task(task("done"))
    waiting = asyncio.gather(work_task)
    await cancel_waiting_task(work_task, waiting)


asyncio.run(main())

Output:

asyncio.wait()
Waiting task cancelled
Work result: done
----------------
asyncio.gather()
Waiting task cancelled
Work task cancelled

Sometimes it becomes necessary to combine wait() and gather() functionality. For example, we want to wait for the completion of at least one task and cancel the rest pending tasks after that, and if the waiting itself was canceled, then also cancel all pending tasks.

As real examples, let's say we have a disconnect event and a work task. And we want to wait for the results of the work task, but if the connection was lost, then cancel it. Or we will make several parallel requests, but upon completion of at least one response, cancel all others.

It could be done this way:

import asyncio
from typing import Optional, Tuple, Set


async def wait_any(
        tasks: Set[asyncio.Future], *, timeout: Optional[int] = None,
) -> Tuple[Set[asyncio.Future], Set[asyncio.Future]]:
    tasks_to_cancel: Set[asyncio.Future] = set()
    try:
        done, tasks_to_cancel = await asyncio.wait(
            tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
        )
        return done, tasks_to_cancel
    except asyncio.CancelledError:
        tasks_to_cancel = tasks
        raise
    finally:
        for task in tasks_to_cancel:
            task.cancel()


async def task():
    await asyncio.sleep(5)


async def cancel_waiting_task(work_task, waiting_task):
    await asyncio.sleep(2)
    waiting_task.cancel()
    try:
        await waiting_task
        print("Waiting done")
    except asyncio.CancelledError:
        print("Waiting task cancelled")

    try:
        res = await work_task
        print(f"Work result: {res}")
    except asyncio.CancelledError:
        print("Work task cancelled")


async def check_tasks(waiting_task, working_task, waiting_conn_lost_task):
    try:
        await waiting_task
        print("waiting is done")
    except asyncio.CancelledError:
        print("waiting is cancelled")

    try:
        await waiting_conn_lost_task
        print("connection is lost")
    except asyncio.CancelledError:
        print("waiting connection lost is cancelled")

    try:
        await working_task
        print("work is done")
    except asyncio.CancelledError:
        print("work is cancelled")


async def work_done_case():
    working_task = asyncio.create_task(task())
    connection_lost_event = asyncio.Event()
    waiting_conn_lost_task = asyncio.create_task(connection_lost_event.wait())
    waiting_task = asyncio.create_task(wait_any({working_task, waiting_conn_lost_task}))
    await check_tasks(waiting_task, working_task, waiting_conn_lost_task)


async def conn_lost_case():
    working_task = asyncio.create_task(task())
    connection_lost_event = asyncio.Event()
    waiting_conn_lost_task = asyncio.create_task(connection_lost_event.wait())
    waiting_task = asyncio.create_task(wait_any({working_task, waiting_conn_lost_task}))
    await asyncio.sleep(2)
    connection_lost_event.set()  # <---
    await check_tasks(waiting_task, working_task, waiting_conn_lost_task)


async def cancel_waiting_case():
    working_task = asyncio.create_task(task())
    connection_lost_event = asyncio.Event()
    waiting_conn_lost_task = asyncio.create_task(connection_lost_event.wait())
    waiting_task = asyncio.create_task(wait_any({working_task, waiting_conn_lost_task}))
    await asyncio.sleep(2)
    waiting_task.cancel()  # <---
    await check_tasks(waiting_task, working_task, waiting_conn_lost_task)


async def main():
    print("Work done")
    print("-------------------")
    await work_done_case()
    print("\nConnection lost")
    print("-------------------")
    await conn_lost_case()
    print("\nCancel waiting")
    print("-------------------")
    await cancel_waiting_case()


asyncio.run(main())

Output:

Work done
-------------------
waiting is done
waiting connection lost is cancelled
work is done

Connection lost
-------------------
waiting is done
connection is lost
work is cancelled

Cancel waiting
-------------------
waiting is cancelled
waiting connection lost is cancelled
work is cancelled