gloader/grab/grabtest/handler.go
2020-12-09 13:29:14 -08:00

161 lines
3.7 KiB
Go

package grabtest
import (
"bufio"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
)
var (
DefaultHandlerContentLength = 1 << 20
DefaultHandlerMD5Checksum = "c35cc7d8d91728a0cb052831bc4ef372"
DefaultHandlerMD5ChecksumBytes = MustHexDecodeString(DefaultHandlerMD5Checksum)
DefaultHandlerSHA256Checksum = "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83"
DefaultHandlerSHA256ChecksumBytes = MustHexDecodeString(DefaultHandlerSHA256Checksum)
)
type StatusCodeFunc func(req *http.Request) int
type handler struct {
statusCodeFunc StatusCodeFunc
methodWhitelist []string
headerBlacklist []string
contentLength int
acceptRanges bool
attachmentFilename string
lastModified time.Time
ttfb time.Duration
rateLimiter *time.Ticker
}
func NewHandler(options ...HandlerOption) (http.Handler, error) {
h := &handler{
statusCodeFunc: func(req *http.Request) int { return http.StatusOK },
methodWhitelist: []string{"GET", "HEAD"},
contentLength: DefaultHandlerContentLength,
acceptRanges: true,
}
for _, option := range options {
if err := option(h); err != nil {
return nil, err
}
}
return h, nil
}
func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) {
h, err := NewHandler(options...)
if err != nil {
t.Fatalf("unable to create test server handler: %v", err)
return
}
s := httptest.NewServer(h)
defer func() {
h.(*handler).close()
s.Close()
}()
f(s.URL)
}
func (h *handler) close() {
if h.rateLimiter != nil {
h.rateLimiter.Stop()
}
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// delay response
if h.ttfb > 0 {
time.Sleep(h.ttfb)
}
// validate request method
allowed := false
for _, m := range h.methodWhitelist {
if r.Method == m {
allowed = true
break
}
}
if !allowed {
httpError(w, http.StatusMethodNotAllowed)
return
}
// set server options
if h.acceptRanges {
w.Header().Set("Accept-Ranges", "bytes")
}
// set attachment filename
if h.attachmentFilename != "" {
w.Header().Set(
"Content-Disposition",
fmt.Sprintf("attachment;filename=\"%s\"", h.attachmentFilename),
)
}
// set last modified timestamp
lastMod := time.Now()
if !h.lastModified.IsZero() {
lastMod = h.lastModified
}
w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat))
// set content-length
offset := 0
if h.acceptRanges {
if reqRange := r.Header.Get("Range"); reqRange != "" {
if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil {
httpError(w, http.StatusBadRequest)
return
}
if offset >= h.contentLength {
httpError(w, http.StatusRequestedRangeNotSatisfiable)
return
}
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset))
// apply header blacklist
for _, key := range h.headerBlacklist {
w.Header().Del(key)
}
// send header and status code
w.WriteHeader(h.statusCodeFunc(r))
// send body
if r.Method == "GET" {
// use buffered io to reduce overhead on the reader
bw := bufio.NewWriterSize(w, 4096)
for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ {
bw.Write([]byte{byte(i)})
if h.rateLimiter != nil {
bw.Flush()
w.(http.Flusher).Flush() // force the server to send the data to the client
select {
case <-h.rateLimiter.C:
case <-r.Context().Done():
}
}
}
if !isRequestClosed(r) {
bw.Flush()
}
}
}
// isRequestClosed returns true if the client request has been canceled.
func isRequestClosed(r *http.Request) bool {
return r.Context().Err() != nil
}
func httpError(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}