cmd/compile: handle boolean and pointer relations

The constant lattice for these types is pretty simple.
We no longer need the old-style facts table, as the ordering
table now has all that information.

Change-Id: If0e118c27a4de8e9bfd727b78942185c2eb50c4b
Reviewed-on: https://go-review.googlesource.com/c/go/+/599097
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
This commit is contained in:
khr@golang.org 2024-07-07 14:58:47 -07:00 committed by Keith Randall
parent a4a130f6d0
commit 5925cd3d15
3 changed files with 263 additions and 132 deletions

View File

@ -128,9 +128,18 @@ type fact struct {
}
// a limit records known upper and lower bounds for a value.
//
// If we have min>max or umin>umax, then this limit is
// called "unsatisfiable". When we encounter such a limit, we
// know that any code for which that limit applies is unreachable.
// We don't particularly care how unsatisfiable limits propagate,
// including becoming satisfiable, because any optimization
// decisions based on those limits only apply to unreachable code.
type limit struct {
min, max int64 // min <= value <= max, signed
umin, umax uint64 // umin <= value <= umax, unsigned
// For booleans, we use 0==false, 1==true for both ranges
// For pointers, we use 0,0,0,0 for nil and minInt64,maxInt64,1,maxUint64 for nonnil
}
func (l limit) String() string {
@ -359,8 +368,9 @@ type ordering struct {
next *ordering // linked list of all known orderings for v.
// Note: v is implicit here, determined by which linked list it is in.
w *Value
d domain // one of signed or unsigned
d domain
r relation // one of ==,!=,<,<=,>,>=
// if d is boolean or pointer, r can only be ==, !=
}
// factsTable keeps track of relations between pairs of values.
@ -379,9 +389,6 @@ type factsTable struct {
unsat bool // true if facts contains a contradiction
unsatDepth int // number of unsat checkpoints
facts map[pair]relation // current known set of relation
stack []fact // previous sets of relations
// order* is a couple of partial order sets that record information
// about relations between SSA values in the signed and unsigned
// domain.
@ -423,8 +430,6 @@ func newFactsTable(f *Func) *factsTable {
ft.orderS.SetUnsigned(false)
ft.orderU.SetUnsigned(true)
ft.orderings = make(map[ID]*ordering)
ft.facts = make(map[pair]relation)
ft.stack = make([]fact, 4)
ft.limits = f.Cache.allocLimitSlice(f.NumValues())
for _, b := range f.Blocks {
for _, v := range b.Values {
@ -471,6 +476,21 @@ func (ft *factsTable) unsignedMinMax(v *Value, min, max uint64) bool {
return ft.newLimit(v, limit{min: math.MinInt64, max: math.MaxInt64, umin: min, umax: max})
}
func (ft *factsTable) booleanFalse(v *Value) bool {
return ft.newLimit(v, limit{min: 0, max: 0, umin: 0, umax: 0})
}
func (ft *factsTable) booleanTrue(v *Value) bool {
return ft.newLimit(v, limit{min: 1, max: 1, umin: 1, umax: 1})
}
func (ft *factsTable) pointerNil(v *Value) bool {
return ft.newLimit(v, limit{min: 0, max: 0, umin: 0, umax: 0})
}
func (ft *factsTable) pointerNonNil(v *Value) bool {
l := noLimit
l.umin = 1
return ft.newLimit(v, l)
}
// newLimit adds new limiting information for v.
// Returns true if the new limit added any new information.
func (ft *factsTable) newLimit(v *Value, newLim limit) bool {
@ -574,6 +594,38 @@ func (ft *factsTable) newLimit(v *Value, newLim limit) bool {
}
}
}
case boolean:
switch o.r {
case eq:
if lim.min == 0 && lim.max == 0 { // constant false
ft.booleanFalse(o.w)
}
if lim.min == 1 && lim.max == 1 { // constant true
ft.booleanTrue(o.w)
}
case lt | gt:
if lim.min == 0 && lim.max == 0 { // constant false
ft.booleanTrue(o.w)
}
if lim.min == 1 && lim.max == 1 { // constant true
ft.booleanFalse(o.w)
}
}
case pointer:
switch o.r {
case eq:
if lim.umax == 0 { // nil
ft.pointerNil(o.w)
}
if lim.umin > 0 { // non-nil
ft.pointerNonNil(o.w)
}
case lt | gt:
if lim.umax == 0 { // nil
ft.pointerNonNil(o.w)
}
// note: not equal to non-nil doesn't tell us anything.
}
}
}
@ -647,122 +699,163 @@ func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) {
ft.unsat = true
return
}
} else {
if lessByID(w, v) {
v, w = w, v
r = reverseBits[r]
}
p := pair{v, w, d}
oldR, ok := ft.facts[p]
if !ok {
if v == w {
oldR = eq
} else {
oldR = lt | eq | gt
}
if d == boolean || d == pointer {
for o := ft.orderings[v.ID]; o != nil; o = o.next {
if o.d == d && o.w == w {
// We already know a relationship between v and w.
// Either it is a duplicate, or it is a contradiction,
// as we only allow eq and lt|gt for these domains,
if o.r != r {
ft.unsat = true
}
return
}
}
// No changes compared to information already in facts table.
if oldR == r {
return
}
ft.stack = append(ft.stack, fact{p, oldR})
ft.facts[p] = oldR & r
// If this relation is not satisfiable, mark it and exit right away
if oldR&r == 0 {
if parent.Func.pass.debug > 2 {
parent.Func.Warnl(parent.Pos, "unsat %s %s %s", v, w, r)
}
ft.unsat = true
return
}
// TODO: this does not do transitive equality.
// We could use a poset like above, but somewhat degenerate (==,!= only).
ft.addOrdering(v, w, d, r)
ft.addOrdering(w, v, d, r) // note: reverseBits unnecessary for eq and lt|gt.
}
// Extract new constant limits based on the comparison.
if d == signed || d == unsigned {
vLimit := ft.limits[v.ID]
wLimit := ft.limits[w.ID]
// Note: all the +1/-1 below could overflow/underflow. Either will
// still generate correct results, it will just lead to imprecision.
// In fact if there is overflow/underflow, the corresponding
// code is unreachable because the known range is outside the range
// of the value's type.
switch d {
case signed:
switch r {
case eq: // v == w
ft.signedMinMax(v, wLimit.min, wLimit.max)
ft.signedMinMax(w, vLimit.min, vLimit.max)
case lt: // v < w
ft.signedMax(v, wLimit.max-1)
ft.signedMin(w, vLimit.min+1)
case lt | eq: // v <= w
ft.signedMax(v, wLimit.max)
ft.signedMin(w, vLimit.min)
case gt: // v > w
ft.signedMin(v, wLimit.min+1)
ft.signedMax(w, vLimit.max-1)
case gt | eq: // v >= w
ft.signedMin(v, wLimit.min)
ft.signedMax(w, vLimit.max)
case lt | gt: // v != w
if vLimit.min == vLimit.max { // v is a constant
c := vLimit.min
if wLimit.min == c {
ft.signedMin(w, c+1)
}
if wLimit.max == c {
ft.signedMax(w, c-1)
}
vLimit := ft.limits[v.ID]
wLimit := ft.limits[w.ID]
// Note: all the +1/-1 below could overflow/underflow. Either will
// still generate correct results, it will just lead to imprecision.
// In fact if there is overflow/underflow, the corresponding
// code is unreachable because the known range is outside the range
// of the value's type.
switch d {
case signed:
switch r {
case eq: // v == w
ft.signedMinMax(v, wLimit.min, wLimit.max)
ft.signedMinMax(w, vLimit.min, vLimit.max)
case lt: // v < w
ft.signedMax(v, wLimit.max-1)
ft.signedMin(w, vLimit.min+1)
case lt | eq: // v <= w
ft.signedMax(v, wLimit.max)
ft.signedMin(w, vLimit.min)
case gt: // v > w
ft.signedMin(v, wLimit.min+1)
ft.signedMax(w, vLimit.max-1)
case gt | eq: // v >= w
ft.signedMin(v, wLimit.min)
ft.signedMax(w, vLimit.max)
case lt | gt: // v != w
if vLimit.min == vLimit.max { // v is a constant
c := vLimit.min
if wLimit.min == c {
ft.signedMin(w, c+1)
}
if wLimit.min == wLimit.max { // w is a constant
c := wLimit.min
if vLimit.min == c {
ft.signedMin(v, c+1)
}
if vLimit.max == c {
ft.signedMax(v, c-1)
}
if wLimit.max == c {
ft.signedMax(w, c-1)
}
}
case unsigned:
switch r {
case eq: // v == w
ft.unsignedMinMax(v, wLimit.umin, wLimit.umax)
ft.unsignedMinMax(w, vLimit.umin, vLimit.umax)
case lt: // v < w
ft.unsignedMax(v, wLimit.umax-1)
ft.unsignedMin(w, vLimit.umin+1)
case lt | eq: // v <= w
ft.unsignedMax(v, wLimit.umax)
ft.unsignedMin(w, vLimit.umin)
case gt: // v > w
ft.unsignedMin(v, wLimit.umin+1)
ft.unsignedMax(w, vLimit.umax-1)
case gt | eq: // v >= w
ft.unsignedMin(v, wLimit.umin)
ft.unsignedMax(w, vLimit.umax)
case lt | gt: // v != w
if vLimit.umin == vLimit.umax { // v is a constant
c := vLimit.umin
if wLimit.umin == c {
ft.unsignedMin(w, c+1)
}
if wLimit.umax == c {
ft.unsignedMax(w, c-1)
}
if wLimit.min == wLimit.max { // w is a constant
c := wLimit.min
if vLimit.min == c {
ft.signedMin(v, c+1)
}
if wLimit.umin == wLimit.umax { // w is a constant
c := wLimit.umin
if vLimit.umin == c {
ft.unsignedMin(v, c+1)
}
if vLimit.umax == c {
ft.unsignedMax(v, c-1)
}
if vLimit.max == c {
ft.signedMax(v, c-1)
}
}
}
case unsigned:
switch r {
case eq: // v == w
ft.unsignedMinMax(v, wLimit.umin, wLimit.umax)
ft.unsignedMinMax(w, vLimit.umin, vLimit.umax)
case lt: // v < w
ft.unsignedMax(v, wLimit.umax-1)
ft.unsignedMin(w, vLimit.umin+1)
case lt | eq: // v <= w
ft.unsignedMax(v, wLimit.umax)
ft.unsignedMin(w, vLimit.umin)
case gt: // v > w
ft.unsignedMin(v, wLimit.umin+1)
ft.unsignedMax(w, vLimit.umax-1)
case gt | eq: // v >= w
ft.unsignedMin(v, wLimit.umin)
ft.unsignedMax(w, vLimit.umax)
case lt | gt: // v != w
if vLimit.umin == vLimit.umax { // v is a constant
c := vLimit.umin
if wLimit.umin == c {
ft.unsignedMin(w, c+1)
}
if wLimit.umax == c {
ft.unsignedMax(w, c-1)
}
}
if wLimit.umin == wLimit.umax { // w is a constant
c := wLimit.umin
if vLimit.umin == c {
ft.unsignedMin(v, c+1)
}
if vLimit.umax == c {
ft.unsignedMax(v, c-1)
}
}
}
case boolean:
switch r {
case eq: // v == w
if vLimit.min == 1 { // v is true
ft.booleanTrue(w)
}
if vLimit.max == 0 { // v is false
ft.booleanFalse(w)
}
if wLimit.min == 1 { // w is true
ft.booleanTrue(v)
}
if wLimit.max == 0 { // w is false
ft.booleanFalse(v)
}
case lt | gt: // v != w
if vLimit.min == 1 { // v is true
ft.booleanFalse(w)
}
if vLimit.max == 0 { // v is false
ft.booleanTrue(w)
}
if wLimit.min == 1 { // w is true
ft.booleanFalse(v)
}
if wLimit.max == 0 { // w is false
ft.booleanTrue(v)
}
}
case pointer:
switch r {
case eq: // v == w
if vLimit.umax == 0 { // v is nil
ft.pointerNil(w)
}
if vLimit.umin > 0 { // v is non-nil
ft.pointerNonNil(w)
}
if wLimit.umax == 0 { // w is nil
ft.pointerNil(v)
}
if wLimit.umin > 0 { // w is non-nil
ft.pointerNonNil(v)
}
case lt | gt: // v != w
if vLimit.umax == 0 { // v is nil
ft.pointerNonNil(w)
}
if wLimit.umax == 0 { // w is nil
ft.pointerNonNil(v)
}
// Note: the other direction doesn't work.
// Being not equal to a non-nil pointer doesn't
// make you (necessarily) a nil pointer.
}
}
// Derived facts below here are only about numbers.
@ -970,7 +1063,6 @@ func (ft *factsTable) checkpoint() {
if ft.unsat {
ft.unsatDepth++
}
ft.stack = append(ft.stack, checkpointFact)
ft.limitStack = append(ft.limitStack, checkpointBound)
ft.orderS.Checkpoint()
ft.orderU.Checkpoint()
@ -986,18 +1078,6 @@ func (ft *factsTable) restore() {
} else {
ft.unsat = false
}
for {
old := ft.stack[len(ft.stack)-1]
ft.stack = ft.stack[:len(ft.stack)-1]
if old == checkpointFact {
break
}
if old.r == lt|eq|gt {
delete(ft.facts, old.p)
} else {
ft.facts[old.p] = old.r
}
}
for {
old := ft.limitStack[len(ft.limitStack)-1]
ft.limitStack = ft.limitStack[:len(ft.limitStack)-1]
@ -1050,12 +1130,14 @@ var (
OpEq32: {signed | unsigned, eq},
OpEq64: {signed | unsigned, eq},
OpEqPtr: {pointer, eq},
OpEqB: {boolean, eq},
OpNeq8: {signed | unsigned, lt | gt},
OpNeq16: {signed | unsigned, lt | gt},
OpNeq32: {signed | unsigned, lt | gt},
OpNeq64: {signed | unsigned, lt | gt},
OpNeqPtr: {pointer, lt | gt},
OpNeqB: {boolean, lt | gt},
OpLess8: {signed, lt},
OpLess8U: {unsigned, lt},
@ -1407,8 +1489,28 @@ func prove(f *Func) {
// flowLimit, below, which computes additional constraints based on
// ranges of opcode arguments).
func initLimit(v *Value) limit {
if v.Type.IsBoolean() {
switch v.Op {
case OpConstBool:
b := v.AuxInt
return limit{min: b, max: b, umin: uint64(b), umax: uint64(b)}
default:
return limit{min: 0, max: 1, umin: 0, umax: 1}
}
}
if v.Type.IsPtrShaped() { // These are the types that EqPtr/NeqPtr operate on, except uintptr.
switch v.Op {
case OpConstNil:
return limit{min: 0, max: 0, umin: 0, umax: 0}
case OpAddr, OpLocalAddr: // TODO: others?
l := noLimit
l.umin = 1
return l
default:
return noLimit
}
}
if !v.Type.IsInteger() {
// TODO: boolean?
return noLimit
}
@ -1700,9 +1802,9 @@ func addBranchRestrictions(ft *factsTable, b *Block, br branch) {
c := b.Controls[0]
switch {
case br == negative:
addRestrictions(b, ft, boolean, nil, c, eq)
ft.booleanFalse(c)
case br == positive:
addRestrictions(b, ft, boolean, nil, c, lt|gt)
ft.booleanTrue(c)
case br >= jumpTable0:
idx := br - jumpTable0
val := int64(idx)
@ -1769,7 +1871,14 @@ func addBranchRestrictions(ft *factsTable, b *Block, br branch) {
addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r)
}
}
}
if c.Op == OpIsNonNil {
switch br {
case positive:
ft.pointerNonNil(c.Args[0])
case negative:
ft.pointerNil(c.Args[0])
}
}
}
@ -1984,7 +2093,7 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
// Helps in cases where we reuse a value after branching on its equality.
for i, arg := range v.Args {
switch arg.Op {
case OpConst64, OpConst32, OpConst16, OpConst8:
case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool, OpConstNil:
continue
}
lim := ft.limits[arg.ID]

View File

@ -148,11 +148,11 @@ func fEqInterEqInter(a interface{}, f float64) bool {
}
func fEqInterNeqInter(a interface{}, f float64) bool {
return a == nil && f > Cf2 || a != nil && f < -Cf2
return a == nil && f > Cf2 || a != nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil"
}
func fNeqInterEqInter(a interface{}, f float64) bool {
return a != nil && f > Cf2 || a == nil && f < -Cf2
return a != nil && f > Cf2 || a == nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil"
}
func fNeqInterNeqInter(a interface{}, f float64) bool {
@ -164,11 +164,11 @@ func fEqSliceEqSlice(a []int, f float64) bool {
}
func fEqSliceNeqSlice(a []int, f float64) bool {
return a == nil && f > Cf2 || a != nil && f < -Cf2
return a == nil && f > Cf2 || a != nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil"
}
func fNeqSliceEqSlice(a []int, f float64) bool {
return a != nil && f > Cf2 || a == nil && f < -Cf2
return a != nil && f > Cf2 || a == nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil"
}
func fNeqSliceNeqSlice(a []int, f float64) bool {

View File

@ -1159,6 +1159,28 @@ func issue66826b(a [31]byte, i int) {
_ = a[3*i] // ERROR "Proved IsInBounds"
}
func f20(a, b bool) int {
if a == b {
if a {
if b { // ERROR "Proved Arg"
return 1
}
}
}
return 0
}
func f21(a, b *int) int {
if a == b {
if a != nil {
if b != nil { // ERROR "Proved IsNonNil"
return 1
}
}
}
return 0
}
//go:noinline
func useInt(a int) {
}