diff --git a/src/cmd/compile/internal/ssa/_gen/generic.rules b/src/cmd/compile/internal/ssa/_gen/generic.rules index d72824c4bf..6ede0fb200 100644 --- a/src/cmd/compile/internal/ssa/_gen/generic.rules +++ b/src/cmd/compile/internal/ssa/_gen/generic.rules @@ -518,8 +518,17 @@ (Leq32 (Const32 [0]) (Rsh32Ux64 _ (Const64 [c]))) && c > 0 => (ConstBool [true]) (Leq64 (Const64 [0]) (Rsh64Ux64 _ (Const64 [c]))) && c > 0 => (ConstBool [true]) +// prefer equalities with zero (Less(64|32|16|8) (Const(64|32|16|8) [0]) x) && isNonNegative(x) => (Neq(64|32|16|8) (Const(64|32|16|8) [0]) x) (Less(64|32|16|8) x (Const(64|32|16|8) [1])) && isNonNegative(x) => (Eq(64|32|16|8) (Const(64|32|16|8) [0]) x) +(Less(64|32|16|8)U x (Const(64|32|16|8) [1])) => (Eq(64|32|16|8) (Const(64|32|16|8) [0]) x) +(Leq(64|32|16|8)U (Const(64|32|16|8) [1]) x) => (Neq(64|32|16|8) (Const(64|32|16|8) [0]) x) + +// prefer comparisons with zero +(Less(64|32|16|8) x (Const(64|32|16|8) [1])) => (Leq(64|32|16|8) x (Const(64|32|16|8) [0])) +(Leq(64|32|16|8) x (Const(64|32|16|8) [-1])) => (Less(64|32|16|8) x (Const(64|32|16|8) [0])) +(Leq(64|32|16|8) (Const(64|32|16|8) [1]) x) => (Less(64|32|16|8) (Const(64|32|16|8) [0]) x) +(Less(64|32|16|8) (Const(64|32|16|8) [-1]) x) => (Leq(64|32|16|8) (Const(64|32|16|8) [0]) x) // constant floating point comparisons (Eq32F (Const32F [c]) (Const32F [d])) => (ConstBool [c == d]) diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go index 49a721b5f2..ceb52fa5fd 100644 --- a/src/cmd/compile/internal/ssa/rewritegeneric.go +++ b/src/cmd/compile/internal/ssa/rewritegeneric.go @@ -11137,6 +11137,7 @@ func rewriteValuegeneric_OpIsSliceInBounds(v *Value) bool { func rewriteValuegeneric_OpLeq16(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Leq16 (Const16 [c]) (Const16 [d])) // result: (ConstBool [c <= d]) for { @@ -11196,11 +11197,46 @@ func rewriteValuegeneric_OpLeq16(v *Value) bool { v.AuxInt = boolToAuxInt(true) return true } + // match: (Leq16 x (Const16 [-1])) + // result: (Less16 x (Const16 [0])) + for { + x := v_0 + if v_1.Op != OpConst16 { + break + } + t := v_1.Type + if auxIntToInt16(v_1.AuxInt) != -1 { + break + } + v.reset(OpLess16) + v0 := b.NewValue0(v.Pos, OpConst16, t) + v0.AuxInt = int16ToAuxInt(0) + v.AddArg2(x, v0) + return true + } + // match: (Leq16 (Const16 [1]) x) + // result: (Less16 (Const16 [0]) x) + for { + if v_0.Op != OpConst16 { + break + } + t := v_0.Type + if auxIntToInt16(v_0.AuxInt) != 1 { + break + } + x := v_1 + v.reset(OpLess16) + v0 := b.NewValue0(v.Pos, OpConst16, t) + v0.AuxInt = int16ToAuxInt(0) + v.AddArg2(v0, x) + return true + } return false } func rewriteValuegeneric_OpLeq16U(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Leq16U (Const16 [c]) (Const16 [d])) // result: (ConstBool [uint16(c) <= uint16(d)]) for { @@ -11216,6 +11252,23 @@ func rewriteValuegeneric_OpLeq16U(v *Value) bool { v.AuxInt = boolToAuxInt(uint16(c) <= uint16(d)) return true } + // match: (Leq16U (Const16 [1]) x) + // result: (Neq16 (Const16 [0]) x) + for { + if v_0.Op != OpConst16 { + break + } + t := v_0.Type + if auxIntToInt16(v_0.AuxInt) != 1 { + break + } + x := v_1 + v.reset(OpNeq16) + v0 := b.NewValue0(v.Pos, OpConst16, t) + v0.AuxInt = int16ToAuxInt(0) + v.AddArg2(v0, x) + return true + } // match: (Leq16U (Const16 [0]) _) // result: (ConstBool [true]) for { @@ -11231,6 +11284,7 @@ func rewriteValuegeneric_OpLeq16U(v *Value) bool { func rewriteValuegeneric_OpLeq32(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Leq32 (Const32 [c]) (Const32 [d])) // result: (ConstBool [c <= d]) for { @@ -11290,6 +11344,40 @@ func rewriteValuegeneric_OpLeq32(v *Value) bool { v.AuxInt = boolToAuxInt(true) return true } + // match: (Leq32 x (Const32 [-1])) + // result: (Less32 x (Const32 [0])) + for { + x := v_0 + if v_1.Op != OpConst32 { + break + } + t := v_1.Type + if auxIntToInt32(v_1.AuxInt) != -1 { + break + } + v.reset(OpLess32) + v0 := b.NewValue0(v.Pos, OpConst32, t) + v0.AuxInt = int32ToAuxInt(0) + v.AddArg2(x, v0) + return true + } + // match: (Leq32 (Const32 [1]) x) + // result: (Less32 (Const32 [0]) x) + for { + if v_0.Op != OpConst32 { + break + } + t := v_0.Type + if auxIntToInt32(v_0.AuxInt) != 1 { + break + } + x := v_1 + v.reset(OpLess32) + v0 := b.NewValue0(v.Pos, OpConst32, t) + v0.AuxInt = int32ToAuxInt(0) + v.AddArg2(v0, x) + return true + } return false } func rewriteValuegeneric_OpLeq32F(v *Value) bool { @@ -11315,6 +11403,7 @@ func rewriteValuegeneric_OpLeq32F(v *Value) bool { func rewriteValuegeneric_OpLeq32U(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Leq32U (Const32 [c]) (Const32 [d])) // result: (ConstBool [uint32(c) <= uint32(d)]) for { @@ -11330,6 +11419,23 @@ func rewriteValuegeneric_OpLeq32U(v *Value) bool { v.AuxInt = boolToAuxInt(uint32(c) <= uint32(d)) return true } + // match: (Leq32U (Const32 [1]) x) + // result: (Neq32 (Const32 [0]) x) + for { + if v_0.Op != OpConst32 { + break + } + t := v_0.Type + if auxIntToInt32(v_0.AuxInt) != 1 { + break + } + x := v_1 + v.reset(OpNeq32) + v0 := b.NewValue0(v.Pos, OpConst32, t) + v0.AuxInt = int32ToAuxInt(0) + v.AddArg2(v0, x) + return true + } // match: (Leq32U (Const32 [0]) _) // result: (ConstBool [true]) for { @@ -11345,6 +11451,7 @@ func rewriteValuegeneric_OpLeq32U(v *Value) bool { func rewriteValuegeneric_OpLeq64(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Leq64 (Const64 [c]) (Const64 [d])) // result: (ConstBool [c <= d]) for { @@ -11404,6 +11511,40 @@ func rewriteValuegeneric_OpLeq64(v *Value) bool { v.AuxInt = boolToAuxInt(true) return true } + // match: (Leq64 x (Const64 [-1])) + // result: (Less64 x (Const64 [0])) + for { + x := v_0 + if v_1.Op != OpConst64 { + break + } + t := v_1.Type + if auxIntToInt64(v_1.AuxInt) != -1 { + break + } + v.reset(OpLess64) + v0 := b.NewValue0(v.Pos, OpConst64, t) + v0.AuxInt = int64ToAuxInt(0) + v.AddArg2(x, v0) + return true + } + // match: (Leq64 (Const64 [1]) x) + // result: (Less64 (Const64 [0]) x) + for { + if v_0.Op != OpConst64 { + break + } + t := v_0.Type + if auxIntToInt64(v_0.AuxInt) != 1 { + break + } + x := v_1 + v.reset(OpLess64) + v0 := b.NewValue0(v.Pos, OpConst64, t) + v0.AuxInt = int64ToAuxInt(0) + v.AddArg2(v0, x) + return true + } return false } func rewriteValuegeneric_OpLeq64F(v *Value) bool { @@ -11429,6 +11570,7 @@ func rewriteValuegeneric_OpLeq64F(v *Value) bool { func rewriteValuegeneric_OpLeq64U(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Leq64U (Const64 [c]) (Const64 [d])) // result: (ConstBool [uint64(c) <= uint64(d)]) for { @@ -11444,6 +11586,23 @@ func rewriteValuegeneric_OpLeq64U(v *Value) bool { v.AuxInt = boolToAuxInt(uint64(c) <= uint64(d)) return true } + // match: (Leq64U (Const64 [1]) x) + // result: (Neq64 (Const64 [0]) x) + for { + if v_0.Op != OpConst64 { + break + } + t := v_0.Type + if auxIntToInt64(v_0.AuxInt) != 1 { + break + } + x := v_1 + v.reset(OpNeq64) + v0 := b.NewValue0(v.Pos, OpConst64, t) + v0.AuxInt = int64ToAuxInt(0) + v.AddArg2(v0, x) + return true + } // match: (Leq64U (Const64 [0]) _) // result: (ConstBool [true]) for { @@ -11459,6 +11618,7 @@ func rewriteValuegeneric_OpLeq64U(v *Value) bool { func rewriteValuegeneric_OpLeq8(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Leq8 (Const8 [c]) (Const8 [d])) // result: (ConstBool [c <= d]) for { @@ -11518,11 +11678,46 @@ func rewriteValuegeneric_OpLeq8(v *Value) bool { v.AuxInt = boolToAuxInt(true) return true } + // match: (Leq8 x (Const8 [-1])) + // result: (Less8 x (Const8 [0])) + for { + x := v_0 + if v_1.Op != OpConst8 { + break + } + t := v_1.Type + if auxIntToInt8(v_1.AuxInt) != -1 { + break + } + v.reset(OpLess8) + v0 := b.NewValue0(v.Pos, OpConst8, t) + v0.AuxInt = int8ToAuxInt(0) + v.AddArg2(x, v0) + return true + } + // match: (Leq8 (Const8 [1]) x) + // result: (Less8 (Const8 [0]) x) + for { + if v_0.Op != OpConst8 { + break + } + t := v_0.Type + if auxIntToInt8(v_0.AuxInt) != 1 { + break + } + x := v_1 + v.reset(OpLess8) + v0 := b.NewValue0(v.Pos, OpConst8, t) + v0.AuxInt = int8ToAuxInt(0) + v.AddArg2(v0, x) + return true + } return false } func rewriteValuegeneric_OpLeq8U(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Leq8U (Const8 [c]) (Const8 [d])) // result: (ConstBool [ uint8(c) <= uint8(d)]) for { @@ -11538,6 +11733,23 @@ func rewriteValuegeneric_OpLeq8U(v *Value) bool { v.AuxInt = boolToAuxInt(uint8(c) <= uint8(d)) return true } + // match: (Leq8U (Const8 [1]) x) + // result: (Neq8 (Const8 [0]) x) + for { + if v_0.Op != OpConst8 { + break + } + t := v_0.Type + if auxIntToInt8(v_0.AuxInt) != 1 { + break + } + x := v_1 + v.reset(OpNeq8) + v0 := b.NewValue0(v.Pos, OpConst8, t) + v0.AuxInt = int8ToAuxInt(0) + v.AddArg2(v0, x) + return true + } // match: (Leq8U (Const8 [0]) _) // result: (ConstBool [true]) for { @@ -11608,11 +11820,46 @@ func rewriteValuegeneric_OpLess16(v *Value) bool { v.AddArg2(v0, x) return true } + // match: (Less16 x (Const16 [1])) + // result: (Leq16 x (Const16 [0])) + for { + x := v_0 + if v_1.Op != OpConst16 { + break + } + t := v_1.Type + if auxIntToInt16(v_1.AuxInt) != 1 { + break + } + v.reset(OpLeq16) + v0 := b.NewValue0(v.Pos, OpConst16, t) + v0.AuxInt = int16ToAuxInt(0) + v.AddArg2(x, v0) + return true + } + // match: (Less16 (Const16 [-1]) x) + // result: (Leq16 (Const16 [0]) x) + for { + if v_0.Op != OpConst16 { + break + } + t := v_0.Type + if auxIntToInt16(v_0.AuxInt) != -1 { + break + } + x := v_1 + v.reset(OpLeq16) + v0 := b.NewValue0(v.Pos, OpConst16, t) + v0.AuxInt = int16ToAuxInt(0) + v.AddArg2(v0, x) + return true + } return false } func rewriteValuegeneric_OpLess16U(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Less16U (Const16 [c]) (Const16 [d])) // result: (ConstBool [uint16(c) < uint16(d)]) for { @@ -11628,6 +11875,23 @@ func rewriteValuegeneric_OpLess16U(v *Value) bool { v.AuxInt = boolToAuxInt(uint16(c) < uint16(d)) return true } + // match: (Less16U x (Const16 [1])) + // result: (Eq16 (Const16 [0]) x) + for { + x := v_0 + if v_1.Op != OpConst16 { + break + } + t := v_1.Type + if auxIntToInt16(v_1.AuxInt) != 1 { + break + } + v.reset(OpEq16) + v0 := b.NewValue0(v.Pos, OpConst16, t) + v0.AuxInt = int16ToAuxInt(0) + v.AddArg2(v0, x) + return true + } // match: (Less16U _ (Const16 [0])) // result: (ConstBool [false]) for { @@ -11698,6 +11962,40 @@ func rewriteValuegeneric_OpLess32(v *Value) bool { v.AddArg2(v0, x) return true } + // match: (Less32 x (Const32 [1])) + // result: (Leq32 x (Const32 [0])) + for { + x := v_0 + if v_1.Op != OpConst32 { + break + } + t := v_1.Type + if auxIntToInt32(v_1.AuxInt) != 1 { + break + } + v.reset(OpLeq32) + v0 := b.NewValue0(v.Pos, OpConst32, t) + v0.AuxInt = int32ToAuxInt(0) + v.AddArg2(x, v0) + return true + } + // match: (Less32 (Const32 [-1]) x) + // result: (Leq32 (Const32 [0]) x) + for { + if v_0.Op != OpConst32 { + break + } + t := v_0.Type + if auxIntToInt32(v_0.AuxInt) != -1 { + break + } + x := v_1 + v.reset(OpLeq32) + v0 := b.NewValue0(v.Pos, OpConst32, t) + v0.AuxInt = int32ToAuxInt(0) + v.AddArg2(v0, x) + return true + } return false } func rewriteValuegeneric_OpLess32F(v *Value) bool { @@ -11723,6 +12021,7 @@ func rewriteValuegeneric_OpLess32F(v *Value) bool { func rewriteValuegeneric_OpLess32U(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Less32U (Const32 [c]) (Const32 [d])) // result: (ConstBool [uint32(c) < uint32(d)]) for { @@ -11738,6 +12037,23 @@ func rewriteValuegeneric_OpLess32U(v *Value) bool { v.AuxInt = boolToAuxInt(uint32(c) < uint32(d)) return true } + // match: (Less32U x (Const32 [1])) + // result: (Eq32 (Const32 [0]) x) + for { + x := v_0 + if v_1.Op != OpConst32 { + break + } + t := v_1.Type + if auxIntToInt32(v_1.AuxInt) != 1 { + break + } + v.reset(OpEq32) + v0 := b.NewValue0(v.Pos, OpConst32, t) + v0.AuxInt = int32ToAuxInt(0) + v.AddArg2(v0, x) + return true + } // match: (Less32U _ (Const32 [0])) // result: (ConstBool [false]) for { @@ -11808,6 +12124,40 @@ func rewriteValuegeneric_OpLess64(v *Value) bool { v.AddArg2(v0, x) return true } + // match: (Less64 x (Const64 [1])) + // result: (Leq64 x (Const64 [0])) + for { + x := v_0 + if v_1.Op != OpConst64 { + break + } + t := v_1.Type + if auxIntToInt64(v_1.AuxInt) != 1 { + break + } + v.reset(OpLeq64) + v0 := b.NewValue0(v.Pos, OpConst64, t) + v0.AuxInt = int64ToAuxInt(0) + v.AddArg2(x, v0) + return true + } + // match: (Less64 (Const64 [-1]) x) + // result: (Leq64 (Const64 [0]) x) + for { + if v_0.Op != OpConst64 { + break + } + t := v_0.Type + if auxIntToInt64(v_0.AuxInt) != -1 { + break + } + x := v_1 + v.reset(OpLeq64) + v0 := b.NewValue0(v.Pos, OpConst64, t) + v0.AuxInt = int64ToAuxInt(0) + v.AddArg2(v0, x) + return true + } return false } func rewriteValuegeneric_OpLess64F(v *Value) bool { @@ -11833,6 +12183,7 @@ func rewriteValuegeneric_OpLess64F(v *Value) bool { func rewriteValuegeneric_OpLess64U(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Less64U (Const64 [c]) (Const64 [d])) // result: (ConstBool [uint64(c) < uint64(d)]) for { @@ -11848,6 +12199,23 @@ func rewriteValuegeneric_OpLess64U(v *Value) bool { v.AuxInt = boolToAuxInt(uint64(c) < uint64(d)) return true } + // match: (Less64U x (Const64 [1])) + // result: (Eq64 (Const64 [0]) x) + for { + x := v_0 + if v_1.Op != OpConst64 { + break + } + t := v_1.Type + if auxIntToInt64(v_1.AuxInt) != 1 { + break + } + v.reset(OpEq64) + v0 := b.NewValue0(v.Pos, OpConst64, t) + v0.AuxInt = int64ToAuxInt(0) + v.AddArg2(v0, x) + return true + } // match: (Less64U _ (Const64 [0])) // result: (ConstBool [false]) for { @@ -11918,11 +12286,46 @@ func rewriteValuegeneric_OpLess8(v *Value) bool { v.AddArg2(v0, x) return true } + // match: (Less8 x (Const8 [1])) + // result: (Leq8 x (Const8 [0])) + for { + x := v_0 + if v_1.Op != OpConst8 { + break + } + t := v_1.Type + if auxIntToInt8(v_1.AuxInt) != 1 { + break + } + v.reset(OpLeq8) + v0 := b.NewValue0(v.Pos, OpConst8, t) + v0.AuxInt = int8ToAuxInt(0) + v.AddArg2(x, v0) + return true + } + // match: (Less8 (Const8 [-1]) x) + // result: (Leq8 (Const8 [0]) x) + for { + if v_0.Op != OpConst8 { + break + } + t := v_0.Type + if auxIntToInt8(v_0.AuxInt) != -1 { + break + } + x := v_1 + v.reset(OpLeq8) + v0 := b.NewValue0(v.Pos, OpConst8, t) + v0.AuxInt = int8ToAuxInt(0) + v.AddArg2(v0, x) + return true + } return false } func rewriteValuegeneric_OpLess8U(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] + b := v.Block // match: (Less8U (Const8 [c]) (Const8 [d])) // result: (ConstBool [ uint8(c) < uint8(d)]) for { @@ -11938,6 +12341,23 @@ func rewriteValuegeneric_OpLess8U(v *Value) bool { v.AuxInt = boolToAuxInt(uint8(c) < uint8(d)) return true } + // match: (Less8U x (Const8 [1])) + // result: (Eq8 (Const8 [0]) x) + for { + x := v_0 + if v_1.Op != OpConst8 { + break + } + t := v_1.Type + if auxIntToInt8(v_1.AuxInt) != 1 { + break + } + v.reset(OpEq8) + v0 := b.NewValue0(v.Pos, OpConst8, t) + v0.AuxInt = int8ToAuxInt(0) + v.AddArg2(v0, x) + return true + } // match: (Less8U _ (Const8 [0])) // result: (ConstBool [false]) for { diff --git a/test/codegen/compare_and_branch.go b/test/codegen/compare_and_branch.go index f7515064b0..b3feef0eb7 100644 --- a/test/codegen/compare_and_branch.go +++ b/test/codegen/compare_and_branch.go @@ -204,3 +204,39 @@ func ui32xu8(x chan uint32) { dummy() } } + +// Signed 64-bit comparison with 1/-1 to comparison with 0. +func si64x0(x chan int64) { + // riscv64:"BGTZ" + for <-x >= 1 { + dummy() + } + + // riscv64:"BLEZ" + for <-x < 1 { + dummy() + } + + // riscv64:"BLTZ" + for <-x <= -1 { + dummy() + } + + // riscv64:"BGEZ" + for <-x > -1 { + dummy() + } +} + +// Unsigned 64-bit comparison with 1 to comparison with 0. +func ui64x0(x chan uint64) { + // riscv64:"BNEZ" + for <-x >= 1 { + dummy() + } + + // riscv64:"BEQZ" + for <-x < 1 { + dummy() + } +} diff --git a/test/prove.go b/test/prove.go index cf225ff38e..00bc0a315f 100644 --- a/test/prove.go +++ b/test/prove.go @@ -820,7 +820,7 @@ func unrollDownExcl0(a []int) int { // Induction variable in unrolled loop. func unrollDownExcl1(a []int) int { var i, x int - for i = len(a) - 1; i >= 1; i -= 2 { // ERROR "Induction variable: limits \[1,\?\], increment 2$" + for i = len(a) - 1; i >= 1; i -= 2 { // ERROR "Induction variable: limits \(0,\?\], increment 2$" x += a[i] // ERROR "Proved IsInBounds$" x += a[i-1] // ERROR "Proved IsInBounds$" }