Skip to content

Commit

Permalink
Air: direct representation of ranges in switch cases
Browse files Browse the repository at this point in the history
This commit modifies the representation of the AIR `switch_br`
instruction to represent ranges in cases. Previously, Sema emitted
different AIR in the case of a range, where the `else` branch of the
`switch_br` contained a simple `cond_br` for each such case which did a
simple range check (`x > a and x < b`). Not only does this add
complexity to Sema, which -- as our secondary bottleneck -- we would
like to keep as small as possible, but it also gets in the way of the
implementation of ziglang#8220. This proposal turns certain `switch` statements
into a looping construct, and for optimization purposes, we want to
lower this to AIR fairly directly (i.e. without involving a `loop`
instruction). That means we would ideally like a single instruction to
represent the entire `switch` statement, so that we can dispatch back to
it with a different operand as in ziglang#8220. This is not really possible to
do correctly under the status quo system.

For now, the actual lowering of `switch` is identical for the LLVM and C
backends. This commit contains a TODO which temporarily regresseses all
remaining self-hosted backends in the presence of switch case ranges.
This functionality will be restored for at least the x86_64 backend
before merge of this branch.
  • Loading branch information
mlugg committed May 3, 2024
1 parent d27e73c commit 0d99a13
Show file tree
Hide file tree
Showing 12 changed files with 297 additions and 233 deletions.
8 changes: 5 additions & 3 deletions src/Air.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1132,17 +1132,19 @@ pub const CondBr = struct {
};

/// Trailing:
/// * 0. `Case` for each `cases_len`
/// * 1. the else body, according to `else_body_len`.
/// * 0. case: Case // for each `cases_len`.
/// * 1. else_body_inst: Inst.Index // for each `else_body_len`.
pub const SwitchBr = struct {
cases_len: u32,
else_body_len: u32,

/// Trailing:
/// * item: Inst.Ref // for each `items_len`.
/// * instruction index for each `body_len`.
/// * { range_start: Inst.Ref, range_end: Inst.Ref } // for each `ranges_len`.
/// * body_inst: Inst.Index // for each `body_len`.
pub const Case = struct {
items_len: u32,
ranges_len: u32,
body_len: u32,
};
};
Expand Down
10 changes: 6 additions & 4 deletions src/Liveness.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1681,8 +1681,9 @@ fn analyzeInstSwitchBr(
var air_extra_index: usize = switch_br.end;
for (0..ncases) |_| {
const case = a.air.extraData(Air.SwitchBr.Case, air_extra_index);
const case_body: []const Air.Inst.Index = @ptrCast(a.air.extra[case.end + case.data.items_len ..][0..case.data.body_len]);
air_extra_index = case.end + case.data.items_len + case_body.len;
air_extra_index = case.end + case.data.items_len + 2 * case.data.ranges_len;
const case_body: []const Air.Inst.Index = @ptrCast(a.air.extra[air_extra_index..][0..case.data.body_len]);
air_extra_index += case_body.len;
try analyzeBody(a, pass, data, case_body);
}
{ // else
Expand All @@ -1707,8 +1708,9 @@ fn analyzeInstSwitchBr(
var air_extra_index: usize = switch_br.end;
for (case_live_sets[0..ncases]) |*live_set| {
const case = a.air.extraData(Air.SwitchBr.Case, air_extra_index);
const case_body: []const Air.Inst.Index = @ptrCast(a.air.extra[case.end + case.data.items_len ..][0..case.data.body_len]);
air_extra_index = case.end + case.data.items_len + case_body.len;
air_extra_index = case.end + case.data.items_len + 2 * case.data.ranges_len;
const case_body: []const Air.Inst.Index = @ptrCast(a.air.extra[air_extra_index..][0..case.data.body_len]);
air_extra_index += case_body.len;
try analyzeBody(a, pass, data, case_body);
live_set.* = data.live_set.move();
}
Expand Down
9 changes: 3 additions & 6 deletions src/Liveness/Verify.zig
Original file line number Diff line number Diff line change
Expand Up @@ -526,12 +526,9 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void {

while (case_i < switch_br.data.cases_len) : (case_i += 1) {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
const items = @as(
[]const Air.Inst.Ref,
@ptrCast(self.air.extra[case.end..][0..case.data.items_len]),
);
const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]);
extra_index = case.end + items.len + case_body.len;
extra_index = case.end + case.data.items_len + case.data.ranges_len * 2;
const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra_index..][0..case.data.body_len]);
extra_index += case_body.len;

self.live.deinit(self.gpa);
self.live = try live.clone(self.gpa);
Expand Down
334 changes: 138 additions & 196 deletions src/Sema.zig

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/arch/aarch64/CodeGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5091,7 +5091,8 @@ fn airSwitch(self: *Self, inst: Air.Inst.Index) !void {
var case_i: u32 = 0;
while (case_i < switch_br.data.cases_len) : (case_i += 1) {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
const items = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[case.end..][0..case.data.items_len]));
if (case.data.ranges_len > 0) return self.fail("TODO: switch with ranges", .{});
const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[case.end..][0..case.data.items_len]);
assert(items.len > 0);
const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]);
extra_index = case.end + items.len + case_body.len;
Expand Down
1 change: 1 addition & 0 deletions src/arch/arm/CodeGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5022,6 +5022,7 @@ fn airSwitch(self: *Self, inst: Air.Inst.Index) !void {
var case_i: u32 = 0;
while (case_i < switch_br.data.cases_len) : (case_i += 1) {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
if (case.data.ranges_len > 0) return self.fail("TODO: switch with ranges", .{});
const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[case.end..][0..case.data.items_len]);
assert(items.len > 0);
const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]);
Expand Down
1 change: 1 addition & 0 deletions src/arch/wasm/CodeGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4015,6 +4015,7 @@ fn airSwitchBr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
var highest_maybe: ?i32 = null;
while (case_i < switch_br.data.cases_len) : (case_i += 1) {
const case = func.air.extraData(Air.SwitchBr.Case, extra_index);
if (case.data.ranges_len != 0) return func.fail("TODO: switch with ranges", .{});
const items: []const Air.Inst.Ref = @ptrCast(func.air.extra[case.end..][0..case.data.items_len]);
const case_body: []const Air.Inst.Index = @ptrCast(func.air.extra[case.end + items.len ..][0..case.data.body_len]);
extra_index = case.end + items.len + case_body.len;
Expand Down
1 change: 1 addition & 0 deletions src/arch/x86_64/CodeGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -13464,6 +13464,7 @@ fn airSwitchBr(self: *Self, inst: Air.Inst.Index) !void {

while (case_i < switch_br.data.cases_len) : (case_i += 1) {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
if (case.data.ranges_len > 0) return self.fail("TODO: switch with ranges", .{});
const items: []const Air.Inst.Ref =
@ptrCast(self.air.extra[case.end..][0..case.data.items_len]);
const case_body: []const Air.Inst.Index =
Expand Down
73 changes: 57 additions & 16 deletions src/codegen/c.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5049,14 +5049,16 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue {
const liveness = try f.liveness.getSwitchBr(gpa, inst, switch_br.data.cases_len + 1);
defer gpa.free(liveness.deaths);

// On the final iteration we do not need to fix any state. This is because, like in the `else`
// branch of a `cond_br`, our parent has to do it for this entire body anyway.
const last_case_i = switch_br.data.cases_len - @intFromBool(switch_br.data.else_body_len == 0);

var any_range_cases = false;
var extra_index: usize = switch_br.end;
for (0..switch_br.data.cases_len) |case_i| {
const case = f.air.extraData(Air.SwitchBr.Case, extra_index);
const items = @as([]const Air.Inst.Ref, @ptrCast(f.air.extra[case.end..][0..case.data.items_len]));
if (case.data.ranges_len != 0) {
any_range_cases = true;
extra_index = case.end + case.data.items_len + case.data.ranges_len * 2 + case.data.body_len;
continue;
}
const items: []const Air.Inst.Ref = @ptrCast(f.air.extra[case.end..][0..case.data.items_len]);
const case_body: []const Air.Inst.Index =
@ptrCast(f.air.extra[case.end + items.len ..][0..case.data.body_len]);
extra_index = case.end + case.data.items_len + case_body.len;
Expand All @@ -5079,30 +5081,69 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue {
}
try writer.writeByte(' ');

if (case_i != last_case_i) {
try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, false);
} else {
for (liveness.deaths[case_i]) |death| {
try die(f, inst, death.toRef());
}
try genBody(f, case_body);
}
try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, false);

// The case body must be noreturn so we don't need to insert a break.
}

const else_body: []const Air.Inst.Index = @ptrCast(f.air.extra[extra_index..][0..switch_br.data.else_body_len]);
try f.object.indent_writer.insertNewline();

try writer.writeAll("default: ");
if (any_range_cases) {
// We will iterate the cases again to handle those with ranges, and generate
// code using conditionals rather than switch cases for such cases.
extra_index = switch_br.end;
for (0..switch_br.data.cases_len) |case_i| {
const case = f.air.extraData(Air.SwitchBr.Case, extra_index);
if (case.data.ranges_len == 0) {
// No ranges, so handled above - skip this case.
extra_index = case.end + case.data.items_len + case.data.body_len;
continue;
}
extra_index = case.end;
const items: []const Air.Inst.Ref = @ptrCast(f.air.extra[extra_index..][0..case.data.items_len]);
extra_index += items.len;
// TODO: this can be written more cleanly once Sema allows @ptrCast on slices where the length changes.
const ranges: []const [2]Air.Inst.Ref = @as([*]const [2]Air.Inst.Ref, @ptrCast(f.air.extra[extra_index..].ptr))[0..case.data.ranges_len];
extra_index += ranges.len * 2;
const case_body: []const Air.Inst.Index = @ptrCast(f.air.extra[extra_index..][0..case.data.body_len]);
extra_index += case_body.len;
try writer.writeAll("if (");
for (items, 0..) |item, item_i| {
if (item_i != 0) try writer.writeAll(" || ");
try f.writeCValue(writer, condition, .Other);
try writer.writeAll(" == ");
try f.object.dg.renderValue(writer, (try f.air.value(item, zcu)).?, .Other);
}
for (ranges, 0..) |range, range_i| {
if (items.len != 0 or range_i != 0) try writer.writeAll(" || ");
// "(x >= lower && x <= upper)"
try writer.writeByte('(');
try f.writeCValue(writer, condition, .Other);
try writer.writeAll(" >= ");
try f.object.dg.renderValue(writer, (try f.air.value(range[0], zcu)).?, .Other);
try writer.writeAll(" && ");
try f.writeCValue(writer, condition, .Other);
try writer.writeAll(" <= ");
try f.object.dg.renderValue(writer, (try f.air.value(range[1], zcu)).?, .Other);
try writer.writeByte(')');
}
try writer.writeAll(") ");
try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, false);
}
}
if (else_body.len > 0) {
// Note that this must be the last case (i.e. the `last_case_i` case was not hit above)
// Note that this must be the last case, so we do not need to use `caseBodyResolveState` since
// the parent block will do it (because the case body is noreturn).
for (liveness.deaths[liveness.deaths.len - 1]) |death| {
try die(f, inst, death.toRef());
}
try writer.writeAll("default: ");
try genBody(f, else_body);
} else {
try writer.writeAll("default: zig_unreachable();");
try writer.writeAll("zig_unreachable();");
}

try f.object.indent_writer.insertNewline();

f.object.indent_writer.popIndent();
Expand Down
72 changes: 68 additions & 4 deletions src/codegen/llvm.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6096,10 +6096,17 @@ pub const FuncGen = struct {
cond;

var extra_index: usize = switch_br.end;
var case_i: u32 = 0;
var any_range_cases = false;
var llvm_cases_len: u32 = 0;
while (case_i < switch_br.data.cases_len) : (case_i += 1) {
for (0..switch_br.data.cases_len) |_| {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
if (case.data.ranges_len != 0) {
// TODO: for ranges, we could still define any scalar cases in the same prong within
// the switch, just directing it to the same bb as the range check.
any_range_cases = true;
extra_index = case.end + case.data.items_len + case.data.ranges_len * 2 + case.data.body_len;
continue;
}
const items: []const Air.Inst.Ref =
@ptrCast(self.air.extra[case.end..][0..case.data.items_len]);
const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len];
Expand All @@ -6112,9 +6119,12 @@ pub const FuncGen = struct {
defer wip_switch.finish(&self.wip);

extra_index = switch_br.end;
case_i = 0;
while (case_i < switch_br.data.cases_len) : (case_i += 1) {
for (0..switch_br.data.cases_len) |_| {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
if (case.data.ranges_len != 0) {
extra_index = case.end + case.data.items_len + case.data.ranges_len * 2 + case.data.body_len;
continue;
}
const items: []const Air.Inst.Ref =
@ptrCast(self.air.extra[case.end..][0..case.data.items_len]);
const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]);
Expand All @@ -6137,6 +6147,60 @@ pub const FuncGen = struct {

self.wip.cursor = .{ .block = else_block };
const else_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra_index..][0..switch_br.data.else_body_len]);
if (any_range_cases) {
// We will iterate the cases again to handle those with ranges, and generate
// code using conditionals rather than switch cases for such cases.
const cond_ty = self.typeOf(pl_op.operand);
extra_index = switch_br.end;
for (0..switch_br.data.cases_len) |_| {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
if (case.data.ranges_len == 0) {
// No ranges, so handled above - skip this case.
extra_index = case.end + case.data.items_len + case.data.body_len;
continue;
}
extra_index = case.end;
const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[extra_index..][0..case.data.items_len]);
extra_index += items.len;
// TODO: this can be written more cleanly once Sema allows @ptrCast on slices where the length changes.
const ranges: []const [2]Air.Inst.Ref = @as([*]const [2]Air.Inst.Ref, @ptrCast(self.air.extra[extra_index..].ptr))[0..case.data.ranges_len];
extra_index += ranges.len * 2;
const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra_index..][0..case.data.body_len]);
extra_index += case_body.len;

var range_cond: ?Builder.Value = null;

for (items) |item| {
const llvm_item = try self.resolveInst(item);
const cond_part = try self.cmp(.normal, .eq, cond_ty, cond, llvm_item);
if (range_cond) |old| {
range_cond = try self.wip.bin(.@"or", old, cond_part, "");
} else range_cond = cond_part;
}
for (ranges) |range| {
const llvm_min = try self.resolveInst(range[0]);
const llvm_max = try self.resolveInst(range[1]);
const cond_part = try self.wip.bin(
.@"and",
try self.cmp(.normal, .gte, cond_ty, cond, llvm_min),
try self.cmp(.normal, .lte, cond_ty, cond, llvm_max),
"",
);
if (range_cond) |old| {
range_cond = try self.wip.bin(.@"or", old, cond_part, "");
} else range_cond = cond_part;
}

const range_case_block = try self.wip.block(1, "RangeCase");
const range_else_block = try self.wip.block(1, "RangeDefault");

_ = try self.wip.brCond(range_cond.?, range_case_block, range_else_block);

self.wip.cursor = .{ .block = range_case_block };
try self.genBodyDebugScope(null, case_body);
self.wip.cursor = .{ .block = range_else_block };
}
}
if (else_body.len != 0) {
try self.genBodyDebugScope(null, else_body);
} else {
Expand Down
1 change: 1 addition & 0 deletions src/codegen/spirv.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5456,6 +5456,7 @@ const DeclGen = struct {
var num_conditions: u32 = 0;
for (0..num_cases) |_| {
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
if (case.data.ranges_len != 0) return self.fail("TODO: switch with ranges", .{});
const case_body = self.air.extra[case.end + case.data.items_len ..][0..case.data.body_len];
extra_index = case.end + case.data.items_len + case_body.len;
num_conditions += case.data.items_len;
Expand Down
17 changes: 14 additions & 3 deletions src/print_air.zig
Original file line number Diff line number Diff line change
Expand Up @@ -843,15 +843,26 @@ const Writer = struct {

while (case_i < switch_br.data.cases_len) : (case_i += 1) {
const case = w.air.extraData(Air.SwitchBr.Case, extra_index);
const items = @as([]const Air.Inst.Ref, @ptrCast(w.air.extra[case.end..][0..case.data.items_len]));
const case_body: []const Air.Inst.Index = @ptrCast(w.air.extra[case.end + items.len ..][0..case.data.body_len]);
extra_index = case.end + case.data.items_len + case_body.len;
extra_index = case.end;
const items: []const Air.Inst.Ref = @ptrCast(w.air.extra[extra_index..][0..case.data.items_len]);
extra_index += items.len;
// TODO: this can be written more cleanly once Sema allows @ptrCast on slices where the length changes.
const ranges: []const [2]Air.Inst.Ref = @as([*]const [2]Air.Inst.Ref, @ptrCast(w.air.extra[extra_index..].ptr))[0..case.data.ranges_len];
extra_index += case.data.ranges_len * 2;
const case_body: []const Air.Inst.Index = @ptrCast(w.air.extra[extra_index..][0..case.data.body_len]);
extra_index += case_body.len;

try s.writeAll(", [");
for (items, 0..) |item, item_i| {
if (item_i != 0) try s.writeAll(", ");
try w.writeInstRef(s, item, false);
}
for (ranges, 0..) |range, range_i| {
if (items.len != 0 or range_i != 0) try s.writeAll(", ");
try w.writeInstRef(s, range[0], false);
try s.writeAll("..");
try w.writeInstRef(s, range[1], false);
}
try s.writeAll("] => {\n");
w.indent += 2;

Expand Down

0 comments on commit 0d99a13

Please sign in to comment.