Skip to content

Commit 366936e

Browse files
committed
tests: add unit tests for header and template funcs
Much of this is being tested with the existing file-based tests in testdata/*, but this moves us toward a much more targetted and simpler test structure.
1 parent ef04bb3 commit 366936e

File tree

3 files changed

+238
-1
lines changed

3 files changed

+238
-1
lines changed

main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -334,5 +334,5 @@ func hasLicense(b []byte) bool {
334334
}
335335
return bytes.Contains(bytes.ToLower(b[:n]), []byte("copyright")) ||
336336
bytes.Contains(bytes.ToLower(b[:n]), []byte("mozilla public")) ||
337-
bytes.Contains(bytes.ToLower(b[:n]), []byte("SPDX-License-Identifier"))
337+
bytes.Contains(bytes.ToLower(b[:n]), []byte("spdx-license-identifier"))
338338
}

main_test.go

+189
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package main
1616

1717
import (
18+
"html/template"
1819
"io/ioutil"
1920
"os"
2021
"os/exec"
@@ -206,3 +207,191 @@ func TestMPL(t *testing.T) {
206207
t.Fatalf("%v\n%s", err, out)
207208
}
208209
}
210+
211+
func createTempFile(contents string, pattern string) (*os.File, error) {
212+
f, err := ioutil.TempFile("", pattern)
213+
if err != nil {
214+
return nil, err
215+
}
216+
217+
if err := ioutil.WriteFile(f.Name(), []byte(contents), 0644); err != nil {
218+
return nil, err
219+
}
220+
221+
return f, nil
222+
}
223+
224+
func TestAddLicense(t *testing.T) {
225+
tmpl := template.Must(template.New("").Parse("{{.Holder}}{{.Year}}{{.SPDXID}}"))
226+
data := licenseData{Holder: "H", Year: "Y", SPDXID: "S"}
227+
228+
tests := []struct {
229+
contents string
230+
wantContents string
231+
wantUpdated bool
232+
}{
233+
{"", "// HYS\n\n", true},
234+
{"content", "// HYS\n\ncontent", true},
235+
236+
// various headers that should be left intact. Many don't make
237+
// sense for our temp file extension, but that doesn't matter.
238+
{"#!/bin/bash\ncontent", "#!/bin/bash\n// HYS\n\ncontent", true},
239+
{"<?xml version='1.0'?>\ncontent", "<?xml version='1.0'?>\n// HYS\n\ncontent", true},
240+
{"<!doctype html>\ncontent", "<!doctype html>\n// HYS\n\ncontent", true},
241+
{"<!DOCTYPE HTML>\ncontent", "<!DOCTYPE HTML>\n// HYS\n\ncontent", true},
242+
{"# encoding: UTF-8\ncontent", "# encoding: UTF-8\n// HYS\n\ncontent", true},
243+
{"# frozen_string_literal: true\ncontent", "# frozen_string_literal: true\n// HYS\n\ncontent", true},
244+
{"<?php\ncontent", "<?php\n// HYS\n\ncontent", true},
245+
246+
// ensure files with existing license or generated files are
247+
// skipped. No need to test all permutations of these, since
248+
// there are specific tests below.
249+
{"// Copyright 2000 Acme\ncontent", "// Copyright 2000 Acme\ncontent", false},
250+
{"// Code generated by go generate; DO NOT EDIT.\ncontent", "// Code generated by go generate; DO NOT EDIT.\ncontent", false},
251+
}
252+
253+
for _, tt := range tests {
254+
// create temp file with contents
255+
f, err := createTempFile(tt.contents, "*.go")
256+
if err != nil {
257+
t.Error(err)
258+
}
259+
fi, err := f.Stat()
260+
if err != nil {
261+
t.Error(err)
262+
}
263+
264+
// run addlicense
265+
updated, err := addLicense(f.Name(), fi.Mode(), tmpl, data)
266+
if err != nil {
267+
t.Error(err)
268+
}
269+
270+
// check results
271+
if updated != tt.wantUpdated {
272+
t.Errorf("addLicense with contents %q returned updated: %t, want %t", tt.contents, updated, tt.wantUpdated)
273+
}
274+
gotContents, err := ioutil.ReadFile(f.Name())
275+
if err != nil {
276+
t.Error(err)
277+
}
278+
if got := string(gotContents); got != tt.wantContents {
279+
t.Errorf("addLicense with contents %q returned contents: %q, want %q", tt.contents, got, tt.wantContents)
280+
}
281+
282+
// if all tests passed, cleanup temp file
283+
if !t.Failed() {
284+
_ = os.Remove(f.Name())
285+
}
286+
}
287+
}
288+
289+
// Test that license headers are added using the appropriate prefix for
290+
// different filenames and extensions.
291+
func TestLicenseHeader(t *testing.T) {
292+
tpl := template.Must(template.New("").Parse("{{.Holder}}{{.Year}}{{.SPDXID}}"))
293+
data := licenseData{Holder: "H", Year: "Y", SPDXID: "S"}
294+
295+
tests := []struct {
296+
paths []string // paths passed to licenseHeader
297+
want string // expected result of executing template
298+
}{
299+
{
300+
[]string{"f.unknown"},
301+
"",
302+
},
303+
{
304+
[]string{"f.c", "f.h", "f.gv"},
305+
"/*\n * HYS\n */\n\n",
306+
},
307+
{
308+
[]string{"f.js", "f.mjs", "f.cjs", "f.jsx", "f.tsx", "f.css", "f.scss", "f.sass", "f.tf", "f.ts"},
309+
"/**\n * HYS\n */\n\n",
310+
},
311+
{
312+
[]string{"f.cc", "f.cpp", "f.cs", "f.go", "f.hh", "f.hpp", "f.java", "f.m", "f.mm", "f.proto",
313+
"f.rs", "f.scala", "f.swift", "f.dart", "f.groovy", "f.kt", "f.kts", "f.v", "f.sv", "f.php"},
314+
"// HYS\n\n",
315+
},
316+
{
317+
[]string{"f.py", "f.sh", "f.yaml", "f.yml", "f.dockerfile", "dockerfile", "f.rb", "gemfile", "f.tcl", "f.bzl"},
318+
"# HYS\n\n",
319+
},
320+
{
321+
[]string{"f.el", "f.lisp"},
322+
";; HYS\n\n",
323+
},
324+
{
325+
[]string{"f.erl"},
326+
"% HYS\n\n",
327+
},
328+
{
329+
[]string{"f.hs", "f.sql", "f.sdl"},
330+
"-- HYS\n\n",
331+
},
332+
{
333+
[]string{"f.html", "f.xml", "f.vue"},
334+
"<!--\n HYS\n-->\n\n",
335+
},
336+
{
337+
[]string{"f.ml", "f.mli", "f.mll", "f.mly"},
338+
"(**\n HYS\n*)\n\n",
339+
},
340+
}
341+
342+
for _, tt := range tests {
343+
for _, path := range tt.paths {
344+
header, _ := licenseHeader(path, tpl, data)
345+
if got := string(header); got != tt.want {
346+
t.Errorf("licenseHeader(%q) returned: %q, want: %q", path, got, tt.want)
347+
}
348+
}
349+
}
350+
}
351+
352+
// Test that generated files are properly recognized.
353+
func TestIsGenerated(t *testing.T) {
354+
tests := []struct {
355+
content string
356+
want bool
357+
}{
358+
{"", false},
359+
{"Generated", false},
360+
{"// Code generated by go generate; DO NOT EDIT.", true},
361+
{"/*\n* Code generated by go generate; DO NOT EDIT.\n*/\n", true},
362+
{"DO NOT EDIT! Replaced on runs of cargo-raze", true},
363+
}
364+
365+
for _, tt := range tests {
366+
b := []byte(tt.content)
367+
if got := isGenerated(b); got != tt.want {
368+
t.Errorf("isGenerated(%q) returned %v, want %v", tt.content, got, tt.want)
369+
}
370+
}
371+
}
372+
373+
// Test that existing license headers are identified.
374+
func TestHasLicense(t *testing.T) {
375+
tests := []struct {
376+
content string
377+
want bool
378+
}{
379+
{"", false},
380+
{"This is my license", false},
381+
{"This code is released into the public domain.", false},
382+
{"SPDX: MIT", false},
383+
384+
{"Copyright 2000", true},
385+
{"CoPyRiGhT 2000", true},
386+
{"Subject to the terms of the Mozilla Public License", true},
387+
{"SPDX-License-Identifier: MIT", true},
388+
{"spdx-license-identifier: MIT", true},
389+
}
390+
391+
for _, tt := range tests {
392+
b := []byte(tt.content)
393+
if got := hasLicense(b); got != tt.want {
394+
t.Errorf("hasLicense(%q) returned %v, want %v", tt.content, got, tt.want)
395+
}
396+
}
397+
}

tmpl_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,51 @@ func TestFetchTemplate(t *testing.T) {
138138
})
139139
}
140140
}
141+
142+
func TestExecuteTemplate(t *testing.T) {
143+
tests := []struct {
144+
template string
145+
data licenseData
146+
top, mid, bot string
147+
want string
148+
}{
149+
{
150+
"",
151+
licenseData{},
152+
"", "", "",
153+
"\n",
154+
},
155+
{
156+
"{{.Holder}}{{.Year}}{{.SPDXID}}",
157+
licenseData{Holder: "H", Year: "Y", SPDXID: "S"},
158+
"", "", "",
159+
"HYS\n\n",
160+
},
161+
{
162+
"{{.Holder}}{{.Year}}{{.SPDXID}}",
163+
licenseData{Holder: "H", Year: "Y", SPDXID: "S"},
164+
"", "// ", "",
165+
"// HYS\n\n",
166+
},
167+
{
168+
"{{.Holder}}{{.Year}}{{.SPDXID}}",
169+
licenseData{Holder: "H", Year: "Y", SPDXID: "S"},
170+
"/*", " * ", "*/",
171+
"/*\n * HYS\n*/\n\n",
172+
},
173+
}
174+
175+
for _, tt := range tests {
176+
tpl, err := template.New("").Parse(tt.template)
177+
if err != nil {
178+
t.Errorf("error parsing template: %v", err)
179+
}
180+
got, err := executeTemplate(tpl, tt.data, tt.top, tt.mid, tt.bot)
181+
if err != nil {
182+
t.Errorf("executeTemplate(%q, %v, %q, %q, %q) returned error: %v", tt.template, tt.data, tt.top, tt.mid, tt.bot, err)
183+
}
184+
if string(got) != tt.want {
185+
t.Errorf("executeTemplate(%q, %v, %q, %q, %q) returned %q, want: %q", tt.template, tt.data, tt.top, tt.mid, tt.bot, string(got), tt.want)
186+
}
187+
}
188+
}

0 commit comments

Comments
 (0)