zig/lib/std / event/channel.zig

Many producer, many consumer, thread-safe, runtime configurable buffer size. When buffer is empty, consumers suspend and are resumed by producers. When buffer is full, producers suspend and are resumed by consumers.

const std = @import("../std.zig");
const builtin = @import("builtin");
const assert = std.debug.assert;
const testing = std.testing;
const Loop = std.event.Loop;

Channel()

Call deinit to free resources when done. buffer must live until deinit is called. For a zero length buffer, use [0]T{}. TODO https://github.com/ziglang/zig/issues/2765


/// Many producer, many consumer, thread-safe, runtime configurable buffer size.
/// When buffer is empty, consumers suspend and are resumed by producers.
/// When buffer is full, producers suspend and are resumed by consumers.
pub fn Channel(comptime T: type) type {
    return struct {
        getters: std.atomic.Queue(GetNode),
        or_null_queue: std.atomic.Queue(*std.atomic.Queue(GetNode).Node),
        putters: std.atomic.Queue(PutNode),
        get_count: usize,
        put_count: usize,
        dispatch_lock: bool,
        need_dispatch: bool,

init()

Must be called when all calls to put and get have suspended and no more calls occur. This can be omitted if caller can guarantee that the suspended putters and getters do not need to be run to completion. Note that this may leave awaiters hanging.


        // simple fixed size ring buffer
        buffer_nodes: []T,
        buffer_index: usize,
        buffer_len: usize,

deinit()

puts a data item in the channel. The function returns when the value has been added to the buffer, or in the case of a zero size buffer, when the item has been retrieved by a getter. Or when the channel is destroyed.


        const SelfChannel = @This();
        const GetNode = struct {
            tick_node: *Loop.NextTickNode,
            data: Data,

put()

await this function to get an item from the channel. If the buffer is empty, the frame will complete when the next item is put in the channel.


            const Data = union(enum) {
                Normal: Normal,
                OrNull: OrNull,
            };

get()

Get an item from the channel. If the buffer is empty and there are no puts waiting, this returns null.


            const Normal = struct {
                ptr: *T,
            };

getOrNull()


            const OrNull = struct {
                ptr: *?T,
                or_null: *std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node,
            };
        };
        const PutNode = struct {
            data: T,
            tick_node: *Loop.NextTickNode,
        };

Test:

std.event.Channel


        const global_event_loop = Loop.instance orelse
            @compileError("std.event.Channel currently only works with event-based I/O");

Test:

std.event.Channel wraparound


        /// Call `deinit` to free resources when done.
        /// `buffer` must live until `deinit` is called.
        /// For a zero length buffer, use `[0]T{}`.
        /// TODO https://github.com/ziglang/zig/issues/2765
        pub fn init(self: *SelfChannel, buffer: []T) void {
            // The ring buffer implementation only works with power of 2 buffer sizes
            // because of relying on subtracting across zero. For example (0 -% 1) % 10 == 5
            assert(buffer.len == 0 or @popCount(buffer.len) == 1);

            self.* = SelfChannel{
                .buffer_len = 0,
                .buffer_nodes = buffer,
                .buffer_index = 0,
                .dispatch_lock = false,
                .need_dispatch = false,
                .getters = std.atomic.Queue(GetNode).init(),
                .putters = std.atomic.Queue(PutNode).init(),
                .or_null_queue = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).init(),
                .get_count = 0,
                .put_count = 0,
            };
        }

        /// Must be called when all calls to put and get have suspended and no more calls occur.
        /// This can be omitted if caller can guarantee that the suspended putters and getters
        /// do not need to be run to completion. Note that this may leave awaiters hanging.
        pub fn deinit(self: *SelfChannel) void {
            while (self.getters.get()) |get_node| {
                resume get_node.data.tick_node.data;
            }
            while (self.putters.get()) |put_node| {
                resume put_node.data.tick_node.data;
            }
            self.* = undefined;
        }

        /// puts a data item in the channel. The function returns when the value has been added to the
        /// buffer, or in the case of a zero size buffer, when the item has been retrieved by a getter.
        /// Or when the channel is destroyed.
        pub fn put(self: *SelfChannel, data: T) void {
            var my_tick_node = Loop.NextTickNode{ .data = @frame() };
            var queue_node = std.atomic.Queue(PutNode).Node{
                .data = PutNode{
                    .tick_node = &my_tick_node,
                    .data = data,
                },
            };

            suspend {
                self.putters.put(&queue_node);
                _ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst);

                self.dispatch();
            }
        }

        /// await this function to get an item from the channel. If the buffer is empty, the frame will
        /// complete when the next item is put in the channel.
        pub fn get(self: *SelfChannel) callconv(.Async) T {
            // TODO https://github.com/ziglang/zig/issues/2765
            var result: T = undefined;
            var my_tick_node = Loop.NextTickNode{ .data = @frame() };
            var queue_node = std.atomic.Queue(GetNode).Node{
                .data = GetNode{
                    .tick_node = &my_tick_node,
                    .data = GetNode.Data{
                        .Normal = GetNode.Normal{ .ptr = &result },
                    },
                },
            };

            suspend {
                self.getters.put(&queue_node);
                _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);

                self.dispatch();
            }
            return result;
        }

        //pub async fn select(comptime EnumUnion: type, channels: ...) EnumUnion {
        //    assert(@memberCount(EnumUnion) == channels.len); // enum union and channels mismatch
        //    assert(channels.len != 0); // enum unions cannot have 0 fields
        //    if (channels.len == 1) {
        //        const result = await (async channels[0].get() catch unreachable);
        //        return @unionInit(EnumUnion, @memberName(EnumUnion, 0), result);
        //    }
        //}

        /// Get an item from the channel. If the buffer is empty and there are no
        /// puts waiting, this returns `null`.
        pub fn getOrNull(self: *SelfChannel) ?T {
            // TODO integrate this function with named return values
            // so we can get rid of this extra result copy
            var result: ?T = null;
            var my_tick_node = Loop.NextTickNode{ .data = @frame() };
            var or_null_node = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node{ .data = undefined };
            var queue_node = std.atomic.Queue(GetNode).Node{
                .data = GetNode{
                    .tick_node = &my_tick_node,
                    .data = GetNode.Data{
                        .OrNull = GetNode.OrNull{
                            .ptr = &result,
                            .or_null = &or_null_node,
                        },
                    },
                },
            };
            or_null_node.data = &queue_node;

            suspend {
                self.getters.put(&queue_node);
                _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);
                self.or_null_queue.put(&or_null_node);

                self.dispatch();
            }
            return result;
        }

        fn dispatch(self: *SelfChannel) void {
            // set the "need dispatch" flag
            @atomicStore(bool, &self.need_dispatch, true, .SeqCst);

            lock: while (true) {
                // set the lock flag
                if (@atomicRmw(bool, &self.dispatch_lock, .Xchg, true, .SeqCst)) return;

                // clear the need_dispatch flag since we're about to do it
                @atomicStore(bool, &self.need_dispatch, false, .SeqCst);

                while (true) {
                    one_dispatch: {
                        // later we correct these extra subtractions
                        var get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
                        var put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);

                        // transfer self.buffer to self.getters
                        while (self.buffer_len != 0) {
                            if (get_count == 0) break :one_dispatch;

                            const get_node = &self.getters.get().?.data;
                            switch (get_node.data) {
                                GetNode.Data.Normal => |info| {
                                    info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len];
                                },
                                GetNode.Data.OrNull => |info| {
                                    _ = self.or_null_queue.remove(info.or_null);
                                    info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len];
                                },
                            }
                            global_event_loop.onNextTick(get_node.tick_node);
                            self.buffer_len -= 1;

                            get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
                        }

                        // direct transfer self.putters to self.getters
                        while (get_count != 0 and put_count != 0) {
                            const get_node = &self.getters.get().?.data;
                            const put_node = &self.putters.get().?.data;

                            switch (get_node.data) {
                                GetNode.Data.Normal => |info| {
                                    info.ptr.* = put_node.data;
                                },
                                GetNode.Data.OrNull => |info| {
                                    _ = self.or_null_queue.remove(info.or_null);
                                    info.ptr.* = put_node.data;
                                },
                            }
                            global_event_loop.onNextTick(get_node.tick_node);
                            global_event_loop.onNextTick(put_node.tick_node);

                            get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
                            put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);
                        }

                        // transfer self.putters to self.buffer
                        while (self.buffer_len != self.buffer_nodes.len and put_count != 0) {
                            const put_node = &self.putters.get().?.data;

                            self.buffer_nodes[self.buffer_index % self.buffer_nodes.len] = put_node.data;
                            global_event_loop.onNextTick(put_node.tick_node);
                            self.buffer_index +%= 1;
                            self.buffer_len += 1;

                            put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);
                        }
                    }

                    // undo the extra subtractions
                    _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);
                    _ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst);

                    // All the "get or null" functions should resume now.
                    var remove_count: usize = 0;
                    while (self.or_null_queue.get()) |or_null_node| {
                        remove_count += @intFromBool(self.getters.remove(or_null_node.data));
                        global_event_loop.onNextTick(or_null_node.data.data.tick_node);
                    }
                    if (remove_count != 0) {
                        _ = @atomicRmw(usize, &self.get_count, .Sub, remove_count, .SeqCst);
                    }

                    // clear need-dispatch flag
                    if (@atomicRmw(bool, &self.need_dispatch, .Xchg, false, .SeqCst)) continue;

                    assert(@atomicRmw(bool, &self.dispatch_lock, .Xchg, false, .SeqCst));

                    // we have to check again now that we unlocked
                    if (@atomicLoad(bool, &self.need_dispatch, .SeqCst)) continue :lock;

                    return;
                }
            }
        }
    };
}

test "std.event.Channel" {
    if (!std.io.is_async) return error.SkipZigTest;

    // https://github.com/ziglang/zig/issues/1908
    if (builtin.single_threaded) return error.SkipZigTest;

    // https://github.com/ziglang/zig/issues/3251
    if (builtin.os.tag == .freebsd) return error.SkipZigTest;

    var channel: Channel(i32) = undefined;
    channel.init(&[0]i32{});
    defer channel.deinit();

    var handle = async testChannelGetter(&channel);
    var putter = async testChannelPutter(&channel);

    await handle;
    await putter;
}

test "std.event.Channel wraparound" {

    // TODO provide a way to run tests in evented I/O mode
    if (!std.io.is_async) return error.SkipZigTest;

    const channel_size = 2;

    var buf: [channel_size]i32 = undefined;
    var channel: Channel(i32) = undefined;
    channel.init(&buf);
    defer channel.deinit();

    // add items to channel and pull them out until
    // the buffer wraps around, make sure it doesn't crash.
    channel.put(5);
    try testing.expectEqual(@as(i32, 5), channel.get());
    channel.put(6);
    try testing.expectEqual(@as(i32, 6), channel.get());
    channel.put(7);
    try testing.expectEqual(@as(i32, 7), channel.get());
}
fn testChannelGetter(channel: *Channel(i32)) callconv(.Async) void {
    const value1 = channel.get();
    try testing.expect(value1 == 1234);

    const value2 = channel.get();
    try testing.expect(value2 == 4567);

    const value3 = channel.getOrNull();
    try testing.expect(value3 == null);

    var last_put = async testPut(channel, 4444);
    const value4 = channel.getOrNull();
    try testing.expect(value4.? == 4444);
    await last_put;
}
fn testChannelPutter(channel: *Channel(i32)) callconv(.Async) void {
    channel.put(1234);
    channel.put(4567);
}
fn testPut(channel: *Channel(i32), value: i32) callconv(.Async) void {
    channel.put(value);
}