zig/lib/std / event/wait_group.zig

A WaitGroup keeps track and waits for a group of async tasks to finish. Call begin when creating new tasks, and have tasks call finish when done. You can provide a count for both operations to perform them in bulk. Call wait to suspend until all tasks are completed. Multiple waiters are supported. WaitGroup is an instance of WaitGroupGeneric, which takes in a bitsize for the internal counter. WaitGroup defaults to a usize counter. It's also possible to define a max value for the counter so that begin will return error.Overflow when the limit is reached, even if the integer type has not has not overflowed. By default max_value is set to std.math.maxInt(CounterType).

const std = @import("../std.zig");
const builtin = @import("builtin");
const Loop = std.event.Loop;

WaitGroup


/// A WaitGroup keeps track and waits for a group of async tasks to finish.
/// Call `begin` when creating new tasks, and have tasks call `finish` when done.
/// You can provide a count for both operations to perform them in bulk.
/// Call `wait` to suspend until all tasks are completed.
/// Multiple waiters are supported.
///
/// WaitGroup is an instance of WaitGroupGeneric, which takes in a bitsize
/// for the internal counter. WaitGroup defaults to a `usize` counter.
/// It's also possible to define a max value for the counter so that
/// `begin` will return error.Overflow when the limit is reached, even
/// if the integer type has not has not overflowed.
/// By default `max_value` is set to std.math.maxInt(CounterType).
pub const WaitGroup = WaitGroupGeneric(@bitSizeOf(usize));

WaitGroupGeneric()


pub fn WaitGroupGeneric(comptime counter_size: u16) type {
    const CounterType = std.meta.Int(.unsigned, counter_size);

begin()


    const global_event_loop = Loop.instance orelse
        @compileError("std.event.WaitGroup currently only works with event-based I/O");

finish()


    return struct {
        counter: CounterType = 0,
        max_counter: CounterType = std.math.maxInt(CounterType),
        mutex: std.Thread.Mutex = .{},
        waiters: ?*Waiter = null,
        const Waiter = struct {
            next: ?*Waiter,
            tail: *Waiter,
            node: Loop.NextTickNode,
        };

wait()


        const Self = @This();
        pub fn begin(self: *Self, count: CounterType) error{Overflow}!void {
            self.mutex.lock();
            defer self.mutex.unlock();

Test:

basic WaitGroup usage


            const new_counter = try std.math.add(CounterType, self.counter, count);
            if (new_counter > self.max_counter) return error.Overflow;
            self.counter = new_counter;
        }

        pub fn finish(self: *Self, count: CounterType) void {
            var waiters = blk: {
                self.mutex.lock();
                defer self.mutex.unlock();
                self.counter = std.math.sub(CounterType, self.counter, count) catch unreachable;
                if (self.counter == 0) {
                    const temp = self.waiters;
                    self.waiters = null;
                    break :blk temp;
                }
                break :blk null;
            };

            // We don't need to hold the lock to reschedule any potential waiter.
            while (waiters) |w| {
                const temp_w = w;
                waiters = w.next;
                global_event_loop.onNextTick(&temp_w.node);
            }
        }

        pub fn wait(self: *Self) void {
            self.mutex.lock();

            if (self.counter == 0) {
                self.mutex.unlock();
                return;
            }

            var self_waiter: Waiter = undefined;
            self_waiter.node.data = @frame();
            if (self.waiters) |head| {
                head.tail.next = &self_waiter;
                head.tail = &self_waiter;
            } else {
                self.waiters = &self_waiter;
                self_waiter.tail = &self_waiter;
                self_waiter.next = null;
            }
            suspend {
                self.mutex.unlock();
            }
        }
    };
}

test "basic WaitGroup usage" {
    if (!std.io.is_async) return error.SkipZigTest;

    // TODO https://github.com/ziglang/zig/issues/1908
    if (builtin.single_threaded) return error.SkipZigTest;

    // TODO https://github.com/ziglang/zig/issues/3251
    if (builtin.os.tag == .freebsd) return error.SkipZigTest;

    var initial_wg = WaitGroup{};
    var final_wg = WaitGroup{};

    try initial_wg.begin(1);
    try final_wg.begin(1);
    var task_frame = async task(&initial_wg, &final_wg);
    initial_wg.finish(1);
    final_wg.wait();
    await task_frame;
}

fn task(wg_i: *WaitGroup, wg_f: *WaitGroup) void {
    wg_i.wait();
    wg_f.finish(1);
}