236 lines
5.3 KiB
Go
236 lines
5.3 KiB
Go
package dkim
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/fs"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
)
|
|
|
|
func TestFromFiles(t *testing.T) {
|
|
msgfs, err := filepath.Glob("testdata/*.msg")
|
|
if err != nil {
|
|
t.Fatalf("error finding test files: %v", err)
|
|
}
|
|
|
|
for _, msgf := range msgfs {
|
|
base := strings.TrimSuffix(msgf, filepath.Ext(msgf))
|
|
t.Run(base, func(t *testing.T) { testOne(t, base) })
|
|
}
|
|
}
|
|
|
|
// This is the same as TestFromFiles, but it runs the private test files,
|
|
// which are not included in the git repository.
|
|
// This is useful for running tests on your own machine, with emails that you
|
|
// don't necessarily want to share publicly.
|
|
func TestFromPrivateFiles(t *testing.T) {
|
|
msgfs, err := filepath.Glob("testdata/private/*/*.msg")
|
|
if err != nil {
|
|
t.Fatalf("error finding private test files: %v", err)
|
|
}
|
|
|
|
for _, msgf := range msgfs {
|
|
base := strings.TrimSuffix(msgf, filepath.Ext(msgf))
|
|
t.Run(base, func(t *testing.T) { testOne(t, base) })
|
|
}
|
|
}
|
|
|
|
func testOne(t *testing.T, base string) {
|
|
ctx := context.Background()
|
|
ctx = WithTraceFunc(ctx, t.Logf)
|
|
|
|
ctx = loadDNS(t, ctx, base+".dns")
|
|
msg := toCRLF(mustReadFile(t, base+".msg"))
|
|
wantResult := loadResult(t, base+".result")
|
|
wantError := loadError(t, base+".error")
|
|
|
|
t.Logf("Message: %.60q", msg)
|
|
t.Logf("Want result: %+v", wantResult)
|
|
t.Logf("Want error: %v", wantError)
|
|
|
|
res, err := VerifyMessage(ctx, msg)
|
|
|
|
// Write the results out for easy updating.
|
|
writeResults(t, base, res, err)
|
|
|
|
diff := cmp.Diff(wantResult, res, cmp.Comparer(equalErrors))
|
|
if diff != "" {
|
|
t.Errorf("VerifyMessage result diff (-want +got):\n%s", diff)
|
|
}
|
|
|
|
// We need to compare them by hand because cmp.Diff won't use our comparer
|
|
// for top-level errors.
|
|
if !equalErrors(wantError, err) {
|
|
diff := cmp.Diff(wantError, err)
|
|
t.Errorf("VerifyMessage error diff (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
|
|
// Used to make cmp.Diff compare errors by their messages. This is obviously
|
|
// not great, but it's good enough for this test.
|
|
func equalErrors(a, b error) bool {
|
|
if a == nil {
|
|
return b == nil
|
|
}
|
|
if b == nil {
|
|
return false
|
|
}
|
|
return a.Error() == b.Error()
|
|
}
|
|
|
|
func mustReadFile(t *testing.T, path string) string {
|
|
t.Helper()
|
|
contents, err := os.ReadFile(path)
|
|
if errors.Is(err, fs.ErrNotExist) {
|
|
return ""
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("error reading %q: %v", path, err)
|
|
}
|
|
return string(contents)
|
|
}
|
|
|
|
func loadDNS(t *testing.T, ctx context.Context, path string) context.Context {
|
|
t.Helper()
|
|
|
|
results := map[string][]string{}
|
|
errors := map[string]error{}
|
|
txtFunc := func(ctx context.Context, domain string) ([]string, error) {
|
|
return results[domain], errors[domain]
|
|
}
|
|
ctx = WithLookupTXTFunc(ctx, txtFunc)
|
|
|
|
c := mustReadFile(t, path)
|
|
|
|
// Unfold \-terminated lines.
|
|
c = strings.ReplaceAll(c, "\\\n", "")
|
|
|
|
for _, line := range strings.Split(c, "\n") {
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
domain, txt, ok := strings.Cut(line, ":")
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
domain = strings.TrimSpace(domain)
|
|
|
|
switch strings.TrimSpace(txt) {
|
|
case "TEMPERROR":
|
|
errors[domain] = &net.DNSError{
|
|
Err: "temporary error (for testing)",
|
|
IsTemporary: true,
|
|
}
|
|
case "PERMERROR":
|
|
errors[domain] = &net.DNSError{
|
|
Err: "permanent error (for testing)",
|
|
IsTemporary: false,
|
|
}
|
|
case "NOTFOUND":
|
|
errors[domain] = &net.DNSError{
|
|
Err: "domain not found (for testing)",
|
|
IsNotFound: true,
|
|
}
|
|
default:
|
|
results[domain] = append(results[domain], txt)
|
|
}
|
|
}
|
|
|
|
t.Logf("Loaded DNS results: %#v", results)
|
|
t.Logf("Loaded DNS errors: %v", errors)
|
|
return ctx
|
|
}
|
|
|
|
func loadResult(t *testing.T, path string) *VerifyResult {
|
|
t.Helper()
|
|
|
|
res := &VerifyResult{}
|
|
c := mustReadFile(t, path)
|
|
if c == "" {
|
|
return nil
|
|
}
|
|
|
|
err := json.Unmarshal([]byte(c), res)
|
|
if err != nil {
|
|
t.Fatalf("error unmarshalling %q: %v", path, err)
|
|
}
|
|
return res
|
|
}
|
|
|
|
func loadError(t *testing.T, path string) error {
|
|
t.Helper()
|
|
|
|
c := strings.TrimSpace(mustReadFile(t, path))
|
|
if c == "" || c == "nil" || c == "<nil>" {
|
|
return nil
|
|
}
|
|
return errors.New(c)
|
|
}
|
|
|
|
func mustWriteFile(t *testing.T, path string, c []byte) {
|
|
t.Helper()
|
|
err := os.WriteFile(path, c, 0644)
|
|
if err != nil {
|
|
t.Fatalf("error writing %q: %v", path, err)
|
|
}
|
|
}
|
|
|
|
func writeResults(t *testing.T, base string, res *VerifyResult, err error) {
|
|
t.Helper()
|
|
|
|
mustWriteFile(t, base+".error.got", []byte(fmt.Sprintf("%v", err)))
|
|
|
|
c, err := json.MarshalIndent(res, "", "\t")
|
|
if err != nil {
|
|
t.Fatalf("error marshalling result: %v", err)
|
|
}
|
|
mustWriteFile(t, base+".result.got", c)
|
|
}
|
|
|
|
// Custom json marshaller so we can write errors as strings.
|
|
func (or *OneResult) MarshalJSON() ([]byte, error) {
|
|
// We use an alias to avoid infinite recursion.
|
|
type Alias OneResult
|
|
aux := &struct {
|
|
Error string `json:""`
|
|
*Alias
|
|
}{
|
|
Alias: (*Alias)(or),
|
|
}
|
|
if or.Error != nil {
|
|
aux.Error = or.Error.Error()
|
|
}
|
|
|
|
return json.Marshal(aux)
|
|
}
|
|
|
|
// Custom json unmarshaller so we can read errors as strings.
|
|
func (or *OneResult) UnmarshalJSON(b []byte) error {
|
|
// We use an alias to avoid infinite recursion.
|
|
type Alias OneResult
|
|
aux := &struct {
|
|
Error string `json:""`
|
|
*Alias
|
|
}{
|
|
Alias: (*Alias)(or),
|
|
}
|
|
if err := json.Unmarshal(b, aux); err != nil {
|
|
return err
|
|
}
|
|
|
|
if aux.Error != "" {
|
|
or.Error = errors.New(aux.Error)
|
|
}
|
|
return nil
|
|
}
|