Black Lives Matter. Support the Equal Justice Initiative.

Source file src/cmd/fix/fix.go

Documentation: cmd/fix

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"path"
    12  	"strconv"
    13  )
    14  
    15  type fix struct {
    16  	name     string
    17  	date     string // date that fix was introduced, in YYYY-MM-DD format
    18  	f        func(*ast.File) bool
    19  	desc     string
    20  	disabled bool // whether this fix should be disabled by default
    21  }
    22  
    23  // main runs sort.Sort(byName(fixes)) before printing list of fixes.
    24  type byName []fix
    25  
    26  func (f byName) Len() int           { return len(f) }
    27  func (f byName) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
    28  func (f byName) Less(i, j int) bool { return f[i].name < f[j].name }
    29  
    30  // main runs sort.Sort(byDate(fixes)) before applying fixes.
    31  type byDate []fix
    32  
    33  func (f byDate) Len() int           { return len(f) }
    34  func (f byDate) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
    35  func (f byDate) Less(i, j int) bool { return f[i].date < f[j].date }
    36  
    37  var fixes []fix
    38  
    39  func register(f fix) {
    40  	fixes = append(fixes, f)
    41  }
    42  
    43  // walk traverses the AST x, calling visit(y) for each node y in the tree but
    44  // also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt,
    45  // in a bottom-up traversal.
    46  func walk(x interface{}, visit func(interface{})) {
    47  	walkBeforeAfter(x, nop, visit)
    48  }
    49  
    50  func nop(interface{}) {}
    51  
    52  // walkBeforeAfter is like walk but calls before(x) before traversing
    53  // x's children and after(x) afterward.
    54  func walkBeforeAfter(x interface{}, before, after func(interface{})) {
    55  	before(x)
    56  
    57  	switch n := x.(type) {
    58  	default:
    59  		panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x))
    60  
    61  	case nil:
    62  
    63  	// pointers to interfaces
    64  	case *ast.Decl:
    65  		walkBeforeAfter(*n, before, after)
    66  	case *ast.Expr:
    67  		walkBeforeAfter(*n, before, after)
    68  	case *ast.Spec:
    69  		walkBeforeAfter(*n, before, after)
    70  	case *ast.Stmt:
    71  		walkBeforeAfter(*n, before, after)
    72  
    73  	// pointers to struct pointers
    74  	case **ast.BlockStmt:
    75  		walkBeforeAfter(*n, before, after)
    76  	case **ast.CallExpr:
    77  		walkBeforeAfter(*n, before, after)
    78  	case **ast.FieldList:
    79  		walkBeforeAfter(*n, before, after)
    80  	case **ast.FuncType:
    81  		walkBeforeAfter(*n, before, after)
    82  	case **ast.Ident:
    83  		walkBeforeAfter(*n, before, after)
    84  	case **ast.BasicLit:
    85  		walkBeforeAfter(*n, before, after)
    86  
    87  	// pointers to slices
    88  	case *[]ast.Decl:
    89  		walkBeforeAfter(*n, before, after)
    90  	case *[]ast.Expr:
    91  		walkBeforeAfter(*n, before, after)
    92  	case *[]*ast.File:
    93  		walkBeforeAfter(*n, before, after)
    94  	case *[]*ast.Ident:
    95  		walkBeforeAfter(*n, before, after)
    96  	case *[]ast.Spec:
    97  		walkBeforeAfter(*n, before, after)
    98  	case *[]ast.Stmt:
    99  		walkBeforeAfter(*n, before, after)
   100  
   101  	// These are ordered and grouped to match ../../go/ast/ast.go
   102  	case *ast.Field:
   103  		walkBeforeAfter(&n.Names, before, after)
   104  		walkBeforeAfter(&n.Type, before, after)
   105  		walkBeforeAfter(&n.Tag, before, after)
   106  	case *ast.FieldList:
   107  		for _, field := range n.List {
   108  			walkBeforeAfter(field, before, after)
   109  		}
   110  	case *ast.BadExpr:
   111  	case *ast.Ident:
   112  	case *ast.Ellipsis:
   113  		walkBeforeAfter(&n.Elt, before, after)
   114  	case *ast.BasicLit:
   115  	case *ast.FuncLit:
   116  		walkBeforeAfter(&n.Type, before, after)
   117  		walkBeforeAfter(&n.Body, before, after)
   118  	case *ast.CompositeLit:
   119  		walkBeforeAfter(&n.Type, before, after)
   120  		walkBeforeAfter(&n.Elts, before, after)
   121  	case *ast.ParenExpr:
   122  		walkBeforeAfter(&n.X, before, after)
   123  	case *ast.SelectorExpr:
   124  		walkBeforeAfter(&n.X, before, after)
   125  	case *ast.IndexExpr:
   126  		walkBeforeAfter(&n.X, before, after)
   127  		walkBeforeAfter(&n.Index, before, after)
   128  	case *ast.SliceExpr:
   129  		walkBeforeAfter(&n.X, before, after)
   130  		if n.Low != nil {
   131  			walkBeforeAfter(&n.Low, before, after)
   132  		}
   133  		if n.High != nil {
   134  			walkBeforeAfter(&n.High, before, after)
   135  		}
   136  	case *ast.TypeAssertExpr:
   137  		walkBeforeAfter(&n.X, before, after)
   138  		walkBeforeAfter(&n.Type, before, after)
   139  	case *ast.CallExpr:
   140  		walkBeforeAfter(&n.Fun, before, after)
   141  		walkBeforeAfter(&n.Args, before, after)
   142  	case *ast.StarExpr:
   143  		walkBeforeAfter(&n.X, before, after)
   144  	case *ast.UnaryExpr:
   145  		walkBeforeAfter(&n.X, before, after)
   146  	case *ast.BinaryExpr:
   147  		walkBeforeAfter(&n.X, before, after)
   148  		walkBeforeAfter(&n.Y, before, after)
   149  	case *ast.KeyValueExpr:
   150  		walkBeforeAfter(&n.Key, before, after)
   151  		walkBeforeAfter(&n.Value, before, after)
   152  
   153  	case *ast.ArrayType:
   154  		walkBeforeAfter(&n.Len, before, after)
   155  		walkBeforeAfter(&n.Elt, before, after)
   156  	case *ast.StructType:
   157  		walkBeforeAfter(&n.Fields, before, after)
   158  	case *ast.FuncType:
   159  		walkBeforeAfter(&n.Params, before, after)
   160  		if n.Results != nil {
   161  			walkBeforeAfter(&n.Results, before, after)
   162  		}
   163  	case *ast.InterfaceType:
   164  		walkBeforeAfter(&n.Methods, before, after)
   165  	case *ast.MapType:
   166  		walkBeforeAfter(&n.Key, before, after)
   167  		walkBeforeAfter(&n.Value, before, after)
   168  	case *ast.ChanType:
   169  		walkBeforeAfter(&n.Value, before, after)
   170  
   171  	case *ast.BadStmt:
   172  	case *ast.DeclStmt:
   173  		walkBeforeAfter(&n.Decl, before, after)
   174  	case *ast.EmptyStmt:
   175  	case *ast.LabeledStmt:
   176  		walkBeforeAfter(&n.Stmt, before, after)
   177  	case *ast.ExprStmt:
   178  		walkBeforeAfter(&n.X, before, after)
   179  	case *ast.SendStmt:
   180  		walkBeforeAfter(&n.Chan, before, after)
   181  		walkBeforeAfter(&n.Value, before, after)
   182  	case *ast.IncDecStmt:
   183  		walkBeforeAfter(&n.X, before, after)
   184  	case *ast.AssignStmt:
   185  		walkBeforeAfter(&n.Lhs, before, after)
   186  		walkBeforeAfter(&n.Rhs, before, after)
   187  	case *ast.GoStmt:
   188  		walkBeforeAfter(&n.Call, before, after)
   189  	case *ast.DeferStmt:
   190  		walkBeforeAfter(&n.Call, before, after)
   191  	case *ast.ReturnStmt:
   192  		walkBeforeAfter(&n.Results, before, after)
   193  	case *ast.BranchStmt:
   194  	case *ast.BlockStmt:
   195  		walkBeforeAfter(&n.List, before, after)
   196  	case *ast.IfStmt:
   197  		walkBeforeAfter(&n.Init, before, after)
   198  		walkBeforeAfter(&n.Cond, before, after)
   199  		walkBeforeAfter(&n.Body, before, after)
   200  		walkBeforeAfter(&n.Else, before, after)
   201  	case *ast.CaseClause:
   202  		walkBeforeAfter(&n.List, before, after)
   203  		walkBeforeAfter(&n.Body, before, after)
   204  	case *ast.SwitchStmt:
   205  		walkBeforeAfter(&n.Init, before, after)
   206  		walkBeforeAfter(&n.Tag, before, after)
   207  		walkBeforeAfter(&n.Body, before, after)
   208  	case *ast.TypeSwitchStmt:
   209  		walkBeforeAfter(&n.Init, before, after)
   210  		walkBeforeAfter(&n.Assign, before, after)
   211  		walkBeforeAfter(&n.Body, before, after)
   212  	case *ast.CommClause:
   213  		walkBeforeAfter(&n.Comm, before, after)
   214  		walkBeforeAfter(&n.Body, before, after)
   215  	case *ast.SelectStmt:
   216  		walkBeforeAfter(&n.Body, before, after)
   217  	case *ast.ForStmt:
   218  		walkBeforeAfter(&n.Init, before, after)
   219  		walkBeforeAfter(&n.Cond, before, after)
   220  		walkBeforeAfter(&n.Post, before, after)
   221  		walkBeforeAfter(&n.Body, before, after)
   222  	case *ast.RangeStmt:
   223  		walkBeforeAfter(&n.Key, before, after)
   224  		walkBeforeAfter(&n.Value, before, after)
   225  		walkBeforeAfter(&n.X, before, after)
   226  		walkBeforeAfter(&n.Body, before, after)
   227  
   228  	case *ast.ImportSpec:
   229  	case *ast.ValueSpec:
   230  		walkBeforeAfter(&n.Type, before, after)
   231  		walkBeforeAfter(&n.Values, before, after)
   232  		walkBeforeAfter(&n.Names, before, after)
   233  	case *ast.TypeSpec:
   234  		walkBeforeAfter(&n.Type, before, after)
   235  
   236  	case *ast.BadDecl:
   237  	case *ast.GenDecl:
   238  		walkBeforeAfter(&n.Specs, before, after)
   239  	case *ast.FuncDecl:
   240  		if n.Recv != nil {
   241  			walkBeforeAfter(&n.Recv, before, after)
   242  		}
   243  		walkBeforeAfter(&n.Type, before, after)
   244  		if n.Body != nil {
   245  			walkBeforeAfter(&n.Body, before, after)
   246  		}
   247  
   248  	case *ast.File:
   249  		walkBeforeAfter(&n.Decls, before, after)
   250  
   251  	case *ast.Package:
   252  		walkBeforeAfter(&n.Files, before, after)
   253  
   254  	case []*ast.File:
   255  		for i := range n {
   256  			walkBeforeAfter(&n[i], before, after)
   257  		}
   258  	case []ast.Decl:
   259  		for i := range n {
   260  			walkBeforeAfter(&n[i], before, after)
   261  		}
   262  	case []ast.Expr:
   263  		for i := range n {
   264  			walkBeforeAfter(&n[i], before, after)
   265  		}
   266  	case []*ast.Ident:
   267  		for i := range n {
   268  			walkBeforeAfter(&n[i], before, after)
   269  		}
   270  	case []ast.Stmt:
   271  		for i := range n {
   272  			walkBeforeAfter(&n[i], before, after)
   273  		}
   274  	case []ast.Spec:
   275  		for i := range n {
   276  			walkBeforeAfter(&n[i], before, after)
   277  		}
   278  	}
   279  	after(x)
   280  }
   281  
   282  // imports reports whether f imports path.
   283  func imports(f *ast.File, path string) bool {
   284  	return importSpec(f, path) != nil
   285  }
   286  
   287  // importSpec returns the import spec if f imports path,
   288  // or nil otherwise.
   289  func importSpec(f *ast.File, path string) *ast.ImportSpec {
   290  	for _, s := range f.Imports {
   291  		if importPath(s) == path {
   292  			return s
   293  		}
   294  	}
   295  	return nil
   296  }
   297  
   298  // importPath returns the unquoted import path of s,
   299  // or "" if the path is not properly quoted.
   300  func importPath(s *ast.ImportSpec) string {
   301  	t, err := strconv.Unquote(s.Path.Value)
   302  	if err == nil {
   303  		return t
   304  	}
   305  	return ""
   306  }
   307  
   308  // declImports reports whether gen contains an import of path.
   309  func declImports(gen *ast.GenDecl, path string) bool {
   310  	if gen.Tok != token.IMPORT {
   311  		return false
   312  	}
   313  	for _, spec := range gen.Specs {
   314  		impspec := spec.(*ast.ImportSpec)
   315  		if importPath(impspec) == path {
   316  			return true
   317  		}
   318  	}
   319  	return false
   320  }
   321  
   322  // isTopName reports whether n is a top-level unresolved identifier with the given name.
   323  func isTopName(n ast.Expr, name string) bool {
   324  	id, ok := n.(*ast.Ident)
   325  	return ok && id.Name == name && id.Obj == nil
   326  }
   327  
   328  // renameTop renames all references to the top-level name old.
   329  // It reports whether it makes any changes.
   330  func renameTop(f *ast.File, old, new string) bool {
   331  	var fixed bool
   332  
   333  	// Rename any conflicting imports
   334  	// (assuming package name is last element of path).
   335  	for _, s := range f.Imports {
   336  		if s.Name != nil {
   337  			if s.Name.Name == old {
   338  				s.Name.Name = new
   339  				fixed = true
   340  			}
   341  		} else {
   342  			_, thisName := path.Split(importPath(s))
   343  			if thisName == old {
   344  				s.Name = ast.NewIdent(new)
   345  				fixed = true
   346  			}
   347  		}
   348  	}
   349  
   350  	// Rename any top-level declarations.
   351  	for _, d := range f.Decls {
   352  		switch d := d.(type) {
   353  		case *ast.FuncDecl:
   354  			if d.Recv == nil && d.Name.Name == old {
   355  				d.Name.Name = new
   356  				d.Name.Obj.Name = new
   357  				fixed = true
   358  			}
   359  		case *ast.GenDecl:
   360  			for _, s := range d.Specs {
   361  				switch s := s.(type) {
   362  				case *ast.TypeSpec:
   363  					if s.Name.Name == old {
   364  						s.Name.Name = new
   365  						s.Name.Obj.Name = new
   366  						fixed = true
   367  					}
   368  				case *ast.ValueSpec:
   369  					for _, n := range s.Names {
   370  						if n.Name == old {
   371  							n.Name = new
   372  							n.Obj.Name = new
   373  							fixed = true
   374  						}
   375  					}
   376  				}
   377  			}
   378  		}
   379  	}
   380  
   381  	// Rename top-level old to new, both unresolved names
   382  	// (probably defined in another file) and names that resolve
   383  	// to a declaration we renamed.
   384  	walk(f, func(n interface{}) {
   385  		id, ok := n.(*ast.Ident)
   386  		if ok && isTopName(id, old) {
   387  			id.Name = new
   388  			fixed = true
   389  		}
   390  		if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new {
   391  			id.Name = id.Obj.Name
   392  			fixed = true
   393  		}
   394  	})
   395  
   396  	return fixed
   397  }
   398  
   399  // matchLen returns the length of the longest prefix shared by x and y.
   400  func matchLen(x, y string) int {
   401  	i := 0
   402  	for i < len(x) && i < len(y) && x[i] == y[i] {
   403  		i++
   404  	}
   405  	return i
   406  }
   407  
   408  // addImport adds the import path to the file f, if absent.
   409  func addImport(f *ast.File, ipath string) (added bool) {
   410  	if imports(f, ipath) {
   411  		return false
   412  	}
   413  
   414  	// Determine name of import.
   415  	// Assume added imports follow convention of using last element.
   416  	_, name := path.Split(ipath)
   417  
   418  	// Rename any conflicting top-level references from name to name_.
   419  	renameTop(f, name, name+"_")
   420  
   421  	newImport := &ast.ImportSpec{
   422  		Path: &ast.BasicLit{
   423  			Kind:  token.STRING,
   424  			Value: strconv.Quote(ipath),
   425  		},
   426  	}
   427  
   428  	// Find an import decl to add to.
   429  	var (
   430  		bestMatch  = -1
   431  		lastImport = -1
   432  		impDecl    *ast.GenDecl
   433  		impIndex   = -1
   434  	)
   435  	for i, decl := range f.Decls {
   436  		gen, ok := decl.(*ast.GenDecl)
   437  		if ok && gen.Tok == token.IMPORT {
   438  			lastImport = i
   439  			// Do not add to import "C", to avoid disrupting the
   440  			// association with its doc comment, breaking cgo.
   441  			if declImports(gen, "C") {
   442  				continue
   443  			}
   444  
   445  			// Compute longest shared prefix with imports in this block.
   446  			for j, spec := range gen.Specs {
   447  				impspec := spec.(*ast.ImportSpec)
   448  				n := matchLen(importPath(impspec), ipath)
   449  				if n > bestMatch {
   450  					bestMatch = n
   451  					impDecl = gen
   452  					impIndex = j
   453  				}
   454  			}
   455  		}
   456  	}
   457  
   458  	// If no import decl found, add one after the last import.
   459  	if impDecl == nil {
   460  		impDecl = &ast.GenDecl{
   461  			Tok: token.IMPORT,
   462  		}
   463  		f.Decls = append(f.Decls, nil)
   464  		copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
   465  		f.Decls[lastImport+1] = impDecl
   466  	}
   467  
   468  	// Ensure the import decl has parentheses, if needed.
   469  	if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() {
   470  		impDecl.Lparen = impDecl.Pos()
   471  	}
   472  
   473  	insertAt := impIndex + 1
   474  	if insertAt == 0 {
   475  		insertAt = len(impDecl.Specs)
   476  	}
   477  	impDecl.Specs = append(impDecl.Specs, nil)
   478  	copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:])
   479  	impDecl.Specs[insertAt] = newImport
   480  	if insertAt > 0 {
   481  		// Assign same position as the previous import,
   482  		// so that the sorter sees it as being in the same block.
   483  		prev := impDecl.Specs[insertAt-1]
   484  		newImport.Path.ValuePos = prev.Pos()
   485  		newImport.EndPos = prev.Pos()
   486  	}
   487  
   488  	f.Imports = append(f.Imports, newImport)
   489  	return true
   490  }
   491  
   492  // deleteImport deletes the import path from the file f, if present.
   493  func deleteImport(f *ast.File, path string) (deleted bool) {
   494  	oldImport := importSpec(f, path)
   495  
   496  	// Find the import node that imports path, if any.
   497  	for i, decl := range f.Decls {
   498  		gen, ok := decl.(*ast.GenDecl)
   499  		if !ok || gen.Tok != token.IMPORT {
   500  			continue
   501  		}
   502  		for j, spec := range gen.Specs {
   503  			impspec := spec.(*ast.ImportSpec)
   504  			if oldImport != impspec {
   505  				continue
   506  			}
   507  
   508  			// We found an import spec that imports path.
   509  			// Delete it.
   510  			deleted = true
   511  			copy(gen.Specs[j:], gen.Specs[j+1:])
   512  			gen.Specs = gen.Specs[:len(gen.Specs)-1]
   513  
   514  			// If this was the last import spec in this decl,
   515  			// delete the decl, too.
   516  			if len(gen.Specs) == 0 {
   517  				copy(f.Decls[i:], f.Decls[i+1:])
   518  				f.Decls = f.Decls[:len(f.Decls)-1]
   519  			} else if len(gen.Specs) == 1 {
   520  				gen.Lparen = token.NoPos // drop parens
   521  			}
   522  			if j > 0 {
   523  				// We deleted an entry but now there will be
   524  				// a blank line-sized hole where the import was.
   525  				// Close the hole by making the previous
   526  				// import appear to "end" where this one did.
   527  				gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
   528  			}
   529  			break
   530  		}
   531  	}
   532  
   533  	// Delete it from f.Imports.
   534  	for i, imp := range f.Imports {
   535  		if imp == oldImport {
   536  			copy(f.Imports[i:], f.Imports[i+1:])
   537  			f.Imports = f.Imports[:len(f.Imports)-1]
   538  			break
   539  		}
   540  	}
   541  
   542  	return
   543  }
   544  
   545  // rewriteImport rewrites any import of path oldPath to path newPath.
   546  func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
   547  	for _, imp := range f.Imports {
   548  		if importPath(imp) == oldPath {
   549  			rewrote = true
   550  			// record old End, because the default is to compute
   551  			// it using the length of imp.Path.Value.
   552  			imp.EndPos = imp.End()
   553  			imp.Path.Value = strconv.Quote(newPath)
   554  		}
   555  	}
   556  	return
   557  }
   558  

View as plain text