@@ -18,6 +18,7 @@ package main
18
18
19
19
import (
20
20
"bytes"
21
+ "errors"
21
22
"flag"
22
23
"fmt"
23
24
"html/template"
@@ -26,8 +27,9 @@ import (
26
27
"os"
27
28
"path/filepath"
28
29
"strings"
29
- "sync"
30
30
"time"
31
+
32
+ "golang.org/x/sync/errgroup"
31
33
)
32
34
33
35
const helpText = `Usage: addlicense [flags] pattern [pattern ...]
@@ -45,11 +47,12 @@ Flags:
45
47
`
46
48
47
49
var (
48
- holder = flag .String ("c" , "Google LLC" , "copyright holder" )
49
- license = flag .String ("l" , "apache" , "license type: apache, bsd, mit" )
50
- licensef = flag .String ("f" , "" , "license file" )
51
- year = flag .String ("y" , fmt .Sprint (time .Now ().Year ()), "copyright year(s)" )
52
- verbose = flag .Bool ("v" , false , "verbose mode: print the name of the files that are modified" )
50
+ holder = flag .String ("c" , "Google LLC" , "copyright holder" )
51
+ license = flag .String ("l" , "apache" , "license type: apache, bsd, mit" )
52
+ licensef = flag .String ("f" , "" , "license file" )
53
+ year = flag .String ("y" , fmt .Sprint (time .Now ().Year ()), "copyright year(s)" )
54
+ verbose = flag .Bool ("v" , false , "verbose mode: print the name of the files that are modified" )
55
+ checkonly = flag .Bool ("check" , false , "check only mode: verify presence of license headers and exit with non-zero code if missing" )
53
56
)
54
57
55
58
func main () {
@@ -92,23 +95,48 @@ func main() {
92
95
ch := make (chan * file , 1000 )
93
96
done := make (chan struct {})
94
97
go func () {
95
- var wg sync. WaitGroup
98
+ var wg errgroup. Group
96
99
for f := range ch {
97
- wg .Add (1 )
98
- go func (f * file ) {
99
- defer wg .Done ()
100
- modified , err := addLicense (f .path , f .mode , t , data )
101
- if err != nil {
102
- log .Printf ("%s: %v" , f .path , err )
103
- return
104
- }
105
- if * verbose && modified {
106
- log .Printf ("%s modified" , f .path )
100
+ f := f // https://golang.org/doc/faq#closures_and_goroutines
101
+ wg .Go (func () error {
102
+ if * checkonly {
103
+ // Check if file extension is known
104
+ lic , err := licenseHeader (f .path , t , data )
105
+ if err != nil {
106
+ log .Printf ("%s: %v" , f .path , err )
107
+ return err
108
+ }
109
+ if lic == nil { // Unknown fileExtension
110
+ return nil
111
+ }
112
+ // Check if file has a license
113
+ isMissingLicenseHeader , err := fileHasLicense (f .path )
114
+ if err != nil {
115
+ log .Printf ("%s: %v" , f .path , err )
116
+ return err
117
+ }
118
+ if isMissingLicenseHeader {
119
+ fmt .Printf ("%s\n " , f .path )
120
+ return errors .New ("missing license header" )
121
+ }
122
+ } else {
123
+ modified , err := addLicense (f .path , f .mode , t , data )
124
+ if err != nil {
125
+ log .Printf ("%s: %v" , f .path , err )
126
+ return err
127
+ }
128
+ if * verbose && modified {
129
+ log .Printf ("%s modified" , f .path )
130
+ }
107
131
}
108
- }(f )
132
+ return nil
133
+ })
109
134
}
110
- wg .Wait ()
135
+ err := wg .Wait ()
111
136
close (done )
137
+ if err != nil {
138
+ os .Exit (1 )
139
+ }
112
140
}()
113
141
114
142
for _ , d := range flag .Args () {
@@ -138,11 +166,45 @@ func walk(ch chan<- *file, start string) {
138
166
}
139
167
140
168
func addLicense (path string , fmode os.FileMode , tmpl * template.Template , data * copyrightData ) (bool , error ) {
169
+ var lic []byte
170
+ var err error
171
+ lic , err = licenseHeader (path , tmpl , data )
172
+ if err != nil || lic == nil {
173
+ return false , err
174
+ }
175
+
176
+ b , err := ioutil .ReadFile (path )
177
+ if err != nil || hasLicense (b ) {
178
+ return false , err
179
+ }
180
+
181
+ line := hashBang (b )
182
+ if len (line ) > 0 {
183
+ b = b [len (line ):]
184
+ if line [len (line )- 1 ] != '\n' {
185
+ line = append (line , '\n' )
186
+ }
187
+ lic = append (line , lic ... )
188
+ }
189
+ b = append (lic , b ... )
190
+ return true , ioutil .WriteFile (path , b , fmode )
191
+ }
192
+
193
+ // fileHasLicense reports whether the file at path contains a license header.
194
+ func fileHasLicense (path string ) (bool , error ) {
195
+ b , err := ioutil .ReadFile (path )
196
+ if err != nil || hasLicense (b ) {
197
+ return false , err
198
+ }
199
+ return true , nil
200
+ }
201
+
202
+ func licenseHeader (path string , tmpl * template.Template , data * copyrightData ) ([]byte , error ) {
141
203
var lic []byte
142
204
var err error
143
205
switch fileExtension (path ) {
144
206
default :
145
- return false , nil
207
+ return nil , nil
146
208
case ".c" , ".h" :
147
209
lic , err = prefix (tmpl , data , "/*" , " * " , " */" )
148
210
case ".js" , ".jsx" , ".tsx" , ".css" , ".tf" , ".ts" :
@@ -164,25 +226,7 @@ func addLicense(path string, fmode os.FileMode, tmpl *template.Template, data *c
164
226
case ".ml" , ".mli" , ".mll" , ".mly" :
165
227
lic , err = prefix (tmpl , data , "(**" , " " , "*)" )
166
228
}
167
- if err != nil || lic == nil {
168
- return false , err
169
- }
170
-
171
- b , err := ioutil .ReadFile (path )
172
- if err != nil || hasLicense (b ) {
173
- return false , err
174
- }
175
-
176
- line := hashBang (b )
177
- if len (line ) > 0 {
178
- b = b [len (line ):]
179
- if line [len (line )- 1 ] != '\n' {
180
- line = append (line , '\n' )
181
- }
182
- lic = append (line , lic ... )
183
- }
184
- b = append (lic , b ... )
185
- return true , ioutil .WriteFile (path , b , fmode )
229
+ return lic , err
186
230
}
187
231
188
232
func fileExtension (name string ) string {
0 commit comments