Black Lives Matter. Support the Equal Justice Initiative.

Source file src/cmd/fix/main.go

Documentation: cmd/fix

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/scanner"
    15  	"go/token"
    16  	"io"
    17  	"io/fs"
    18  	"os"
    19  	"path/filepath"
    20  	"sort"
    21  	"strings"
    22  
    23  	"cmd/internal/diff"
    24  )
    25  
    26  var (
    27  	fset     = token.NewFileSet()
    28  	exitCode = 0
    29  )
    30  
    31  var allowedRewrites = flag.String("r", "",
    32  	"restrict the rewrites to this comma-separated list")
    33  
    34  var forceRewrites = flag.String("force", "",
    35  	"force these fixes to run even if the code looks updated")
    36  
    37  var allowed, force map[string]bool
    38  
    39  var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
    40  
    41  // enable for debugging fix failures
    42  const debug = false // display incorrectly reformatted source and exit
    43  
    44  func usage() {
    45  	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    46  	flag.PrintDefaults()
    47  	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    48  	sort.Sort(byName(fixes))
    49  	for _, f := range fixes {
    50  		if f.disabled {
    51  			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
    52  		} else {
    53  			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    54  		}
    55  		desc := strings.TrimSpace(f.desc)
    56  		desc = strings.ReplaceAll(desc, "\n", "\n\t")
    57  		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    58  	}
    59  	os.Exit(2)
    60  }
    61  
    62  func main() {
    63  	flag.Usage = usage
    64  	flag.Parse()
    65  
    66  	sort.Sort(byDate(fixes))
    67  
    68  	if *allowedRewrites != "" {
    69  		allowed = make(map[string]bool)
    70  		for _, f := range strings.Split(*allowedRewrites, ",") {
    71  			allowed[f] = true
    72  		}
    73  	}
    74  
    75  	if *forceRewrites != "" {
    76  		force = make(map[string]bool)
    77  		for _, f := range strings.Split(*forceRewrites, ",") {
    78  			force[f] = true
    79  		}
    80  	}
    81  
    82  	if flag.NArg() == 0 {
    83  		if err := processFile("standard input", true); err != nil {
    84  			report(err)
    85  		}
    86  		os.Exit(exitCode)
    87  	}
    88  
    89  	for i := 0; i < flag.NArg(); i++ {
    90  		path := flag.Arg(i)
    91  		switch dir, err := os.Stat(path); {
    92  		case err != nil:
    93  			report(err)
    94  		case dir.IsDir():
    95  			walkDir(path)
    96  		default:
    97  			if err := processFile(path, false); err != nil {
    98  				report(err)
    99  			}
   100  		}
   101  	}
   102  
   103  	os.Exit(exitCode)
   104  }
   105  
   106  const parserMode = parser.ParseComments
   107  
   108  func gofmtFile(f *ast.File) ([]byte, error) {
   109  	var buf bytes.Buffer
   110  	if err := format.Node(&buf, fset, f); err != nil {
   111  		return nil, err
   112  	}
   113  	return buf.Bytes(), nil
   114  }
   115  
   116  func processFile(filename string, useStdin bool) error {
   117  	var f *os.File
   118  	var err error
   119  	var fixlog bytes.Buffer
   120  
   121  	if useStdin {
   122  		f = os.Stdin
   123  	} else {
   124  		f, err = os.Open(filename)
   125  		if err != nil {
   126  			return err
   127  		}
   128  		defer f.Close()
   129  	}
   130  
   131  	src, err := io.ReadAll(f)
   132  	if err != nil {
   133  		return err
   134  	}
   135  
   136  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	// Make sure file is in canonical format.
   142  	// This "fmt" pseudo-fix cannot be disabled.
   143  	newSrc, err := gofmtFile(file)
   144  	if err != nil {
   145  		return err
   146  	}
   147  	if !bytes.Equal(newSrc, src) {
   148  		newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
   149  		if err != nil {
   150  			return err
   151  		}
   152  		file = newFile
   153  		fmt.Fprintf(&fixlog, " fmt")
   154  	}
   155  
   156  	// Apply all fixes to file.
   157  	newFile := file
   158  	fixed := false
   159  	for _, fix := range fixes {
   160  		if allowed != nil && !allowed[fix.name] {
   161  			continue
   162  		}
   163  		if fix.disabled && !force[fix.name] {
   164  			continue
   165  		}
   166  		if fix.f(newFile) {
   167  			fixed = true
   168  			fmt.Fprintf(&fixlog, " %s", fix.name)
   169  
   170  			// AST changed.
   171  			// Print and parse, to update any missing scoping
   172  			// or position information for subsequent fixers.
   173  			newSrc, err := gofmtFile(newFile)
   174  			if err != nil {
   175  				return err
   176  			}
   177  			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   178  			if err != nil {
   179  				if debug {
   180  					fmt.Printf("%s", newSrc)
   181  					report(err)
   182  					os.Exit(exitCode)
   183  				}
   184  				return err
   185  			}
   186  		}
   187  	}
   188  	if !fixed {
   189  		return nil
   190  	}
   191  	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   192  
   193  	// Print AST.  We did that after each fix, so this appears
   194  	// redundant, but it is necessary to generate gofmt-compatible
   195  	// source code in a few cases. The official gofmt style is the
   196  	// output of the printer run on a standard AST generated by the parser,
   197  	// but the source we generated inside the loop above is the
   198  	// output of the printer run on a mangled AST generated by a fixer.
   199  	newSrc, err = gofmtFile(newFile)
   200  	if err != nil {
   201  		return err
   202  	}
   203  
   204  	if *doDiff {
   205  		data, err := diff.Diff("go-fix", src, newSrc)
   206  		if err != nil {
   207  			return fmt.Errorf("computing diff: %s", err)
   208  		}
   209  		fmt.Printf("diff %s fixed/%s\n", filename, filename)
   210  		os.Stdout.Write(data)
   211  		return nil
   212  	}
   213  
   214  	if useStdin {
   215  		os.Stdout.Write(newSrc)
   216  		return nil
   217  	}
   218  
   219  	return os.WriteFile(f.Name(), newSrc, 0)
   220  }
   221  
   222  func gofmt(n interface{}) string {
   223  	var gofmtBuf bytes.Buffer
   224  	if err := format.Node(&gofmtBuf, fset, n); err != nil {
   225  		return "<" + err.Error() + ">"
   226  	}
   227  	return gofmtBuf.String()
   228  }
   229  
   230  func report(err error) {
   231  	scanner.PrintError(os.Stderr, err)
   232  	exitCode = 2
   233  }
   234  
   235  func walkDir(path string) {
   236  	filepath.WalkDir(path, visitFile)
   237  }
   238  
   239  func visitFile(path string, f fs.DirEntry, err error) error {
   240  	if err == nil && isGoFile(f) {
   241  		err = processFile(path, false)
   242  	}
   243  	if err != nil {
   244  		report(err)
   245  	}
   246  	return nil
   247  }
   248  
   249  func isGoFile(f fs.DirEntry) bool {
   250  	// ignore non-Go files
   251  	name := f.Name()
   252  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   253  }
   254  

View as plain text