Source file
src/cmd/fix/main.go
Documentation: cmd/fix
1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "fmt"
11 "go/ast"
12 "go/format"
13 "go/parser"
14 "go/scanner"
15 "go/token"
16 "io"
17 "io/fs"
18 "os"
19 "path/filepath"
20 "sort"
21 "strings"
22
23 "cmd/internal/diff"
24 )
25
26 var (
27 fset = token.NewFileSet()
28 exitCode = 0
29 )
30
31 var allowedRewrites = flag.String("r", "",
32 "restrict the rewrites to this comma-separated list")
33
34 var forceRewrites = flag.String("force", "",
35 "force these fixes to run even if the code looks updated")
36
37 var allowed, force map[string]bool
38
39 var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
40
41
42 const debug = false
43
44 func usage() {
45 fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
46 flag.PrintDefaults()
47 fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
48 sort.Sort(byName(fixes))
49 for _, f := range fixes {
50 if f.disabled {
51 fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
52 } else {
53 fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
54 }
55 desc := strings.TrimSpace(f.desc)
56 desc = strings.ReplaceAll(desc, "\n", "\n\t")
57 fmt.Fprintf(os.Stderr, "\t%s\n", desc)
58 }
59 os.Exit(2)
60 }
61
62 func main() {
63 flag.Usage = usage
64 flag.Parse()
65
66 sort.Sort(byDate(fixes))
67
68 if *allowedRewrites != "" {
69 allowed = make(map[string]bool)
70 for _, f := range strings.Split(*allowedRewrites, ",") {
71 allowed[f] = true
72 }
73 }
74
75 if *forceRewrites != "" {
76 force = make(map[string]bool)
77 for _, f := range strings.Split(*forceRewrites, ",") {
78 force[f] = true
79 }
80 }
81
82 if flag.NArg() == 0 {
83 if err := processFile("standard input", true); err != nil {
84 report(err)
85 }
86 os.Exit(exitCode)
87 }
88
89 for i := 0; i < flag.NArg(); i++ {
90 path := flag.Arg(i)
91 switch dir, err := os.Stat(path); {
92 case err != nil:
93 report(err)
94 case dir.IsDir():
95 walkDir(path)
96 default:
97 if err := processFile(path, false); err != nil {
98 report(err)
99 }
100 }
101 }
102
103 os.Exit(exitCode)
104 }
105
106 const parserMode = parser.ParseComments
107
108 func gofmtFile(f *ast.File) ([]byte, error) {
109 var buf bytes.Buffer
110 if err := format.Node(&buf, fset, f); err != nil {
111 return nil, err
112 }
113 return buf.Bytes(), nil
114 }
115
116 func processFile(filename string, useStdin bool) error {
117 var f *os.File
118 var err error
119 var fixlog bytes.Buffer
120
121 if useStdin {
122 f = os.Stdin
123 } else {
124 f, err = os.Open(filename)
125 if err != nil {
126 return err
127 }
128 defer f.Close()
129 }
130
131 src, err := io.ReadAll(f)
132 if err != nil {
133 return err
134 }
135
136 file, err := parser.ParseFile(fset, filename, src, parserMode)
137 if err != nil {
138 return err
139 }
140
141
142
143 newSrc, err := gofmtFile(file)
144 if err != nil {
145 return err
146 }
147 if !bytes.Equal(newSrc, src) {
148 newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
149 if err != nil {
150 return err
151 }
152 file = newFile
153 fmt.Fprintf(&fixlog, " fmt")
154 }
155
156
157 newFile := file
158 fixed := false
159 for _, fix := range fixes {
160 if allowed != nil && !allowed[fix.name] {
161 continue
162 }
163 if fix.disabled && !force[fix.name] {
164 continue
165 }
166 if fix.f(newFile) {
167 fixed = true
168 fmt.Fprintf(&fixlog, " %s", fix.name)
169
170
171
172
173 newSrc, err := gofmtFile(newFile)
174 if err != nil {
175 return err
176 }
177 newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
178 if err != nil {
179 if debug {
180 fmt.Printf("%s", newSrc)
181 report(err)
182 os.Exit(exitCode)
183 }
184 return err
185 }
186 }
187 }
188 if !fixed {
189 return nil
190 }
191 fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
192
193
194
195
196
197
198
199 newSrc, err = gofmtFile(newFile)
200 if err != nil {
201 return err
202 }
203
204 if *doDiff {
205 data, err := diff.Diff("go-fix", src, newSrc)
206 if err != nil {
207 return fmt.Errorf("computing diff: %s", err)
208 }
209 fmt.Printf("diff %s fixed/%s\n", filename, filename)
210 os.Stdout.Write(data)
211 return nil
212 }
213
214 if useStdin {
215 os.Stdout.Write(newSrc)
216 return nil
217 }
218
219 return os.WriteFile(f.Name(), newSrc, 0)
220 }
221
222 func gofmt(n interface{}) string {
223 var gofmtBuf bytes.Buffer
224 if err := format.Node(&gofmtBuf, fset, n); err != nil {
225 return "<" + err.Error() + ">"
226 }
227 return gofmtBuf.String()
228 }
229
230 func report(err error) {
231 scanner.PrintError(os.Stderr, err)
232 exitCode = 2
233 }
234
235 func walkDir(path string) {
236 filepath.WalkDir(path, visitFile)
237 }
238
239 func visitFile(path string, f fs.DirEntry, err error) error {
240 if err == nil && isGoFile(f) {
241 err = processFile(path, false)
242 }
243 if err != nil {
244 report(err)
245 }
246 return nil
247 }
248
249 func isGoFile(f fs.DirEntry) bool {
250
251 name := f.Name()
252 return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
253 }
254
View as plain text