From a1ede8d1a0dcccdd74f0a7aa3e857fb8e2218ec3 Mon Sep 17 00:00:00 2001 From: falsycat Date: Sun, 30 Mar 2025 08:28:36 +0900 Subject: [PATCH] refactor Digraph --- src/hncore/Digraph.zig | 150 ++++++++++++++++++++++++----------------- 1 file changed, 88 insertions(+), 62 deletions(-) diff --git a/src/hncore/Digraph.zig b/src/hncore/Digraph.zig index 7a2c603..24c5bcb 100644 --- a/src/hncore/Digraph.zig +++ b/src/hncore/Digraph.zig @@ -1,7 +1,7 @@ const std = @import("std"); -/// A data type to store relations of nodes in directional-graph. -pub fn Digraph(comptime T: type) type { +/// A data type to store connections of nodes in directional-graph. +pub fn Digraph(comptime T: type, comptime lessThanFn: LessThanFunc(T)) type { return struct { const Node = T; const Conn = struct { from: T, to: T, }; @@ -14,16 +14,16 @@ pub fn Digraph(comptime T: type) type { map: ConnList, /// - pub fn init(alloc: std.mem.Allocator, map_unsorted: []const Conn) !@This() { - var map_sorted = ConnList.init(alloc); - try map_sorted.ensureTotalCapacity(map_unsorted.len); - for (map_unsorted) |conn| { - try map_sorted.append(conn); + pub fn init(alloc: std.mem.Allocator, mapUnsorted: []const Conn) !@This() { + var mapSorted = ConnList.init(alloc); + try mapSorted.ensureTotalCapacity(mapUnsorted.len); + for (mapUnsorted) |conn| { + try mapSorted.append(conn); } - std.mem.sort(Conn, map_sorted.items, {}, compare_conn); + std.mem.sort(Conn, mapSorted.items, {}, compareConn); return .{ - .map = map_sorted, + .map = mapSorted, }; } /// @@ -32,47 +32,47 @@ pub fn Digraph(comptime T: type) type { } /// - pub fn connect_if(self: *@This(), from: T, to: T) !bool { - const begin, const end = self.find_segment(from); - if (self.find_connection_in_segment(from, to, begin, end)) |_| { + pub fn connectIf(self: *@This(), from: T, to: T) !bool { + const begin, const end = self.findSegment(from); + if (self.findConnectionInSegment(from, to, begin, end)) |_| { return false; } else { try self.map.insert(end, Conn { .from = from, .to = to, }); return true; } } - /// Same to `connect_if`, but returns an error if it's already connected. + /// Same to `connectIf`, but returns an error if it's already connected. pub fn connect(self: *@This(), from: T, to: T) !void { - if (!try self.connect_if(from, to)) { + if (!try self.connectIf(from, to)) { return Error.AlreadyConnected; } } /// - pub fn disconnect_if(self: *@This(), from: T, to: T) bool { - const begin, const end = self.find_segment(from); - if (self.find_connection_in_segment(from, to, begin, end)) |idx| { + pub fn disconnectIf(self: *@This(), from: T, to: T) bool { + const begin, const end = self.findSegment(from); + if (self.findConnectionInSegment(from, to, begin, end)) |idx| { _ = self.map.orderedRemove(idx); return true; } else { return false; } } - /// Same to `disconnect_if`, but returns an error if it's not connected. + /// Same to `disconnectIf`, but returns an error if it's not connected. pub fn disconnect(self: *@This(), from: T, to: T) !void { - if (!self.disconnect_if(from, to)) { + if (!self.disconnectIf(from, to)) { return Error.NotConnected; } } /// - pub fn is_connected(self: *const @This(), from: T, to: T) bool { - const begin, const end = self.find_segment(from); - return self.find_connection_in_segment(from, to, begin, end) != null; + pub fn isConnected(self: *const @This(), from: T, to: T) bool { + const begin, const end = self.findSegment(from); + return self.findConnectionInSegment(from, to, begin, end) != null; } - fn find_connection_in_segment(self: *const @This(), from: T, to: T, begin: usize, end: usize) ?usize { + fn findConnectionInSegment(self: *const @This(), from: T, to: T, begin: usize, end: usize) ?usize { for (self.map.items[begin..end], begin..) |v, idx| { if ((v.from == from) and (v.to == to)) { return idx; @@ -80,36 +80,36 @@ pub fn Digraph(comptime T: type) type { } return null; } - fn find_segment(self: *const @This(), from: T) struct { usize, usize } { + fn findSegment(self: *const @This(), from: T) struct { usize, usize } { const n = self.map.items.len; if (n == 0) { return .{ 0, 0, }; } - const base_idx = self.binsearch(from).?; - const base_from = self.map.items[base_idx].from; + const baseIdx = self.binsearch(from).?; + const baseFrom = self.map.items[baseIdx].from; var begin: usize = undefined; var end : usize = undefined; - if (base_from < from) { - begin = base_idx; + if (baseFrom < from) { + begin = baseIdx; while ((begin < n) and (self.map.items[begin].from != from)) { begin += 1; } end = begin; while ((end < n) and (self.map.items[end].from == from)) { end += 1; } - } else if (base_from > from) { - end = base_idx; + } else if (baseFrom > from) { + end = baseIdx; while ((end > 0) and (self.map.items[end-1].from != from)) { end -= 1; } begin = end; while ((begin > 0) and (self.map.items[begin-1].from == from)) { begin -= 1; } } else { - begin = base_idx; + begin = baseIdx; while ((begin > 0) and (self.map.items[begin-1].from == from)) { begin -= 1; } - end = base_idx; + end = baseIdx; while ((end < n) and (self.map.items[end].from == from)) { end += 1; } } return .{ begin, end, }; @@ -137,14 +137,40 @@ pub fn Digraph(comptime T: type) type { } return idx; } - fn compare_conn(_: void, a: Conn, b: Conn) bool { - return a.from < b.from; + fn compareConn(_: void, a: Conn, b: Conn) bool { + return lessThanFn(a.from, b.from); } }; } +/// A type of comparator function for the type T, which is to be passed as an argument of `Digraph()`. +pub fn LessThanFunc(comptime T: type) type { + return fn (lhs: T, rhs: T) bool; +} + +/// Returns a lessThanFunc for the comparable type T. +pub fn lessThanFuncFor(comptime T: type) LessThanFunc(T) { + return struct { + fn inner(lhs: T, rhs: T) bool { + if (@typeInfo(T) == .pointer) { + return @intFromPtr(lhs) < @intFromPtr(rhs); + } else { + return lhs < rhs; + } + } + }.inner; +} + +test "compile check for various types" { + _ = Digraph(u8, lessThanFuncFor(u8)); + _ = Digraph(u16, lessThanFuncFor(u16)); + _ = Digraph(i8, lessThanFuncFor(i8)); + _ = Digraph(i16, lessThanFuncFor(i16)); + _ = Digraph(*i8, lessThanFuncFor(*i8)); + _ = Digraph(*anyopaque, lessThanFuncFor(*anyopaque)); +} test "check if connected" { - const Sut = Digraph(u8); + const Sut = Digraph(u8, lessThanFuncFor(u8)); const map = [_]Sut.Conn { .{ .from = 3, .to = 0, }, @@ -155,39 +181,39 @@ test "check if connected" { var sut = try Sut.init(std.testing.allocator, map[0..]); defer sut.deinit(); - try std.testing.expect(sut.is_connected(0, 1)); - try std.testing.expect(!sut.is_connected(1, 0)); + try std.testing.expect(sut.isConnected(0, 1)); + try std.testing.expect(!sut.isConnected(1, 0)); - try std.testing.expect(sut.is_connected(1, 3)); - try std.testing.expect(!sut.is_connected(3, 1)); + try std.testing.expect(sut.isConnected(1, 3)); + try std.testing.expect(!sut.isConnected(3, 1)); - try std.testing.expect(sut.is_connected(3, 0)); - try std.testing.expect(!sut.is_connected(0, 3)); + try std.testing.expect(sut.isConnected(3, 0)); + try std.testing.expect(!sut.isConnected(0, 3)); - try std.testing.expect(!sut.is_connected(0, 2)); - try std.testing.expect(!sut.is_connected(2, 0)); + try std.testing.expect(!sut.isConnected(0, 2)); + try std.testing.expect(!sut.isConnected(2, 0)); - try std.testing.expect(!sut.is_connected(1, 2)); - try std.testing.expect(!sut.is_connected(2, 1)); + try std.testing.expect(!sut.isConnected(1, 2)); + try std.testing.expect(!sut.isConnected(2, 1)); } test "make new connection" { - const Sut = Digraph(u8); + const Sut = Digraph(u8, lessThanFuncFor(u8)); var sut = try Sut.init(std.testing.allocator, &.{}); defer sut.deinit(); - try std.testing.expect(try sut.connect_if(2, 1)); + try std.testing.expect(try sut.connectIf(2, 1)); - try std.testing.expect(sut.is_connected(2, 1)); - try std.testing.expect(!sut.is_connected(1, 2)); + try std.testing.expect(sut.isConnected(2, 1)); + try std.testing.expect(!sut.isConnected(1, 2)); try sut.connect(3, 1); - try std.testing.expect(sut.is_connected(3, 1)); - try std.testing.expect(!sut.is_connected(1, 3)); + try std.testing.expect(sut.isConnected(3, 1)); + try std.testing.expect(!sut.isConnected(1, 3)); } test "making an existing connection fails" { - const Sut = Digraph(u8); + const Sut = Digraph(u8, lessThanFuncFor(u8)); const map = [_]Sut.Conn { .{ .from = 0, .to = 1, }, @@ -195,11 +221,11 @@ test "making an existing connection fails" { var sut = try Sut.init(std.testing.allocator, map[0..]); defer sut.deinit(); - try std.testing.expect(!try sut.connect_if(0, 1)); + try std.testing.expect(!try sut.connectIf(0, 1)); try std.testing.expectError(Sut.Error.AlreadyConnected, sut.connect(0, 1)); } test "disconnect an existing connection" { - const Sut = Digraph(u8); + const Sut = Digraph(u8, lessThanFuncFor(u8)); const map = [_]Sut.Conn { .{ .from = 0, .to = 1, }, @@ -208,23 +234,23 @@ test "disconnect an existing connection" { var sut = try Sut.init(std.testing.allocator, map[0..]); defer sut.deinit(); - try std.testing.expect(sut.disconnect_if(0, 1)); - try std.testing.expect(!sut.is_connected(0, 1)); + try std.testing.expect(sut.disconnectIf(0, 1)); + try std.testing.expect(!sut.isConnected(0, 1)); try sut.disconnect(2, 3); - try std.testing.expect(!sut.is_connected(2, 3)); + try std.testing.expect(!sut.isConnected(2, 3)); } test "disconnecting a missing connection fails" { - const Sut = Digraph(u8); + const Sut = Digraph(u8, lessThanFuncFor(u8)); var sut = try Sut.init(std.testing.allocator, &.{}); defer sut.deinit(); - try std.testing.expect(!sut.disconnect_if(0, 1)); + try std.testing.expect(!sut.disconnectIf(0, 1)); try std.testing.expectError(Sut.Error.NotConnected, sut.disconnect(1, 0)); } test "chaotic operation" { - const Sut = Digraph(u16); + const Sut = Digraph(u16, lessThanFuncFor(u16)); var sut = try Sut.init(std.testing.allocator, &.{}); defer sut.deinit(); @@ -240,10 +266,10 @@ test "chaotic operation" { } for (0..N/2) |v| { const x: Sut.Node = @intCast(v); - try std.testing.expect(sut.is_connected(x*%7, x*%13)); + try std.testing.expect(sut.isConnected(x*%7, x*%13)); } for (N/2..N) |v| { const x: Sut.Node = @intCast(v); - try std.testing.expect(!sut.is_connected(x*%7, x*%13)); + try std.testing.expect(!sut.isConnected(x*%7, x*%13)); } }