diff --git a/src/cmd/compile/internal/gc/export.go b/src/cmd/compile/internal/gc/export.go index e19d52fa95..a11e5fdd30 100644 --- a/src/cmd/compile/internal/gc/export.go +++ b/src/cmd/compile/internal/gc/export.go @@ -94,15 +94,14 @@ func (p *exporter) markObject(n ir.Node) { // markType recursively visits types reachable from t to identify // functions whose inline bodies may be needed. func (p *exporter) markType(t *types.Type) { + if t.IsInstantiatedGeneric() { + // Re-instantiated types don't add anything new, so don't follow them. + return + } if p.marked[t] { return } p.marked[t] = true - if t.HasTParam() { - // Don't deal with any generic types or their methods, since we - // will only be inlining actual instantiations, not generic methods. - return - } // If this is a named type, mark all of its associated // methods. Skip interface types because t.Methods contains @@ -159,5 +158,8 @@ func (p *exporter) markType(t *types.Type) { p.markType(f.Type) } } + + case types.TTYPEPARAM: + // No other type that needs to be followed. } } diff --git a/src/cmd/compile/internal/noder/decl.go b/src/cmd/compile/internal/noder/decl.go index 375eb41898..5c80b20671 100644 --- a/src/cmd/compile/internal/noder/decl.go +++ b/src/cmd/compile/internal/noder/decl.go @@ -109,6 +109,16 @@ func (g *irgen) funcDecl(out *ir.Nodes, decl *syntax.FuncDecl) { } g.funcBody(fn, decl.Recv, decl.Type, decl.Body) + if fn.Type().HasTParam() && fn.Body != nil { + // Set pointers to the dcls/body of a generic function/method in + // the Inl struct, so it is marked for export, is available for + // stenciling, and works with Inline_Flood(). + fn.Inl = &ir.Inline{ + Cost: 1, + Dcl: fn.Dcl, + Body: fn.Body, + } + } out.Append(fn) } diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 3ba364f67c..8145f9e8f9 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -17,8 +17,6 @@ import ( "go/constant" ) -// For catching problems as we add more features -// TODO(danscales): remove assertions or replace with base.FatalfAt() func assert(p bool) { if !p { panic("assertion failed") diff --git a/src/cmd/compile/internal/reflectdata/reflect.go b/src/cmd/compile/internal/reflectdata/reflect.go index 604cec6096..0fcb7e3d6d 100644 --- a/src/cmd/compile/internal/reflectdata/reflect.go +++ b/src/cmd/compile/internal/reflectdata/reflect.go @@ -949,7 +949,7 @@ func writeType(t *types.Type) *obj.LSym { // in the local package, even if they may be marked as part of // another package (the package of their base generic type). if tbase.Sym() != nil && tbase.Sym().Pkg != types.LocalPkg && - !tbase.IsInstantiated() { + !tbase.IsFullyInstantiated() { if i := typecheck.BaseTypeIndex(t); i >= 0 { lsym.Pkg = tbase.Sym().Pkg.Prefix lsym.SymIdx = int32(i) @@ -1795,7 +1795,7 @@ func methodWrapper(rcvr *types.Type, method *types.Field, forItab bool) *obj.LSy // instantiated methods. if rcvr.IsPtr() && rcvr.Elem() == method.Type.Recv().Type && rcvr.Elem().Sym() != nil && rcvr.Elem().Sym().Pkg != types.LocalPkg && - !rcvr.Elem().IsInstantiated() { + !rcvr.Elem().IsFullyInstantiated() { return lsym } diff --git a/src/cmd/compile/internal/typecheck/expr.go b/src/cmd/compile/internal/typecheck/expr.go index 24d141e8a2..30d864320f 100644 --- a/src/cmd/compile/internal/typecheck/expr.go +++ b/src/cmd/compile/internal/typecheck/expr.go @@ -311,8 +311,19 @@ func tcCompLit(n *ir.CompLitExpr) (res ir.Node) { f := t.Field(i) s := f.Sym - if s != nil && !types.IsExported(s.Name) && s.Pkg != types.LocalPkg { - base.Errorf("implicit assignment of unexported field '%s' in %v literal", s.Name, t) + + // Do the test for assigning to unexported fields. + // But if this is an instantiated function, then + // the function has already been typechecked. In + // that case, don't do the test, since it can fail + // for the closure structs created in + // walkClosure(), because the instantiated + // function is compiled as if in the source + // package of the generic function. + if !(ir.CurFunc != nil && strings.Index(ir.CurFunc.Nname.Sym().Name, "[") >= 0) { + if s != nil && !types.IsExported(s.Name) && s.Pkg != types.LocalPkg { + base.Errorf("implicit assignment of unexported field '%s' in %v literal", s.Name, t) + } } // No pushtype allowed here. Must name fields for that. n1 = AssignConv(n1, f.Type, "field value") diff --git a/src/cmd/compile/internal/typecheck/func.go b/src/cmd/compile/internal/typecheck/func.go index 760b8868ab..f9ee686f9e 100644 --- a/src/cmd/compile/internal/typecheck/func.go +++ b/src/cmd/compile/internal/typecheck/func.go @@ -74,8 +74,25 @@ func ClosureType(clo *ir.ClosureExpr) *types.Type { // The information appears in the binary in the form of type descriptors; // the struct is unnamed so that closures in multiple packages with the // same struct type can share the descriptor. + + // Make sure the .F field is in the same package as the rest of the + // fields. This deals with closures in instantiated functions, which are + // compiled as if from the source package of the generic function. + var pkg *types.Pkg + if len(clo.Func.ClosureVars) == 0 { + pkg = types.LocalPkg + } else { + for _, v := range clo.Func.ClosureVars { + if pkg == nil { + pkg = v.Sym().Pkg + } else if pkg != v.Sym().Pkg { + base.Fatalf("Closure variables from multiple packages") + } + } + } + fields := []*types.Field{ - types.NewField(base.Pos, Lookup(".F"), types.Types[types.TUINTPTR]), + types.NewField(base.Pos, pkg.Lookup(".F"), types.Types[types.TUINTPTR]), } for _, v := range clo.Func.ClosureVars { typ := v.Type() diff --git a/src/cmd/compile/internal/typecheck/iexport.go b/src/cmd/compile/internal/typecheck/iexport.go index f635b79ada..236f6ed789 100644 --- a/src/cmd/compile/internal/typecheck/iexport.go +++ b/src/cmd/compile/internal/typecheck/iexport.go @@ -1332,24 +1332,9 @@ func (w *exportWriter) funcExt(n *ir.Name) { } } - // Inline body. - if n.Type().HasTParam() { - if n.Func.Inl != nil { - // n.Func.Inl may already be set on a generic function if - // we imported it from another package, but shouldn't be - // set for a generic function in the local package. - if n.Sym().Pkg == types.LocalPkg { - base.FatalfAt(n.Pos(), "generic function is marked inlineable") - } - } else { - // Populate n.Func.Inl, so body of exported generic function will - // be written out. - n.Func.Inl = &ir.Inline{ - Cost: 1, - Dcl: n.Func.Dcl, - Body: n.Func.Body, - } - } + // Write out inline body or body of a generic function/method. + if n.Type().HasTParam() && n.Func.Body != nil && n.Func.Inl == nil { + base.FatalfAt(n.Pos(), "generic function is not marked inlineable") } if n.Func.Inl != nil { w.uint64(1 + uint64(n.Func.Inl.Cost)) diff --git a/src/cmd/compile/internal/types/type.go b/src/cmd/compile/internal/types/type.go index 7a05230a78..a3a6050c52 100644 --- a/src/cmd/compile/internal/types/type.go +++ b/src/cmd/compile/internal/types/type.go @@ -8,6 +8,7 @@ import ( "cmd/compile/internal/base" "cmd/internal/src" "fmt" + "strings" "sync" ) @@ -279,10 +280,23 @@ func (t *Type) SetRParams(rparams []*Type) { } } -// IsInstantiated reports whether t is a fully instantiated generic type; i.e. an +// IsBaseGeneric returns true if t is a generic type (not reinstantiated with +// another type params or fully instantiated. +func (t *Type) IsBaseGeneric() bool { + return len(t.RParams()) > 0 && strings.Index(t.Sym().Name, "[") < 0 +} + +// IsInstantiatedGeneric returns t if t ia generic type that has been +// reinstantiated with new typeparams (i.e. is not fully instantiated). +func (t *Type) IsInstantiatedGeneric() bool { + return len(t.RParams()) > 0 && strings.Index(t.Sym().Name, "[") >= 0 && + t.HasTParam() +} + +// IsFullyInstantiated reports whether t is a fully instantiated generic type; i.e. an // instantiated generic type where all type arguments are non-generic or fully // instantiated generic types. -func (t *Type) IsInstantiated() bool { +func (t *Type) IsFullyInstantiated() bool { return len(t.RParams()) > 0 && !t.HasTParam() } diff --git a/test/typeparam/orderedmapsimp.dir/a.go b/test/typeparam/orderedmapsimp.dir/a.go index 1b5827b4bb..37fc3e79b9 100644 --- a/test/typeparam/orderedmapsimp.dir/a.go +++ b/test/typeparam/orderedmapsimp.dir/a.go @@ -100,25 +100,25 @@ type keyValue[K, V any] struct { } // iterate returns an iterator that traverses the map. -// func (m *Map[K, V]) Iterate() *Iterator[K, V] { -// sender, receiver := Ranger[keyValue[K, V]]() -// var f func(*node[K, V]) bool -// f = func(n *node[K, V]) bool { -// if n == nil { -// return true -// } -// // Stop the traversal if Send fails, which means that -// // nothing is listening to the receiver. -// return f(n.left) && -// sender.Send(context.Background(), keyValue[K, V]{n.key, n.val}) && -// f(n.right) -// } -// go func() { -// f(m.root) -// sender.Close() -// }() -// return &Iterator[K, V]{receiver} -// } +func (m *Map[K, V]) Iterate() *Iterator[K, V] { + sender, receiver := Ranger[keyValue[K, V]]() + var f func(*node[K, V]) bool + f = func(n *node[K, V]) bool { + if n == nil { + return true + } + // Stop the traversal if Send fails, which means that + // nothing is listening to the receiver. + return f(n.left) && + sender.Send(context.Background(), keyValue[K, V]{n.key, n.val}) && + f(n.right) + } + go func() { + f(m.root) + sender.Close() + }() + return &Iterator[K, V]{receiver} +} // Iterator is used to iterate over the map. type Iterator[K, V any] struct { diff --git a/test/typeparam/orderedmapsimp.dir/main.go b/test/typeparam/orderedmapsimp.dir/main.go index 77869ad9fc..ac4cee6a78 100644 --- a/test/typeparam/orderedmapsimp.dir/main.go +++ b/test/typeparam/orderedmapsimp.dir/main.go @@ -41,24 +41,21 @@ func TestMap() { panic(fmt.Sprintf("unexpectedly found %q", []byte("d"))) } - // TODO(danscales): Iterate() has some things to be fixed with inlining in - // stenciled functions and using closures across packages. - - // gather := func(it *a.Iterator[[]byte, int]) []int { - // var r []int - // for { - // _, v, ok := it.Next() - // if !ok { - // return r - // } - // r = append(r, v) - // } - // } - // got := gather(m.Iterate()) - // want := []int{'a', 'b', 'x'} - // if !a.SliceEqual(got, want) { - // panic(fmt.Sprintf("Iterate returned %v, want %v", got, want)) - // } + gather := func(it *a.Iterator[[]byte, int]) []int { + var r []int + for { + _, v, ok := it.Next() + if !ok { + return r + } + r = append(r, v) + } + } + got := gather(m.Iterate()) + want := []int{'a', 'b', 'x'} + if !a.SliceEqual(got, want) { + panic(fmt.Sprintf("Iterate returned %v, want %v", got, want)) + } }