Black Lives Matter. Support the Equal Justice Initiative.

Source file src/cmd/fix/typecheck.go

Documentation: cmd/fix

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/token"
    12  	exec "internal/execabs"
    13  	"os"
    14  	"path/filepath"
    15  	"reflect"
    16  	"runtime"
    17  	"strings"
    18  )
    19  
    20  // Partial type checker.
    21  //
    22  // The fact that it is partial is very important: the input is
    23  // an AST and a description of some type information to
    24  // assume about one or more packages, but not all the
    25  // packages that the program imports. The checker is
    26  // expected to do as much as it can with what it has been
    27  // given. There is not enough information supplied to do
    28  // a full type check, but the type checker is expected to
    29  // apply information that can be derived from variable
    30  // declarations, function and method returns, and type switches
    31  // as far as it can, so that the caller can still tell the types
    32  // of expression relevant to a particular fix.
    33  //
    34  // TODO(rsc,gri): Replace with go/typechecker.
    35  // Doing that could be an interesting test case for go/typechecker:
    36  // the constraints about working with partial information will
    37  // likely exercise it in interesting ways. The ideal interface would
    38  // be to pass typecheck a map from importpath to package API text
    39  // (Go source code), but for now we use data structures (TypeConfig, Type).
    40  //
    41  // The strings mostly use gofmt form.
    42  //
    43  // A Field or FieldList has as its type a comma-separated list
    44  // of the types of the fields. For example, the field list
    45  //	x, y, z int
    46  // has type "int, int, int".
    47  
    48  // The prefix "type " is the type of a type.
    49  // For example, given
    50  //	var x int
    51  //	type T int
    52  // x's type is "int" but T's type is "type int".
    53  // mkType inserts the "type " prefix.
    54  // getType removes it.
    55  // isType tests for it.
    56  
    57  func mkType(t string) string {
    58  	return "type " + t
    59  }
    60  
    61  func getType(t string) string {
    62  	if !isType(t) {
    63  		return ""
    64  	}
    65  	return t[len("type "):]
    66  }
    67  
    68  func isType(t string) bool {
    69  	return strings.HasPrefix(t, "type ")
    70  }
    71  
    72  // TypeConfig describes the universe of relevant types.
    73  // For ease of creation, the types are all referred to by string
    74  // name (e.g., "reflect.Value").  TypeByName is the only place
    75  // where the strings are resolved.
    76  
    77  type TypeConfig struct {
    78  	Type map[string]*Type
    79  	Var  map[string]string
    80  	Func map[string]string
    81  
    82  	// External maps from a name to its type.
    83  	// It provides additional typings not present in the Go source itself.
    84  	// For now, the only additional typings are those generated by cgo.
    85  	External map[string]string
    86  }
    87  
    88  // typeof returns the type of the given name, which may be of
    89  // the form "x" or "p.X".
    90  func (cfg *TypeConfig) typeof(name string) string {
    91  	if cfg.Var != nil {
    92  		if t := cfg.Var[name]; t != "" {
    93  			return t
    94  		}
    95  	}
    96  	if cfg.Func != nil {
    97  		if t := cfg.Func[name]; t != "" {
    98  			return "func()" + t
    99  		}
   100  	}
   101  	return ""
   102  }
   103  
   104  // Type describes the Fields and Methods of a type.
   105  // If the field or method cannot be found there, it is next
   106  // looked for in the Embed list.
   107  type Type struct {
   108  	Field  map[string]string // map field name to type
   109  	Method map[string]string // map method name to comma-separated return types (should start with "func ")
   110  	Embed  []string          // list of types this type embeds (for extra methods)
   111  	Def    string            // definition of named type
   112  }
   113  
   114  // dot returns the type of "typ.name", making its decision
   115  // using the type information in cfg.
   116  func (typ *Type) dot(cfg *TypeConfig, name string) string {
   117  	if typ.Field != nil {
   118  		if t := typ.Field[name]; t != "" {
   119  			return t
   120  		}
   121  	}
   122  	if typ.Method != nil {
   123  		if t := typ.Method[name]; t != "" {
   124  			return t
   125  		}
   126  	}
   127  
   128  	for _, e := range typ.Embed {
   129  		etyp := cfg.Type[e]
   130  		if etyp != nil {
   131  			if t := etyp.dot(cfg, name); t != "" {
   132  				return t
   133  			}
   134  		}
   135  	}
   136  
   137  	return ""
   138  }
   139  
   140  // typecheck type checks the AST f assuming the information in cfg.
   141  // It returns two maps with type information:
   142  // typeof maps AST nodes to type information in gofmt string form.
   143  // assign maps type strings to lists of expressions that were assigned
   144  // to values of another type that were assigned to that type.
   145  func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) {
   146  	typeof = make(map[interface{}]string)
   147  	assign = make(map[string][]interface{})
   148  	cfg1 := &TypeConfig{}
   149  	*cfg1 = *cfg // make copy so we can add locally
   150  	copied := false
   151  
   152  	// If we import "C", add types of cgo objects.
   153  	cfg.External = map[string]string{}
   154  	cfg1.External = cfg.External
   155  	if imports(f, "C") {
   156  		// Run cgo on gofmtFile(f)
   157  		// Parse, extract decls from _cgo_gotypes.go
   158  		// Map _Ctype_* types to C.* types.
   159  		err := func() error {
   160  			txt, err := gofmtFile(f)
   161  			if err != nil {
   162  				return err
   163  			}
   164  			dir, err := os.MkdirTemp(os.TempDir(), "fix_cgo_typecheck")
   165  			if err != nil {
   166  				return err
   167  			}
   168  			defer os.RemoveAll(dir)
   169  			err = os.WriteFile(filepath.Join(dir, "in.go"), txt, 0600)
   170  			if err != nil {
   171  				return err
   172  			}
   173  			cmd := exec.Command(filepath.Join(runtime.GOROOT(), "bin", "go"), "tool", "cgo", "-objdir", dir, "-srcdir", dir, "in.go")
   174  			err = cmd.Run()
   175  			if err != nil {
   176  				return err
   177  			}
   178  			out, err := os.ReadFile(filepath.Join(dir, "_cgo_gotypes.go"))
   179  			if err != nil {
   180  				return err
   181  			}
   182  			cgo, err := parser.ParseFile(token.NewFileSet(), "cgo.go", out, 0)
   183  			if err != nil {
   184  				return err
   185  			}
   186  			for _, decl := range cgo.Decls {
   187  				fn, ok := decl.(*ast.FuncDecl)
   188  				if !ok {
   189  					continue
   190  				}
   191  				if strings.HasPrefix(fn.Name.Name, "_Cfunc_") {
   192  					var params, results []string
   193  					for _, p := range fn.Type.Params.List {
   194  						t := gofmt(p.Type)
   195  						t = strings.ReplaceAll(t, "_Ctype_", "C.")
   196  						params = append(params, t)
   197  					}
   198  					for _, r := range fn.Type.Results.List {
   199  						t := gofmt(r.Type)
   200  						t = strings.ReplaceAll(t, "_Ctype_", "C.")
   201  						results = append(results, t)
   202  					}
   203  					cfg.External["C."+fn.Name.Name[7:]] = joinFunc(params, results)
   204  				}
   205  			}
   206  			return nil
   207  		}()
   208  		if err != nil {
   209  			fmt.Fprintf(os.Stderr, "go fix: warning: no cgo types: %s\n", err)
   210  		}
   211  	}
   212  
   213  	// gather function declarations
   214  	for _, decl := range f.Decls {
   215  		fn, ok := decl.(*ast.FuncDecl)
   216  		if !ok {
   217  			continue
   218  		}
   219  		typecheck1(cfg, fn.Type, typeof, assign)
   220  		t := typeof[fn.Type]
   221  		if fn.Recv != nil {
   222  			// The receiver must be a type.
   223  			rcvr := typeof[fn.Recv]
   224  			if !isType(rcvr) {
   225  				if len(fn.Recv.List) != 1 {
   226  					continue
   227  				}
   228  				rcvr = mkType(gofmt(fn.Recv.List[0].Type))
   229  				typeof[fn.Recv.List[0].Type] = rcvr
   230  			}
   231  			rcvr = getType(rcvr)
   232  			if rcvr != "" && rcvr[0] == '*' {
   233  				rcvr = rcvr[1:]
   234  			}
   235  			typeof[rcvr+"."+fn.Name.Name] = t
   236  		} else {
   237  			if isType(t) {
   238  				t = getType(t)
   239  			} else {
   240  				t = gofmt(fn.Type)
   241  			}
   242  			typeof[fn.Name] = t
   243  
   244  			// Record typeof[fn.Name.Obj] for future references to fn.Name.
   245  			typeof[fn.Name.Obj] = t
   246  		}
   247  	}
   248  
   249  	// gather struct declarations
   250  	for _, decl := range f.Decls {
   251  		d, ok := decl.(*ast.GenDecl)
   252  		if ok {
   253  			for _, s := range d.Specs {
   254  				switch s := s.(type) {
   255  				case *ast.TypeSpec:
   256  					if cfg1.Type[s.Name.Name] != nil {
   257  						break
   258  					}
   259  					if !copied {
   260  						copied = true
   261  						// Copy map lazily: it's time.
   262  						cfg1.Type = make(map[string]*Type)
   263  						for k, v := range cfg.Type {
   264  							cfg1.Type[k] = v
   265  						}
   266  					}
   267  					t := &Type{Field: map[string]string{}}
   268  					cfg1.Type[s.Name.Name] = t
   269  					switch st := s.Type.(type) {
   270  					case *ast.StructType:
   271  						for _, f := range st.Fields.List {
   272  							for _, n := range f.Names {
   273  								t.Field[n.Name] = gofmt(f.Type)
   274  							}
   275  						}
   276  					case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
   277  						t.Def = gofmt(st)
   278  					}
   279  				}
   280  			}
   281  		}
   282  	}
   283  
   284  	typecheck1(cfg1, f, typeof, assign)
   285  	return typeof, assign
   286  }
   287  
   288  func makeExprList(a []*ast.Ident) []ast.Expr {
   289  	var b []ast.Expr
   290  	for _, x := range a {
   291  		b = append(b, x)
   292  	}
   293  	return b
   294  }
   295  
   296  // Typecheck1 is the recursive form of typecheck.
   297  // It is like typecheck but adds to the information in typeof
   298  // instead of allocating a new map.
   299  func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) {
   300  	// set sets the type of n to typ.
   301  	// If isDecl is true, n is being declared.
   302  	set := func(n ast.Expr, typ string, isDecl bool) {
   303  		if typeof[n] != "" || typ == "" {
   304  			if typeof[n] != typ {
   305  				assign[typ] = append(assign[typ], n)
   306  			}
   307  			return
   308  		}
   309  		typeof[n] = typ
   310  
   311  		// If we obtained typ from the declaration of x
   312  		// propagate the type to all the uses.
   313  		// The !isDecl case is a cheat here, but it makes
   314  		// up in some cases for not paying attention to
   315  		// struct fields. The real type checker will be
   316  		// more accurate so we won't need the cheat.
   317  		if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
   318  			typeof[id.Obj] = typ
   319  		}
   320  	}
   321  
   322  	// Type-check an assignment lhs = rhs.
   323  	// If isDecl is true, this is := so we can update
   324  	// the types of the objects that lhs refers to.
   325  	typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
   326  		if len(lhs) > 1 && len(rhs) == 1 {
   327  			if _, ok := rhs[0].(*ast.CallExpr); ok {
   328  				t := split(typeof[rhs[0]])
   329  				// Lists should have same length but may not; pair what can be paired.
   330  				for i := 0; i < len(lhs) && i < len(t); i++ {
   331  					set(lhs[i], t[i], isDecl)
   332  				}
   333  				return
   334  			}
   335  		}
   336  		if len(lhs) == 1 && len(rhs) == 2 {
   337  			// x = y, ok
   338  			rhs = rhs[:1]
   339  		} else if len(lhs) == 2 && len(rhs) == 1 {
   340  			// x, ok = y
   341  			lhs = lhs[:1]
   342  		}
   343  
   344  		// Match as much as we can.
   345  		for i := 0; i < len(lhs) && i < len(rhs); i++ {
   346  			x, y := lhs[i], rhs[i]
   347  			if typeof[y] != "" {
   348  				set(x, typeof[y], isDecl)
   349  			} else {
   350  				set(y, typeof[x], false)
   351  			}
   352  		}
   353  	}
   354  
   355  	expand := func(s string) string {
   356  		typ := cfg.Type[s]
   357  		if typ != nil && typ.Def != "" {
   358  			return typ.Def
   359  		}
   360  		return s
   361  	}
   362  
   363  	// The main type check is a recursive algorithm implemented
   364  	// by walkBeforeAfter(n, before, after).
   365  	// Most of it is bottom-up, but in a few places we need
   366  	// to know the type of the function we are checking.
   367  	// The before function records that information on
   368  	// the curfn stack.
   369  	var curfn []*ast.FuncType
   370  
   371  	before := func(n interface{}) {
   372  		// push function type on stack
   373  		switch n := n.(type) {
   374  		case *ast.FuncDecl:
   375  			curfn = append(curfn, n.Type)
   376  		case *ast.FuncLit:
   377  			curfn = append(curfn, n.Type)
   378  		}
   379  	}
   380  
   381  	// After is the real type checker.
   382  	after := func(n interface{}) {
   383  		if n == nil {
   384  			return
   385  		}
   386  		if false && reflect.TypeOf(n).Kind() == reflect.Ptr { // debugging trace
   387  			defer func() {
   388  				if t := typeof[n]; t != "" {
   389  					pos := fset.Position(n.(ast.Node).Pos())
   390  					fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
   391  				}
   392  			}()
   393  		}
   394  
   395  		switch n := n.(type) {
   396  		case *ast.FuncDecl, *ast.FuncLit:
   397  			// pop function type off stack
   398  			curfn = curfn[:len(curfn)-1]
   399  
   400  		case *ast.FuncType:
   401  			typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
   402  
   403  		case *ast.FieldList:
   404  			// Field list is concatenation of sub-lists.
   405  			t := ""
   406  			for _, field := range n.List {
   407  				if t != "" {
   408  					t += ", "
   409  				}
   410  				t += typeof[field]
   411  			}
   412  			typeof[n] = t
   413  
   414  		case *ast.Field:
   415  			// Field is one instance of the type per name.
   416  			all := ""
   417  			t := typeof[n.Type]
   418  			if !isType(t) {
   419  				// Create a type, because it is typically *T or *p.T
   420  				// and we might care about that type.
   421  				t = mkType(gofmt(n.Type))
   422  				typeof[n.Type] = t
   423  			}
   424  			t = getType(t)
   425  			if len(n.Names) == 0 {
   426  				all = t
   427  			} else {
   428  				for _, id := range n.Names {
   429  					if all != "" {
   430  						all += ", "
   431  					}
   432  					all += t
   433  					typeof[id.Obj] = t
   434  					typeof[id] = t
   435  				}
   436  			}
   437  			typeof[n] = all
   438  
   439  		case *ast.ValueSpec:
   440  			// var declaration. Use type if present.
   441  			if n.Type != nil {
   442  				t := typeof[n.Type]
   443  				if !isType(t) {
   444  					t = mkType(gofmt(n.Type))
   445  					typeof[n.Type] = t
   446  				}
   447  				t = getType(t)
   448  				for _, id := range n.Names {
   449  					set(id, t, true)
   450  				}
   451  			}
   452  			// Now treat same as assignment.
   453  			typecheckAssign(makeExprList(n.Names), n.Values, true)
   454  
   455  		case *ast.AssignStmt:
   456  			typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
   457  
   458  		case *ast.Ident:
   459  			// Identifier can take its type from underlying object.
   460  			if t := typeof[n.Obj]; t != "" {
   461  				typeof[n] = t
   462  			}
   463  
   464  		case *ast.SelectorExpr:
   465  			// Field or method.
   466  			name := n.Sel.Name
   467  			if t := typeof[n.X]; t != "" {
   468  				t = strings.TrimPrefix(t, "*") // implicit *
   469  				if typ := cfg.Type[t]; typ != nil {
   470  					if t := typ.dot(cfg, name); t != "" {
   471  						typeof[n] = t
   472  						return
   473  					}
   474  				}
   475  				tt := typeof[t+"."+name]
   476  				if isType(tt) {
   477  					typeof[n] = getType(tt)
   478  					return
   479  				}
   480  			}
   481  			// Package selector.
   482  			if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
   483  				str := x.Name + "." + name
   484  				if cfg.Type[str] != nil {
   485  					typeof[n] = mkType(str)
   486  					return
   487  				}
   488  				if t := cfg.typeof(x.Name + "." + name); t != "" {
   489  					typeof[n] = t
   490  					return
   491  				}
   492  			}
   493  
   494  		case *ast.CallExpr:
   495  			// make(T) has type T.
   496  			if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
   497  				typeof[n] = gofmt(n.Args[0])
   498  				return
   499  			}
   500  			// new(T) has type *T
   501  			if isTopName(n.Fun, "new") && len(n.Args) == 1 {
   502  				typeof[n] = "*" + gofmt(n.Args[0])
   503  				return
   504  			}
   505  			// Otherwise, use type of function to determine arguments.
   506  			t := typeof[n.Fun]
   507  			if t == "" {
   508  				t = cfg.External[gofmt(n.Fun)]
   509  			}
   510  			in, out := splitFunc(t)
   511  			if in == nil && out == nil {
   512  				return
   513  			}
   514  			typeof[n] = join(out)
   515  			for i, arg := range n.Args {
   516  				if i >= len(in) {
   517  					break
   518  				}
   519  				if typeof[arg] == "" {
   520  					typeof[arg] = in[i]
   521  				}
   522  			}
   523  
   524  		case *ast.TypeAssertExpr:
   525  			// x.(type) has type of x.
   526  			if n.Type == nil {
   527  				typeof[n] = typeof[n.X]
   528  				return
   529  			}
   530  			// x.(T) has type T.
   531  			if t := typeof[n.Type]; isType(t) {
   532  				typeof[n] = getType(t)
   533  			} else {
   534  				typeof[n] = gofmt(n.Type)
   535  			}
   536  
   537  		case *ast.SliceExpr:
   538  			// x[i:j] has type of x.
   539  			typeof[n] = typeof[n.X]
   540  
   541  		case *ast.IndexExpr:
   542  			// x[i] has key type of x's type.
   543  			t := expand(typeof[n.X])
   544  			if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
   545  				// Lazy: assume there are no nested [] in the array
   546  				// length or map key type.
   547  				if i := strings.Index(t, "]"); i >= 0 {
   548  					typeof[n] = t[i+1:]
   549  				}
   550  			}
   551  
   552  		case *ast.StarExpr:
   553  			// *x for x of type *T has type T when x is an expr.
   554  			// We don't use the result when *x is a type, but
   555  			// compute it anyway.
   556  			t := expand(typeof[n.X])
   557  			if isType(t) {
   558  				typeof[n] = "type *" + getType(t)
   559  			} else if strings.HasPrefix(t, "*") {
   560  				typeof[n] = t[len("*"):]
   561  			}
   562  
   563  		case *ast.UnaryExpr:
   564  			// &x for x of type T has type *T.
   565  			t := typeof[n.X]
   566  			if t != "" && n.Op == token.AND {
   567  				typeof[n] = "*" + t
   568  			}
   569  
   570  		case *ast.CompositeLit:
   571  			// T{...} has type T.
   572  			typeof[n] = gofmt(n.Type)
   573  
   574  			// Propagate types down to values used in the composite literal.
   575  			t := expand(typeof[n])
   576  			if strings.HasPrefix(t, "[") { // array or slice
   577  				// Lazy: assume there are no nested [] in the array length.
   578  				if i := strings.Index(t, "]"); i >= 0 {
   579  					et := t[i+1:]
   580  					for _, e := range n.Elts {
   581  						if kv, ok := e.(*ast.KeyValueExpr); ok {
   582  							e = kv.Value
   583  						}
   584  						if typeof[e] == "" {
   585  							typeof[e] = et
   586  						}
   587  					}
   588  				}
   589  			}
   590  			if strings.HasPrefix(t, "map[") { // map
   591  				// Lazy: assume there are no nested [] in the map key type.
   592  				if i := strings.Index(t, "]"); i >= 0 {
   593  					kt, vt := t[4:i], t[i+1:]
   594  					for _, e := range n.Elts {
   595  						if kv, ok := e.(*ast.KeyValueExpr); ok {
   596  							if typeof[kv.Key] == "" {
   597  								typeof[kv.Key] = kt
   598  							}
   599  							if typeof[kv.Value] == "" {
   600  								typeof[kv.Value] = vt
   601  							}
   602  						}
   603  					}
   604  				}
   605  			}
   606  			if typ := cfg.Type[t]; typ != nil && len(typ.Field) > 0 { // struct
   607  				for _, e := range n.Elts {
   608  					if kv, ok := e.(*ast.KeyValueExpr); ok {
   609  						if ft := typ.Field[fmt.Sprintf("%s", kv.Key)]; ft != "" {
   610  							if typeof[kv.Value] == "" {
   611  								typeof[kv.Value] = ft
   612  							}
   613  						}
   614  					}
   615  				}
   616  			}
   617  
   618  		case *ast.ParenExpr:
   619  			// (x) has type of x.
   620  			typeof[n] = typeof[n.X]
   621  
   622  		case *ast.RangeStmt:
   623  			t := expand(typeof[n.X])
   624  			if t == "" {
   625  				return
   626  			}
   627  			var key, value string
   628  			if t == "string" {
   629  				key, value = "int", "rune"
   630  			} else if strings.HasPrefix(t, "[") {
   631  				key = "int"
   632  				if i := strings.Index(t, "]"); i >= 0 {
   633  					value = t[i+1:]
   634  				}
   635  			} else if strings.HasPrefix(t, "map[") {
   636  				if i := strings.Index(t, "]"); i >= 0 {
   637  					key, value = t[4:i], t[i+1:]
   638  				}
   639  			}
   640  			changed := false
   641  			if n.Key != nil && key != "" {
   642  				changed = true
   643  				set(n.Key, key, n.Tok == token.DEFINE)
   644  			}
   645  			if n.Value != nil && value != "" {
   646  				changed = true
   647  				set(n.Value, value, n.Tok == token.DEFINE)
   648  			}
   649  			// Ugly failure of vision: already type-checked body.
   650  			// Do it again now that we have that type info.
   651  			if changed {
   652  				typecheck1(cfg, n.Body, typeof, assign)
   653  			}
   654  
   655  		case *ast.TypeSwitchStmt:
   656  			// Type of variable changes for each case in type switch,
   657  			// but go/parser generates just one variable.
   658  			// Repeat type check for each case with more precise
   659  			// type information.
   660  			as, ok := n.Assign.(*ast.AssignStmt)
   661  			if !ok {
   662  				return
   663  			}
   664  			varx, ok := as.Lhs[0].(*ast.Ident)
   665  			if !ok {
   666  				return
   667  			}
   668  			t := typeof[varx]
   669  			for _, cas := range n.Body.List {
   670  				cas := cas.(*ast.CaseClause)
   671  				if len(cas.List) == 1 {
   672  					// Variable has specific type only when there is
   673  					// exactly one type in the case list.
   674  					if tt := typeof[cas.List[0]]; isType(tt) {
   675  						tt = getType(tt)
   676  						typeof[varx] = tt
   677  						typeof[varx.Obj] = tt
   678  						typecheck1(cfg, cas.Body, typeof, assign)
   679  					}
   680  				}
   681  			}
   682  			// Restore t.
   683  			typeof[varx] = t
   684  			typeof[varx.Obj] = t
   685  
   686  		case *ast.ReturnStmt:
   687  			if len(curfn) == 0 {
   688  				// Probably can't happen.
   689  				return
   690  			}
   691  			f := curfn[len(curfn)-1]
   692  			res := n.Results
   693  			if f.Results != nil {
   694  				t := split(typeof[f.Results])
   695  				for i := 0; i < len(res) && i < len(t); i++ {
   696  					set(res[i], t[i], false)
   697  				}
   698  			}
   699  
   700  		case *ast.BinaryExpr:
   701  			// Propagate types across binary ops that require two args of the same type.
   702  			switch n.Op {
   703  			case token.EQL, token.NEQ: // TODO: more cases. This is enough for the cftype fix.
   704  				if typeof[n.X] != "" && typeof[n.Y] == "" {
   705  					typeof[n.Y] = typeof[n.X]
   706  				}
   707  				if typeof[n.X] == "" && typeof[n.Y] != "" {
   708  					typeof[n.X] = typeof[n.Y]
   709  				}
   710  			}
   711  		}
   712  	}
   713  	walkBeforeAfter(f, before, after)
   714  }
   715  
   716  // Convert between function type strings and lists of types.
   717  // Using strings makes this a little harder, but it makes
   718  // a lot of the rest of the code easier. This will all go away
   719  // when we can use go/typechecker directly.
   720  
   721  // splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"].
   722  func splitFunc(s string) (in, out []string) {
   723  	if !strings.HasPrefix(s, "func(") {
   724  		return nil, nil
   725  	}
   726  
   727  	i := len("func(") // index of beginning of 'in' arguments
   728  	nparen := 0
   729  	for j := i; j < len(s); j++ {
   730  		switch s[j] {
   731  		case '(':
   732  			nparen++
   733  		case ')':
   734  			nparen--
   735  			if nparen < 0 {
   736  				// found end of parameter list
   737  				out := strings.TrimSpace(s[j+1:])
   738  				if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
   739  					out = out[1 : len(out)-1]
   740  				}
   741  				return split(s[i:j]), split(out)
   742  			}
   743  		}
   744  	}
   745  	return nil, nil
   746  }
   747  
   748  // joinFunc is the inverse of splitFunc.
   749  func joinFunc(in, out []string) string {
   750  	outs := ""
   751  	if len(out) == 1 {
   752  		outs = " " + out[0]
   753  	} else if len(out) > 1 {
   754  		outs = " (" + join(out) + ")"
   755  	}
   756  	return "func(" + join(in) + ")" + outs
   757  }
   758  
   759  // split splits "int, float" into ["int", "float"] and splits "" into [].
   760  func split(s string) []string {
   761  	out := []string{}
   762  	i := 0 // current type being scanned is s[i:j].
   763  	nparen := 0
   764  	for j := 0; j < len(s); j++ {
   765  		switch s[j] {
   766  		case ' ':
   767  			if i == j {
   768  				i++
   769  			}
   770  		case '(':
   771  			nparen++
   772  		case ')':
   773  			nparen--
   774  			if nparen < 0 {
   775  				// probably can't happen
   776  				return nil
   777  			}
   778  		case ',':
   779  			if nparen == 0 {
   780  				if i < j {
   781  					out = append(out, s[i:j])
   782  				}
   783  				i = j + 1
   784  			}
   785  		}
   786  	}
   787  	if nparen != 0 {
   788  		// probably can't happen
   789  		return nil
   790  	}
   791  	if i < len(s) {
   792  		out = append(out, s[i:])
   793  	}
   794  	return out
   795  }
   796  
   797  // join is the inverse of split.
   798  func join(x []string) string {
   799  	return strings.Join(x, ", ")
   800  }
   801  

View as plain text