Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 77 additions & 24 deletions rlp/rlpgen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"sort"

"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
"golang.org/x/tools/go/packages"
)

// buildContext keeps the data needed for make*Op.
Expand Down Expand Up @@ -96,14 +97,20 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
// file and assigns unique names of temporary variables.
type genContext struct {
inPackage *types.Package
imports map[string]struct{}
imports map[string]genImportPackage
tempCounter int
}

type genImportPackage struct {
alias string
pkg *types.Package
}

func newGenContext(inPackage *types.Package) *genContext {
return &genContext{
inPackage: inPackage,
imports: make(map[string]struct{}),
inPackage: inPackage,
imports: make(map[string]genImportPackage),
tempCounter: 0,
}
}

Expand All @@ -117,32 +124,78 @@ func (ctx *genContext) resetTemp() {
ctx.tempCounter = 0
}

func (ctx *genContext) addImport(path string) {
if path == ctx.inPackage.Path() {
return // avoid importing the package that we're generating in.
func (ctx *genContext) addImportPath(path string) {
pkg, err := ctx.loadPackage(path)
if err != nil {
panic(fmt.Sprintf("can't load package %q: %v", path, err))
}
// TODO: renaming?
ctx.imports[path] = struct{}{}
ctx.addImport(pkg)
}

// importsList returns all packages that need to be imported.
func (ctx *genContext) importsList() []string {
imp := make([]string, 0, len(ctx.imports))
for k := range ctx.imports {
imp = append(imp, k)
func (ctx *genContext) addImport(pkg *types.Package) string {
if pkg.Path() == ctx.inPackage.Path() {
return "" // avoid importing the package that we're generating in
}
sort.Strings(imp)
return imp
if p, exists := ctx.imports[pkg.Path()]; exists {
return p.alias
}
var (
baseName = pkg.Name()
alias = baseName
counter = 1
)
// If the base name conflicts with an existing import, add a numeric suffix.
for ctx.hasAlias(alias) {
alias = fmt.Sprintf("%s%d", baseName, counter)
counter++
}
ctx.imports[pkg.Path()] = genImportPackage{alias, pkg}
return alias
}

// hasAlias checks if an alias is already in use
func (ctx *genContext) hasAlias(alias string) bool {
for _, p := range ctx.imports {
if p.alias == alias {
return true
}
}
return false
}

// qualify is the types.Qualifier used for printing types.
// loadPackage attempts to load package information
func (ctx *genContext) loadPackage(path string) (*types.Package, error) {
cfg := &packages.Config{Mode: packages.NeedName}
pkgs, err := packages.Load(cfg, path)
if err != nil {
return nil, err
}
if len(pkgs) == 0 {
return nil, fmt.Errorf("no package found for path %s", path)
}
return types.NewPackage(path, pkgs[0].Name), nil
}

// qualify is the types.Qualifier used for printing types
func (ctx *genContext) qualify(pkg *types.Package) string {
if pkg.Path() == ctx.inPackage.Path() {
return ""
}
ctx.addImport(pkg.Path())
// TODO: renaming?
return pkg.Name()
return ctx.addImport(pkg)
}

// importsList returns all packages that need to be imported
func (ctx *genContext) importsList() []string {
imp := make([]string, 0, len(ctx.imports))
for path, p := range ctx.imports {
if p.alias == p.pkg.Name() {
imp = append(imp, fmt.Sprintf("%q", path))
} else {
imp = append(imp, fmt.Sprintf("%s %q", p.alias, path))
}
}
sort.Strings(imp)
return imp
}

type op interface {
Expand Down Expand Up @@ -359,7 +412,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string {
}

func (op uint256Op) genDecode(ctx *genContext) (string, string) {
ctx.addImport("github.com/holiman/uint256")
ctx.addImportPath("github.com/holiman/uint256")

var b bytes.Buffer
resultV := ctx.temp()
Expand Down Expand Up @@ -732,7 +785,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru
// generateDecoder generates the DecodeRLP method on 'typ'.
func generateDecoder(ctx *genContext, typ string, op op) []byte {
ctx.resetTemp()
ctx.addImport(pathOfPackageRLP)
ctx.addImportPath(pathOfPackageRLP)

result, code := op.genDecode(ctx)
var b bytes.Buffer
Expand All @@ -747,8 +800,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte {
// generateEncoder generates the EncodeRLP method on 'typ'.
func generateEncoder(ctx *genContext, typ string, op op) []byte {
ctx.resetTemp()
ctx.addImport("io")
ctx.addImport(pathOfPackageRLP)
ctx.addImportPath("io")
ctx.addImportPath(pathOfPackageRLP)

var b bytes.Buffer
fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
Expand Down Expand Up @@ -783,7 +836,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b
var b bytes.Buffer
fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
for _, imp := range ctx.importsList() {
fmt.Fprintf(&b, "import %q\n", imp)
fmt.Fprintf(&b, "import %s\n", imp)
}
if encoder {
fmt.Fprintln(&b)
Expand Down
2 changes: 1 addition & 1 deletion rlp/rlpgen/gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func init() {
}
}

var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"}
var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256", "pkgclash"}

func TestOutput(t *testing.T) {
for _, test := range tests {
Expand Down
13 changes: 13 additions & 0 deletions rlp/rlpgen/testdata/pkgclash.in.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// -*- mode: go -*-

package test

import (
eth1 "github.com/ethereum/go-ethereum/eth"
eth2 "github.com/ethereum/go-ethereum/eth/protocols/eth"
)

type Test struct {
A eth1.MinerAPI
B eth2.GetReceiptsPacket
}
82 changes: 82 additions & 0 deletions rlp/rlpgen/testdata/pkgclash.out.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package test

import "github.com/ethereum/go-ethereum/common"
import "github.com/ethereum/go-ethereum/eth"
import "github.com/ethereum/go-ethereum/rlp"
import "io"
import eth1 "github.com/ethereum/go-ethereum/eth/protocols/eth"

func (obj *Test) EncodeRLP(_w io.Writer) error {
w := rlp.NewEncoderBuffer(_w)
_tmp0 := w.List()
_tmp1 := w.List()
w.ListEnd(_tmp1)
_tmp2 := w.List()
w.WriteUint64(obj.B.RequestId)
_tmp3 := w.List()
for _, _tmp4 := range obj.B.GetReceiptsRequest {
w.WriteBytes(_tmp4[:])
}
w.ListEnd(_tmp3)
w.ListEnd(_tmp2)
w.ListEnd(_tmp0)
return w.Flush()
}

func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
var _tmp0 Test
{
if _, err := dec.List(); err != nil {
return err
}
// A:
var _tmp1 eth.MinerAPI
{
if _, err := dec.List(); err != nil {
return err
}
if err := dec.ListEnd(); err != nil {
return err
}
}
_tmp0.A = _tmp1
// B:
var _tmp2 eth1.GetReceiptsPacket
{
if _, err := dec.List(); err != nil {
return err
}
// RequestId:
_tmp3, err := dec.Uint64()
if err != nil {
return err
}
_tmp2.RequestId = _tmp3
// GetReceiptsRequest:
var _tmp4 []common.Hash
if _, err := dec.List(); err != nil {
return err
}
for dec.MoreDataInList() {
var _tmp5 common.Hash
if err := dec.ReadBytes(_tmp5[:]); err != nil {
return err
}
_tmp4 = append(_tmp4, _tmp5)
}
if err := dec.ListEnd(); err != nil {
return err
}
_tmp2.GetReceiptsRequest = _tmp4
if err := dec.ListEnd(); err != nil {
return err
}
}
_tmp0.B = _tmp2
if err := dec.ListEnd(); err != nil {
return err
}
}
*obj = _tmp0
return nil
}
Loading