Commit custom grab

This commit is contained in:
lordwelch 2020-12-09 13:29:14 -08:00
parent 17d26242d2
commit f1179ff06e
37 changed files with 4132 additions and 0 deletions

3
grab/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
# ignore IDE project files
*.iml
.idea/

14
grab/.travis.yml Normal file
View File

@ -0,0 +1,14 @@
language: go
go:
- tip
- 1.10.x
- 1.9.x
- 1.8.x
- 1.7.x
script: make check
env:
- GOARCH=amd64
- GOARCH=386

26
grab/LICENSE Normal file
View File

@ -0,0 +1,26 @@
Copyright (c) 2017 Ryan Armstrong. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors
may be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

29
grab/Makefile Normal file
View File

@ -0,0 +1,29 @@
GO = go
GOGET = $(GO) get -u
all: check lint
check:
cd cmd/grab && $(MAKE) -B all
$(GO) test -cover -race ./...
install:
$(GO) install -v ./...
clean:
$(GO) clean -x ./...
rm -rvf ./.test*
lint:
gofmt -l -e -s . || :
go vet . || :
golint . || :
gocyclo -over 15 . || :
misspell ./* || :
deps:
$(GOGET) github.com/golang/lint/golint
$(GOGET) github.com/fzipp/gocyclo
$(GOGET) github.com/client9/misspell/cmd/misspell
.PHONY: all check install clean lint deps

127
grab/README.md Normal file
View File

@ -0,0 +1,127 @@
# grab
[![GoDoc](https://godoc.org/github.com/cavaliercoder/grab?status.svg)](https://godoc.org/github.com/cavaliercoder/grab) [![Build Status](https://travis-ci.org/cavaliercoder/grab.svg?branch=master)](https://travis-ci.org/cavaliercoder/grab) [![Go Report Card](https://goreportcard.com/badge/github.com/cavaliercoder/grab)](https://goreportcard.com/report/github.com/cavaliercoder/grab)
*Downloading the internet, one goroutine at a time!*
$ go get github.com/cavaliercoder/grab
Grab is a Go package for downloading files from the internet with the following
rad features:
* Monitor download progress concurrently
* Auto-resume incomplete downloads
* Guess filename from content header or URL path
* Safely cancel downloads using context.Context
* Validate downloads using checksums
* Download batches of files concurrently
* Apply rate limiters
Requires Go v1.7+
## Example
The following example downloads a PDF copy of the free eBook, "An Introduction
to Programming in Go" into the current working directory.
```go
resp, err := grab.Get(".", "http://www.golang-book.com/public/pdf/gobook.pdf")
if err != nil {
log.Fatal(err)
}
fmt.Println("Download saved to", resp.Filename)
```
The following, more complete example allows for more granular control and
periodically prints the download progress until it is complete.
The second time you run the example, it will auto-resume the previous download
and exit sooner.
```go
package main
import (
"fmt"
"os"
"time"
"github.com/cavaliercoder/grab"
)
func main() {
// create client
client := grab.NewClient()
req, _ := grab.NewRequest(".", "http://www.golang-book.com/public/pdf/gobook.pdf")
// start download
fmt.Printf("Downloading %v...\n", req.URL())
resp := client.Do(req)
fmt.Printf(" %v\n", resp.HTTPResponse.Status)
// start UI loop
t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
Loop:
for {
select {
case <-t.C:
fmt.Printf(" transferred %v / %v bytes (%.2f%%)\n",
resp.BytesComplete(),
resp.Size(),
100*resp.Progress())
case <-resp.Done:
// download is complete
break Loop
}
}
// check for errors
if err := resp.Err(); err != nil {
fmt.Fprintf(os.Stderr, "Download failed: %v\n", err)
os.Exit(1)
}
fmt.Printf("Download saved to ./%v \n", resp.Filename)
// Output:
// Downloading http://www.golang-book.com/public/pdf/gobook.pdf...
// 200 OK
// transferred 42970 / 2893557 bytes (1.49%)
// transferred 1207474 / 2893557 bytes (41.73%)
// transferred 2758210 / 2893557 bytes (95.32%)
// Download saved to ./gobook.pdf
}
```
## Design trade-offs
The primary use case for Grab is to concurrently downloading thousands of large
files from remote file repositories where the remote files are immutable.
Examples include operating system package repositories or ISO libraries.
Grab aims to provide robust, sane defaults. These are usually determined using
the HTTP specifications, or by mimicking the behavior of common web clients like
cURL, wget and common web browsers.
Grab aims to be stateless. The only state that exists is the remote files you
wish to download and the local copy which may be completed, partially completed
or not yet created. The advantage to this is that the local file system is not
cluttered unnecessarily with addition state files (like a `.crdownload` file).
The disadvantage of this approach is that grab must make assumptions about the
local and remote state; specifically, that they have not been modified by
another program.
If the local or remote file are modified outside of grab, and you download the
file again with resuming enabled, the local file will likely become corrupted.
In this case, you might consider making remote files immutable, or disabling
resume.
Grab aims to enable best-in-class functionality for more complex features
through extensible interfaces, rather than reimplementation. For example,
you can provide your own Hash algorithm to compute file checksums, or your
own rate limiter implementation (with all the associated trade-offs) to rate
limit downloads.

54
grab/bps/bps.go Normal file
View File

@ -0,0 +1,54 @@
/*
Package bps provides gauges for calculating the Bytes Per Second transfer rate
of data streams.
*/
package bps
import (
"context"
"time"
)
// Gauge is the common interface for all BPS gauges in this package. Given a
// set of samples over time, each gauge type can be used to measure the Bytes
// Per Second transfer rate of a data stream.
//
// All samples must monotonically increase in timestamp and value. Each sample
// should represent the total number of bytes sent in a stream, rather than
// accounting for the number sent since the last sample.
//
// To ensure a gauge can report progress as quickly as possible, take an initial
// sample when your stream first starts.
//
// All gauge implementations are safe for concurrent use.
type Gauge interface {
// Sample adds a new sample of the progress of the monitored stream.
Sample(t time.Time, n int64)
// BPS returns the calculated Bytes Per Second rate of the monitored stream.
BPS() float64
}
// SampleFunc is used by Watch to take periodic samples of a monitored stream.
type SampleFunc func() (n int64)
// Watch will periodically call the given SampleFunc to sample the progress of
// a monitored stream and update the given gauge. SampleFunc should return the
// total number of bytes transferred by the stream since it started.
//
// Watch is a blocking call and should typically be called in a new goroutine.
// To prevent the goroutine from leaking, make sure to cancel the given context
// once the stream is completed or canceled.
func Watch(ctx context.Context, g Gauge, f SampleFunc, interval time.Duration) {
g.Sample(time.Now(), f())
t := time.NewTicker(interval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case now := <-t.C:
g.Sample(now, f())
}
}
}

81
grab/bps/sma.go Normal file
View File

@ -0,0 +1,81 @@
package bps
import (
"sync"
"time"
)
// NewSMA returns a gauge that uses a Simple Moving Average with the given
// number of samples to measure the bytes per second of a byte stream.
//
// BPS is computed using the timestamp of the most recent and oldest sample in
// the sample buffer. When a new sample is added, the oldest sample is dropped
// if the sample count exceeds maxSamples.
//
// The gauge does not account for any latency in arrival time of new samples or
// the desired window size. Any variance in the arrival of samples will result
// in a BPS measurement that is correct for the submitted samples, but over a
// varying time window.
//
// maxSamples should be equal to 1 + (window size / sampling interval) where
// window size is the number of seconds over which the moving average is
// smoothed and sampling interval is the number of seconds between each sample.
//
// For example, if you want a five second window, sampling once per second,
// maxSamples should be 1 + 5/1 = 6.
func NewSMA(maxSamples int) Gauge {
if maxSamples < 2 {
panic("sample count must be greater than 1")
}
return &sma{
maxSamples: uint64(maxSamples),
samples: make([]int64, maxSamples),
timestamps: make([]time.Time, maxSamples),
}
}
type sma struct {
mu sync.Mutex
index uint64
maxSamples uint64
sampleCount uint64
samples []int64
timestamps []time.Time
}
func (c *sma) Sample(t time.Time, n int64) {
c.mu.Lock()
defer c.mu.Unlock()
c.timestamps[c.index] = t
c.samples[c.index] = n
c.index = (c.index + 1) % c.maxSamples
// prevent integer overflow in sampleCount. Values greater or equal to
// maxSamples have the same semantic meaning.
c.sampleCount++
if c.sampleCount > c.maxSamples {
c.sampleCount = c.maxSamples
}
}
func (c *sma) BPS() float64 {
c.mu.Lock()
defer c.mu.Unlock()
// we need two samples to start
if c.sampleCount < 2 {
return 0
}
// First sample is always the oldest until ring buffer first overflows
oldest := c.index
if c.sampleCount < c.maxSamples {
oldest = 0
}
newest := (c.index + c.maxSamples - 1) % c.maxSamples
seconds := c.timestamps[newest].Sub(c.timestamps[oldest]).Seconds()
bytes := float64(c.samples[newest] - c.samples[oldest])
return bytes / seconds
}

55
grab/bps/sma_test.go Normal file
View File

@ -0,0 +1,55 @@
package bps
import (
"testing"
"time"
)
type Sample struct {
N int64
Expect float64
}
func getSimpleSamples(sampleCount, rate int) []Sample {
a := make([]Sample, sampleCount)
for i := 1; i < sampleCount; i++ {
a[i] = Sample{N: int64(i * rate), Expect: float64(rate)}
}
return a
}
type SampleSetTest struct {
Gauge Gauge
Interval time.Duration
Samples []Sample
}
func (c *SampleSetTest) Run(t *testing.T) {
ts := time.Unix(0, 0)
for i, sample := range c.Samples {
c.Gauge.Sample(ts, sample.N)
if actual := c.Gauge.BPS(); actual != sample.Expect {
t.Errorf("expected: Gauge.BPS() → %0.2f, got %0.2f in test %d", sample.Expect, actual, i+1)
}
ts = ts.Add(c.Interval)
}
}
func TestSMA_SimpleSteadyCase(t *testing.T) {
test := &SampleSetTest{
Interval: time.Second,
Samples: getSimpleSamples(100000, 3),
}
t.Run("SmallSampleSize", func(t *testing.T) {
test.Gauge = NewSMA(2)
test.Run(t)
})
t.Run("RegularSize", func(t *testing.T) {
test.Gauge = NewSMA(6)
test.Run(t)
})
t.Run("LargeSampleSize", func(t *testing.T) {
test.Gauge = NewSMA(1000)
test.Run(t)
})
}

570
grab/client.go Normal file
View File

@ -0,0 +1,570 @@
package grab
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
)
// HTTPClient provides an interface allowing us to perform HTTP requests.
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// truncater is a private interface allowing different response
// Writers to be truncated
type truncater interface {
Truncate(size int64) error
}
// A Client is a file download client.
//
// Clients are safe for concurrent use by multiple goroutines.
type Client struct {
// HTTPClient specifies the http.Client which will be used for communicating
// with the remote server during the file transfer.
HTTPClient HTTPClient
// UserAgent specifies the User-Agent string which will be set in the
// headers of all requests made by this client.
//
// The user agent string may be overridden in the headers of each request.
UserAgent string
// BufferSize specifies the size in bytes of the buffer that is used for
// transferring all requested files. Larger buffers may result in faster
// throughput but will use more memory and result in less frequent updates
// to the transfer progress statistics. The BufferSize of each request can
// be overridden on each Request object. Default: 32KB.
BufferSize int
}
// NewClient returns a new file download Client, using default configuration.
func NewClient() *Client {
return &Client{
UserAgent: "grab",
HTTPClient: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
},
}
}
// DefaultClient is the default client and is used by all Get convenience
// functions.
var DefaultClient = NewClient()
// Do sends a file transfer request and returns a file transfer response,
// following policy (e.g. redirects, cookies, auth) as configured on the
// client's HTTPClient.
//
// Like http.Get, Do blocks while the transfer is initiated, but returns as soon
// as the transfer has started transferring in a background goroutine, or if it
// failed early.
//
// An error is returned via Response.Err if caused by client policy (such as
// CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err
// will block the caller until the transfer is completed, successfully or
// otherwise.
func (c *Client) Do(req *Request) *Response {
// cancel will be called on all code-paths via closeResponse
ctx, cancel := context.WithCancel(req.Context())
req = req.WithContext(ctx)
resp := &Response{
Request: req,
Start: time.Now(),
Done: make(chan struct{}, 0),
Filename: req.Filename,
ctx: ctx,
cancel: cancel,
bufferSize: req.BufferSize,
}
if resp.bufferSize == 0 {
// default to Client.BufferSize
resp.bufferSize = c.BufferSize
}
// Run state-machine while caller is blocked to initialize the file transfer.
// Must never transition to the copyFile state - this happens next in another
// goroutine.
c.run(resp, c.statFileInfo)
// Run copyFile in a new goroutine. copyFile will no-op if the transfer is
// already complete or failed.
go c.run(resp, c.copyFile)
return resp
}
// DoChannel executes all requests sent through the given Request channel, one
// at a time, until it is closed by another goroutine. The caller is blocked
// until the Request channel is closed and all transfers have completed. All
// responses are sent through the given Response channel as soon as they are
// received from the remote servers and can be used to track the progress of
// each download.
//
// Slow Response receivers will cause a worker to block and therefore delay the
// start of the transfer for an already initiated connection - potentially
// causing a server timeout. It is the caller's responsibility to ensure a
// sufficient buffer size is used for the Response channel to prevent this.
//
// If an error occurs during any of the file transfers it will be accessible via
// the associated Response.Err function.
func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) {
// TODO: enable cancelling of batch jobs
for req := range reqch {
resp := c.Do(req)
respch <- resp
<-resp.Done
}
}
// DoBatch executes all the given requests using the given number of concurrent
// workers. Control is passed back to the caller as soon as the workers are
// initiated.
//
// If the requested number of workers is less than one, a worker will be created
// for every request. I.e. all requests will be executed concurrently.
//
// If an error occurs during any of the file transfers it will be accessible via
// call to the associated Response.Err.
//
// The returned Response channel is closed only after all of the given Requests
// have completed, successfully or otherwise.
func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response {
if workers < 1 {
workers = len(requests)
}
reqch := make(chan *Request, len(requests))
respch := make(chan *Response, len(requests))
wg := sync.WaitGroup{}
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
c.DoChannel(reqch, respch)
wg.Done()
}()
}
// queue requests
go func() {
for _, req := range requests {
reqch <- req
}
close(reqch)
wg.Wait()
close(respch)
}()
return respch
}
// An stateFunc is an action that mutates the state of a Response and returns
// the next stateFunc to be called.
type stateFunc func(*Response) stateFunc
// run calls the given stateFunc function and all subsequent returned stateFuncs
// until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc
// should mutate the state of the given Response until it has completed
// downloading or failed.
func (c *Client) run(resp *Response, f stateFunc) {
for {
select {
case <-resp.ctx.Done():
if resp.IsComplete() {
return
}
resp.err = resp.ctx.Err()
f = c.closeResponse
default:
// keep working
}
if f = f(resp); f == nil {
return
}
}
}
// statFileInfo retrieves FileInfo for any local file matching
// Response.Filename.
//
// If the file does not exist, is a directory, or its name is unknown the next
// stateFunc is headRequest.
//
// If the file exists, Response.fi is set and the next stateFunc is
// validateLocal.
//
// If an error occurs, the next stateFunc is closeResponse.
func (c *Client) statFileInfo(resp *Response) stateFunc {
if resp.Request.NoStore || resp.Filename == "" { // No filename provided will guess
return c.headRequest
}
fi, err := os.Stat(resp.Filename)
if err != nil {
if os.IsNotExist(err) { // Filename does not exist, will download
return c.headRequest
}
resp.err = err
return c.closeResponse // Other PathError occured
}
if fi.IsDir() { // resp.Request.Filename is a directory
// Will guess filename and append to resp.Request.Filename
resp.Filename = ""
return c.headRequest
}
resp.fi = fi
return c.validateLocal // Filename exists, validate file
}
// validateLocal compares a local copy of the downloaded file to the remote
// file.
//
// An error is returned if the local file is larger than the remote file, or
// Request.SkipExisting is true.
//
// If the existing file matches the length of the remote file, the next
// stateFunc is checksumFile.
//
// If the local file is smaller than the remote file and the remote server is
// known to support ranged requests, the next stateFunc is getRequest.
func (c *Client) validateLocal(resp *Response) stateFunc {
if resp.Request.SkipExisting {
resp.err = ErrFileExists
return c.closeResponse
}
// determine target file size
expectedSize := resp.Request.Size
if expectedSize == 0 && resp.HTTPResponse != nil {
expectedSize = resp.HTTPResponse.ContentLength
}
if expectedSize == 0 {
// size is either actually 0 or unknown
// if unknown, we ask the remote server
// if known to be 0, we proceed with a GET
return c.headRequest
}
if expectedSize == resp.fi.Size() {
// local file matches remote file size - wrap it up
resp.DidResume = true
resp.bytesResumed = resp.fi.Size()
return c.checksumFile
}
if resp.Request.NoResume {
// local file should be overwritten
return c.getRequest
}
if expectedSize >= 0 && expectedSize < resp.fi.Size() {
// remote size is known, is smaller than local size and we want to resume
resp.err = ErrBadLength
return c.closeResponse
}
if resp.CanResume {
// set resume range on GET request
resp.Request.HTTPRequest.Header.Set(
"Range",
fmt.Sprintf("bytes=%d-", resp.fi.Size()))
resp.DidResume = true
resp.bytesResumed = resp.fi.Size()
return c.getRequest
}
return c.headRequest
}
func (c *Client) checksumFile(resp *Response) stateFunc {
if resp.Request.hash == nil {
return c.closeResponse
}
if resp.Filename == "" {
panic("grab: developer error: filename not set")
}
if resp.Size() < 0 {
panic("grab: developer error: size unknown")
}
req := resp.Request
// compute checksum
var sum []byte
sum, resp.err = resp.checksumUnsafe()
if resp.err != nil {
return c.closeResponse
}
// compare checksum
if !bytes.Equal(sum, req.checksum) {
resp.err = ErrBadChecksum
if !resp.Request.NoStore && req.deleteOnError {
if err := os.Remove(resp.Filename); err != nil {
// err should be os.PathError and include file path
resp.err = fmt.Errorf(
"cannot remove downloaded file with checksum mismatch: %v",
err)
}
}
}
return c.closeResponse
}
// doHTTPRequest sends a HTTP Request and returns the response
func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) {
if c.UserAgent != "" && req.Header.Get("User-Agent") == "" {
req.Header.Set("User-Agent", c.UserAgent)
}
return c.HTTPClient.Do(req)
}
func (c *Client) headRequest(resp *Response) stateFunc {
if resp.optionsKnown {
return c.getRequest
}
resp.optionsKnown = true
if resp.Request.NoResume {
return c.getRequest
}
if resp.Filename != "" && resp.fi == nil {
// destination path is already known and does not exist
return c.getRequest
}
hreq := new(http.Request)
*hreq = *resp.Request.HTTPRequest
hreq.Method = "HEAD"
resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq)
if resp.err != nil {
return c.closeResponse
}
resp.HTTPResponse.Body.Close()
if resp.HTTPResponse.StatusCode != http.StatusOK {
return c.getRequest
}
// In case of redirects during HEAD, record the final URL and use it
// instead of the original URL when sending future requests.
// This way we avoid sending potentially unsupported requests to
// the original URL, e.g. "Range", since it was the final URL
// that advertised its support.
resp.Request.HTTPRequest.URL = resp.HTTPResponse.Request.URL
resp.Request.HTTPRequest.Host = resp.HTTPResponse.Request.Host
return c.readResponse
}
func (c *Client) getRequest(resp *Response) stateFunc {
resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest)
if resp.err != nil {
return c.closeResponse
}
// TODO: check Content-Range
// check status code
if !resp.Request.IgnoreBadStatusCodes {
if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 {
resp.err = StatusCodeError(resp.HTTPResponse.StatusCode)
return c.closeResponse
}
}
return c.readResponse
}
func (c *Client) readResponse(resp *Response) stateFunc {
if resp.HTTPResponse == nil {
panic("grab: developer error: Response.HTTPResponse is nil")
}
// check expected size
resp.sizeUnsafe = resp.HTTPResponse.ContentLength
if resp.sizeUnsafe >= 0 {
// remote size is known
resp.sizeUnsafe += resp.bytesResumed
if resp.Request.Size > 0 && resp.Request.Size != resp.sizeUnsafe {
resp.err = ErrBadLength
return c.closeResponse
}
}
// check filename
if resp.Filename == "" {
filename, err := guessFilename(resp.HTTPResponse)
if err != nil {
resp.err = err
return c.closeResponse
}
// Request.Filename will be empty or a directory
resp.Filename = filepath.Join(resp.Request.Filename, filename)
}
if !resp.Request.NoStore && resp.requestMethod() == "HEAD" {
if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" {
resp.CanResume = true
}
return c.statFileInfo
}
return c.openWriter
}
// openWriter opens the destination file for writing and seeks to the location
// from whence the file transfer will resume.
//
// Requires that Response.Filename and resp.DidResume are already be set.
func (c *Client) openWriter(resp *Response) stateFunc {
if !resp.Request.NoStore && !resp.Request.NoCreateDirectories {
resp.err = mkdirp(resp.Filename)
if resp.err != nil {
return c.closeResponse
}
}
if resp.Request.NoStore {
resp.writer = &resp.storeBuffer
} else {
// compute write flags
flag := os.O_CREATE | os.O_WRONLY
if resp.fi != nil {
if resp.DidResume {
flag = os.O_APPEND | os.O_WRONLY
} else {
// truncate later in copyFile, if not cancelled
// by BeforeCopy hook
flag = os.O_WRONLY
}
}
// open file
f, err := os.OpenFile(resp.Filename, flag, 0666)
if err != nil {
resp.err = err
return c.closeResponse
}
resp.writer = f
// seek to start or end
whence := os.SEEK_SET
if resp.bytesResumed > 0 {
whence = os.SEEK_END
}
_, resp.err = f.Seek(0, whence)
if resp.err != nil {
return c.closeResponse
}
}
// init transfer
if resp.bufferSize < 1 {
resp.bufferSize = 32 * 1024
}
b := make([]byte, resp.bufferSize)
resp.transfer = newTransfer(
resp.Request.Context(),
resp.Request.RateLimiter,
resp.writer,
resp.HTTPResponse.Body,
b)
// next step is copyFile, but this will be called later in another goroutine
return nil
}
// copy transfers content for a HTTP connection established via Client.do()
func (c *Client) copyFile(resp *Response) stateFunc {
if resp.IsComplete() {
return nil
}
// run BeforeCopy hook
if f := resp.Request.BeforeCopy; f != nil {
resp.err = f(resp)
if resp.err != nil {
return c.closeResponse
}
}
var bytesCopied int64
if resp.transfer == nil {
panic("grab: developer error: Response.transfer is nil")
}
// We waited to truncate the file in openWriter() to make sure
// the BeforeCopy didn't cancel the copy. If this was an existing
// file that is not going to be resumed, truncate the contents.
if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume {
t.Truncate(0)
}
bytesCopied, resp.err = resp.transfer.copy()
if resp.err != nil {
return c.closeResponse
}
closeWriter(resp)
// set file timestamp
if !resp.Request.NoStore && !resp.Request.IgnoreRemoteTime {
resp.err = setLastModified(resp.HTTPResponse, resp.Filename)
if resp.err != nil {
return c.closeResponse
}
}
// update transfer size if previously unknown
if resp.Size() < 0 {
discoveredSize := resp.bytesResumed + bytesCopied
atomic.StoreInt64(&resp.sizeUnsafe, discoveredSize)
if resp.Request.Size > 0 && resp.Request.Size != discoveredSize {
resp.err = ErrBadLength
return c.closeResponse
}
}
// run AfterCopy hook
if f := resp.Request.AfterCopy; f != nil {
resp.err = f(resp)
if resp.err != nil {
return c.closeResponse
}
}
return c.checksumFile
}
func closeWriter(resp *Response) {
if closer, ok := resp.writer.(io.Closer); ok {
closer.Close()
}
resp.writer = nil
}
// close finalizes the Response
func (c *Client) closeResponse(resp *Response) stateFunc {
if resp.IsComplete() {
panic("grab: developer error: response already closed")
}
resp.fi = nil
closeWriter(resp)
resp.closeResponseBody()
resp.End = time.Now()
close(resp.Done)
if resp.cancel != nil {
resp.cancel()
}
return nil
}

915
grab/client_test.go Normal file
View File

@ -0,0 +1,915 @@
package grab
import (
"bytes"
"context"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
"io/ioutil"
"math/rand"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/cavaliercoder/grab/grabtest"
)
// TestFilenameResolutions tests that the destination filename for Requests can
// be determined correctly, using an explicitly requested path,
// Content-Disposition headers or a URL path - with or without an existing
// target directory.
func TestFilenameResolution(t *testing.T) {
tests := []struct {
Name string
Filename string
URL string
AttachmentFilename string
Expect string
}{
{"Using Request.Filename", ".testWithFilename", "/url-filename", "header-filename", ".testWithFilename"},
{"Using Content-Disposition Header", "", "/url-filename", ".testWithHeaderFilename", ".testWithHeaderFilename"},
{"Using Content-Disposition Header with target directory", ".test", "/url-filename", "header-filename", ".test/header-filename"},
{"Using URL Path", "", "/.testWithURLFilename?params-filename", "", ".testWithURLFilename"},
{"Using URL Path with target directory", ".test", "/url-filename?garbage", "", ".test/url-filename"},
{"Failure", "", "", "", ""},
}
err := os.Mkdir(".test", 0777)
if err != nil {
panic(err)
}
defer os.RemoveAll(".test")
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
opts := []grabtest.HandlerOption{}
if test.AttachmentFilename != "" {
opts = append(opts, grabtest.AttachmentFilename(test.AttachmentFilename))
}
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(test.Filename, url+test.URL)
resp := DefaultClient.Do(req)
defer os.Remove(resp.Filename)
if err := resp.Err(); err != nil {
if test.Expect != "" || err != ErrNoFilename {
panic(err)
}
} else {
if test.Expect == "" {
t.Errorf("expected: %v, got: %v", ErrNoFilename, err)
}
}
if resp.Filename != test.Expect {
t.Errorf("Filename mismatch. Expected '%s', got '%s'.", test.Expect, resp.Filename)
}
testComplete(t, resp)
}, opts...)
})
}
}
// TestChecksums checks that checksum validation behaves as expected for valid
// and corrupted downloads.
func TestChecksums(t *testing.T) {
tests := []struct {
size int
hash hash.Hash
sum string
match bool
}{
{128, md5.New(), "37eff01866ba3f538421b30b7cbefcac", true},
{128, md5.New(), "37eff01866ba3f538421b30b7cbefcad", false},
{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855b", true},
{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855c", false},
{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef372", true},
{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef373", false},
{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d535", true},
{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d536", false},
{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b77", true},
{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b78", false},
{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923be", true},
{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923bf", false},
{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be5", true},
{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be4", false},
{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c9", true},
{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c8", false},
{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83", true},
{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c82", false},
{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f7", true},
{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f8", false},
{1024, sha512.New(), "37f652be867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566c", true},
{1024, sha512.New(), "37f652bf867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566d", false},
{1048576, sha512.New(), "ac1d097b4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", true},
{1048576, sha512.New(), "ac1d097c4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", false},
}
for _, test := range tests {
var expect error
comparison := "Match"
if !test.match {
comparison = "Mismatch"
expect = ErrBadChecksum
}
t.Run(fmt.Sprintf("With%s%s", comparison, test.sum[:8]), func(t *testing.T) {
filename := fmt.Sprintf(".testChecksum-%s-%s", comparison, test.sum[:8])
defer os.Remove(filename)
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.SetChecksum(test.hash, grabtest.MustHexDecodeString(test.sum), true)
resp := DefaultClient.Do(req)
err := resp.Err()
if err != expect {
t.Errorf("expected error: %v, got: %v", expect, err)
}
// ensure mismatch file was deleted
if !test.match {
if _, err := os.Stat(filename); err == nil {
t.Errorf("checksum failure not cleaned up: %s", filename)
} else if !os.IsNotExist(err) {
panic(err)
}
}
testComplete(t, resp)
}, grabtest.ContentLength(test.size))
})
}
}
// TestContentLength ensures that ErrBadLength is returned if a server response
// does not match the requested length.
func TestContentLength(t *testing.T) {
size := int64(32768)
testCases := []struct {
Name string
NoHead bool
Size int64
Expect int64
Match bool
}{
{"Good size in HEAD request", false, size, size, true},
{"Good size in GET request", true, size, size, true},
{"Bad size in HEAD request", false, size - 1, size, false},
{"Bad size in GET request", true, size - 1, size, false},
}
for _, test := range testCases {
t.Run(test.Name, func(t *testing.T) {
opts := []grabtest.HandlerOption{
grabtest.ContentLength(int(test.Size)),
}
if test.NoHead {
opts = append(opts, grabtest.MethodWhitelist("GET"))
}
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(".testSize-mismatch-head", url)
req.Size = size
resp := DefaultClient.Do(req)
defer os.Remove(resp.Filename)
err := resp.Err()
if test.Match {
if err == ErrBadLength {
t.Errorf("error: %v", err)
} else if err != nil {
panic(err)
} else if resp.Size() != size {
t.Errorf("expected %v bytes, got %v bytes", size, resp.Size())
}
} else {
if err == nil {
t.Errorf("expected: %v, got %v", ErrBadLength, err)
} else if err != ErrBadLength {
panic(err)
}
}
testComplete(t, resp)
}, opts...)
})
}
}
// TestAutoResume tests segmented downloading of a large file.
func TestAutoResume(t *testing.T) {
segs := 8
size := 1048576
sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grabtest.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83")
filename := ".testAutoResume"
defer os.Remove(filename)
for i := 0; i < segs; i++ {
segsize := (i + 1) * (size / segs)
t.Run(fmt.Sprintf("With%vBytes", segsize), func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
if i == segs-1 {
req.SetChecksum(sha256.New(), sum, false)
}
resp := mustDo(req)
if i > 0 && !resp.DidResume {
t.Errorf("expected Response.DidResume to be true")
}
testComplete(t, resp)
},
grabtest.ContentLength(segsize),
)
})
}
t.Run("WithFailure", func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
// request smaller segment
req := mustNewRequest(filename, url)
resp := DefaultClient.Do(req)
if err := resp.Err(); err != ErrBadLength {
t.Errorf("expected ErrBadLength for smaller request, got: %v", err)
}
},
grabtest.ContentLength(size-128),
)
})
t.Run("WithNoResume", func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.NoResume = true
resp := mustDo(req)
if resp.DidResume {
t.Errorf("expected Response.DidResume to be false")
}
testComplete(t, resp)
},
grabtest.ContentLength(size+128),
)
})
t.Run("WithNoResumeAndTruncate", func(t *testing.T) {
size := size - 128
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.NoResume = true
resp := mustDo(req)
if resp.DidResume {
t.Errorf("expected Response.DidResume to be false")
}
if v := resp.BytesComplete(); v != int64(size) {
t.Errorf("expected Response.BytesComplete: %d, got: %d", size, v)
}
testComplete(t, resp)
},
grabtest.ContentLength(size),
)
})
t.Run("WithNoContentLengthHeader", func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.SetChecksum(sha256.New(), sum, false)
resp := mustDo(req)
if !resp.DidResume {
t.Errorf("expected Response.DidResume to be true")
}
if actual := resp.Size(); actual != int64(size) {
t.Errorf("expected Response.Size: %d, got: %d", size, actual)
}
testComplete(t, resp)
},
grabtest.ContentLength(size),
grabtest.HeaderBlacklist("Content-Length"),
)
})
t.Run("WithNoContentLengthHeaderAndChecksumFailure", func(t *testing.T) {
// ref: https://github.com/cavaliercoder/grab/pull/27
size := size * 2
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.SetChecksum(sha256.New(), sum, false)
resp := DefaultClient.Do(req)
if err := resp.Err(); err != ErrBadChecksum {
t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
}
if !resp.DidResume {
t.Errorf("expected Response.DidResume to be true")
}
if actual := resp.BytesComplete(); actual != int64(size) {
t.Errorf("expected Response.BytesComplete: %d, got: %d", size, actual)
}
if actual := resp.Size(); actual != int64(size) {
t.Errorf("expected Response.Size: %d, got: %d", size, actual)
}
testComplete(t, resp)
},
grabtest.ContentLength(size),
grabtest.HeaderBlacklist("Content-Length"),
)
})
// TODO: test when existing file is corrupted
}
func TestSkipExisting(t *testing.T) {
filename := ".testSkipExisting"
defer os.Remove(filename)
// download a file
grabtest.WithTestServer(t, func(url string) {
resp := mustDo(mustNewRequest(filename, url))
testComplete(t, resp)
})
// redownload
grabtest.WithTestServer(t, func(url string) {
resp := mustDo(mustNewRequest(filename, url))
testComplete(t, resp)
// ensure download was resumed
if !resp.DidResume {
t.Fatalf("Expected download to skip existing file, but it did not")
}
// ensure all bytes were resumed
if resp.Size() == 0 || resp.Size() != resp.bytesResumed {
t.Fatalf("Expected to skip %d bytes in redownload; got %d", resp.Size(), resp.bytesResumed)
}
})
// ensure checksum is performed on pre-existing file
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.SetChecksum(sha256.New(), []byte{0x01, 0x02, 0x03, 0x04}, true)
resp := DefaultClient.Do(req)
if err := resp.Err(); err != ErrBadChecksum {
t.Fatalf("Expected checksum error, got: %v", err)
}
})
}
// TestBatch executes multiple requests simultaneously and validates the
// responses.
func TestBatch(t *testing.T) {
tests := 32
size := 32768
sum := grabtest.MustHexDecodeString("e11360251d1173650cdcd20f111d8f1ca2e412f572e8b36a4dc067121c1799b8")
// test with 4 workers and with one per request
grabtest.WithTestServer(t, func(url string) {
for _, workerCount := range []int{4, 0} {
// create requests
reqs := make([]*Request, tests)
for i := 0; i < len(reqs); i++ {
filename := fmt.Sprintf(".testBatch.%d", i+1)
reqs[i] = mustNewRequest(filename, url+fmt.Sprintf("/request_%d?", i+1))
reqs[i].Label = fmt.Sprintf("Test %d", i+1)
reqs[i].SetChecksum(sha256.New(), sum, false)
}
// batch run
responses := DefaultClient.DoBatch(workerCount, reqs...)
// listen for responses
Loop:
for i := 0; i < len(reqs); {
select {
case resp := <-responses:
if resp == nil {
break Loop
}
testComplete(t, resp)
if err := resp.Err(); err != nil {
t.Errorf("%s: %v", resp.Filename, err)
}
// remove test file
if resp.IsComplete() {
os.Remove(resp.Filename) // ignore errors
}
i++
}
}
}
},
grabtest.ContentLength(size),
)
}
// TestCancelContext tests that a batch of requests can be cancel using a
// context.Context cancellation. Requests are cancelled in multiple states:
// in-progress and unstarted.
func TestCancelContext(t *testing.T) {
fileSize := 134217728
tests := 256
client := NewClient()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grabtest.WithTestServer(t, func(url string) {
reqs := make([]*Request, tests)
for i := 0; i < tests; i++ {
req := mustNewRequest("", fmt.Sprintf("%s/.testCancelContext%d", url, i))
reqs[i] = req.WithContext(ctx)
}
respch := client.DoBatch(8, reqs...)
time.Sleep(time.Millisecond * 500)
cancel()
for resp := range respch {
defer os.Remove(resp.Filename)
// err should be context.Canceled or http.errRequestCanceled
if resp.Err() == nil || !strings.Contains(resp.Err().Error(), "canceled") {
t.Errorf("expected '%v', got '%v'", context.Canceled, resp.Err())
}
if resp.BytesComplete() >= int64(fileSize) {
t.Errorf("expected Response.BytesComplete: < %d, got: %d", fileSize, resp.BytesComplete())
}
}
},
grabtest.ContentLength(fileSize),
)
}
// TestCancelHangingResponse tests that a never ending request is terminated
// when the response is cancelled.
func TestCancelHangingResponse(t *testing.T) {
fileSize := 10
client := NewClient()
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest("", fmt.Sprintf("%s/.testCancelHangingResponse", url))
resp := client.Do(req)
defer os.Remove(resp.Filename)
// Wait for some bytes to be transferred
for resp.BytesComplete() == 0 {
time.Sleep(50 * time.Millisecond)
}
done := make(chan error)
go func() {
done <- resp.Cancel()
}()
select {
case err := <-done:
if err != context.Canceled {
t.Errorf("Expected context.Canceled error, go: %v", err)
}
case <-time.After(time.Second):
t.Fatal("response was not cancelled within 1s")
}
if resp.BytesComplete() == int64(fileSize) {
t.Error("download was not supposed to be complete")
}
fmt.Println("bye")
},
grabtest.RateLimiter(1),
grabtest.ContentLength(fileSize),
)
}
// TestNestedDirectory tests that missing subdirectories are created.
func TestNestedDirectory(t *testing.T) {
dir := "./.testNested/one/two/three"
filename := ".testNestedFile"
expect := dir + "/" + filename
t.Run("Create", func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
resp := mustDo(mustNewRequest(expect, url+"/"+filename))
defer os.RemoveAll("./.testNested/")
if resp.Filename != expect {
t.Errorf("expected nested Request.Filename to be %v, got %v", expect, resp.Filename)
}
})
})
t.Run("No create", func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(expect, url+"/"+filename)
req.NoCreateDirectories = true
resp := DefaultClient.Do(req)
err := resp.Err()
if !os.IsNotExist(err) {
t.Errorf("expected: %v, got: %v", os.ErrNotExist, err)
}
})
})
}
// TestRemoteTime tests that the timestamp of the downloaded file can be set
// according to the timestamp of the remote file.
func TestRemoteTime(t *testing.T) {
filename := "./.testRemoteTime"
defer os.Remove(filename)
// random time between epoch and now
expect := time.Unix(rand.Int63n(time.Now().Unix()), 0)
grabtest.WithTestServer(t, func(url string) {
resp := mustDo(mustNewRequest(filename, url))
fi, err := os.Stat(resp.Filename)
if err != nil {
panic(err)
}
actual := fi.ModTime()
if !actual.Equal(expect) {
t.Errorf("expected %v, got %v", expect, actual)
}
},
grabtest.LastModified(expect),
)
}
func TestResponseCode(t *testing.T) {
filename := "./.testResponseCode"
t.Run("With404", func(t *testing.T) {
defer os.Remove(filename)
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
resp := DefaultClient.Do(req)
expect := StatusCodeError(http.StatusNotFound)
err := resp.Err()
if err != expect {
t.Errorf("expected %v, got '%v'", expect, err)
}
if !IsStatusCodeError(err) {
t.Errorf("expected IsStatusCodeError to return true for %T: %v", err, err)
}
},
grabtest.StatusCodeStatic(http.StatusNotFound),
)
})
t.Run("WithIgnoreNon2XX", func(t *testing.T) {
defer os.Remove(filename)
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.IgnoreBadStatusCodes = true
resp := DefaultClient.Do(req)
if err := resp.Err(); err != nil {
t.Errorf("expected nil, got '%v'", err)
}
},
grabtest.StatusCodeStatic(http.StatusNotFound),
)
})
}
func TestBeforeCopyHook(t *testing.T) {
filename := "./.testBeforeCopy"
t.Run("Noop", func(t *testing.T) {
defer os.RemoveAll(filename)
grabtest.WithTestServer(t, func(url string) {
called := false
req := mustNewRequest(filename, url)
req.BeforeCopy = func(resp *Response) error {
called = true
if resp.IsComplete() {
t.Error("Response object passed to BeforeCopy hook has already been closed")
}
if resp.Progress() != 0 {
t.Error("Download progress already > 0 when BeforeCopy hook was called")
}
if resp.Duration() == 0 {
t.Error("Duration was zero when BeforeCopy was called")
}
if resp.BytesComplete() != 0 {
t.Error("BytesComplete already > 0 when BeforeCopy hook was called")
}
return nil
}
resp := DefaultClient.Do(req)
if err := resp.Err(); err != nil {
t.Errorf("unexpected error using BeforeCopy hook: %v", err)
}
testComplete(t, resp)
if !called {
t.Error("BeforeCopy hook was never called")
}
})
})
t.Run("WithError", func(t *testing.T) {
defer os.RemoveAll(filename)
grabtest.WithTestServer(t, func(url string) {
testError := errors.New("test")
req := mustNewRequest(filename, url)
req.BeforeCopy = func(resp *Response) error {
return testError
}
resp := DefaultClient.Do(req)
if err := resp.Err(); err != testError {
t.Errorf("expected error '%v', got '%v'", testError, err)
}
if resp.BytesComplete() != 0 {
t.Errorf("expected 0 bytes completed for canceled BeforeCopy hook, got %d",
resp.BytesComplete())
}
testComplete(t, resp)
})
})
// Assert that an existing local file will not be truncated prior to the
// BeforeCopy hook has a chance to cancel the request
t.Run("NoTruncate", func(t *testing.T) {
tfile, err := ioutil.TempFile("", "grab_client_test.*.file")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tfile.Name())
const size = 128
_, err = tfile.Write(bytes.Repeat([]byte("x"), size))
if err != nil {
t.Fatal(err)
}
grabtest.WithTestServer(t, func(url string) {
called := false
req := mustNewRequest(tfile.Name(), url)
req.NoResume = true
req.BeforeCopy = func(resp *Response) error {
called = true
fi, err := tfile.Stat()
if err != nil {
t.Errorf("failed to stat temp file: %v", err)
return nil
}
if fi.Size() != size {
t.Errorf("expected existing file size of %d bytes "+
"prior to BeforeCopy hook, got %d", size, fi.Size())
}
return nil
}
resp := DefaultClient.Do(req)
if err := resp.Err(); err != nil {
t.Errorf("unexpected error using BeforeCopy hook: %v", err)
}
testComplete(t, resp)
if !called {
t.Error("BeforeCopy hook was never called")
}
})
})
}
func TestAfterCopyHook(t *testing.T) {
filename := "./.testAfterCopy"
t.Run("Noop", func(t *testing.T) {
defer os.RemoveAll(filename)
grabtest.WithTestServer(t, func(url string) {
called := false
req := mustNewRequest(filename, url)
req.AfterCopy = func(resp *Response) error {
called = true
if resp.IsComplete() {
t.Error("Response object passed to AfterCopy hook has already been closed")
}
if resp.Progress() <= 0 {
t.Error("Download progress was 0 when AfterCopy hook was called")
}
if resp.Duration() == 0 {
t.Error("Duration was zero when AfterCopy was called")
}
if resp.BytesComplete() <= 0 {
t.Error("BytesComplete was 0 when AfterCopy hook was called")
}
return nil
}
resp := DefaultClient.Do(req)
if err := resp.Err(); err != nil {
t.Errorf("unexpected error using AfterCopy hook: %v", err)
}
testComplete(t, resp)
if !called {
t.Error("AfterCopy hook was never called")
}
})
})
t.Run("WithError", func(t *testing.T) {
defer os.RemoveAll(filename)
grabtest.WithTestServer(t, func(url string) {
testError := errors.New("test")
req := mustNewRequest(filename, url)
req.AfterCopy = func(resp *Response) error {
return testError
}
resp := DefaultClient.Do(req)
if err := resp.Err(); err != testError {
t.Errorf("expected error '%v', got '%v'", testError, err)
}
if resp.BytesComplete() <= 0 {
t.Errorf("ByteCompleted was %d after AfterCopy hook was called",
resp.BytesComplete())
}
testComplete(t, resp)
})
})
}
func TestIssue37(t *testing.T) {
// ref: https://github.com/cavaliercoder/grab/issues/37
filename := "./.testIssue37"
largeSize := int64(2097152)
smallSize := int64(1048576)
defer os.RemoveAll(filename)
// download large file
grabtest.WithTestServer(t, func(url string) {
resp := mustDo(mustNewRequest(filename, url))
if resp.Size() != largeSize {
t.Errorf("expected response size: %d, got: %d", largeSize, resp.Size())
}
}, grabtest.ContentLength(int(largeSize)))
// download new, smaller version of same file
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.NoResume = true
resp := mustDo(req)
if resp.Size() != smallSize {
t.Errorf("expected response size: %d, got: %d", smallSize, resp.Size())
}
// local file should have truncated and not resumed
if resp.DidResume {
t.Errorf("expected download to truncate, resumed instead")
}
}, grabtest.ContentLength(int(smallSize)))
fi, err := os.Stat(filename)
if err != nil {
t.Fatal(err)
}
if fi.Size() != int64(smallSize) {
t.Errorf("expected file size %d, got %d", smallSize, fi.Size())
}
}
// TestHeadBadStatus validates that HEAD requests that return non-200 can be
// ignored and succeed if the GET requests succeeeds.
//
// Fixes: https://github.com/cavaliercoder/grab/issues/43
func TestHeadBadStatus(t *testing.T) {
expect := http.StatusOK
filename := ".testIssue43"
statusFunc := func(r *http.Request) int {
if r.Method == "HEAD" {
return http.StatusForbidden
}
return http.StatusOK
}
grabtest.WithTestServer(t, func(url string) {
testURL := fmt.Sprintf("%s/%s", url, filename)
resp := mustDo(mustNewRequest("", testURL))
if resp.HTTPResponse.StatusCode != expect {
t.Errorf(
"expected status code: %d, got:% d",
expect,
resp.HTTPResponse.StatusCode)
}
},
grabtest.StatusCode(statusFunc),
)
}
// TestMissingContentLength ensures that the Response.Size is correct for
// transfers where the remote server does not send a Content-Length header.
//
// TestAutoResume also covers cases with checksum validation.
//
// Kudos to Setnička Jiří <Jiri.Setnicka@ysoft.com> for identifying and raising
// a solution to this issue. Ref: https://github.com/cavaliercoder/grab/pull/27
func TestMissingContentLength(t *testing.T) {
// expectSize must be sufficiently large that DefaultClient.Do won't prefetch
// the entire body and compute ContentLength before returning a Response.
expectSize := 1048576
opts := []grabtest.HandlerOption{
grabtest.ContentLength(expectSize),
grabtest.HeaderBlacklist("Content-Length"),
grabtest.TimeToFirstByte(time.Millisecond * 100), // delay for initial read
}
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(".testMissingContentLength", url)
req.SetChecksum(
md5.New(),
grabtest.DefaultHandlerMD5ChecksumBytes,
false)
resp := DefaultClient.Do(req)
// ensure remote server is not sending content-length header
if v := resp.HTTPResponse.Header.Get("Content-Length"); v != "" {
panic(fmt.Sprintf("http header content length must be empty, got: %s", v))
}
if v := resp.HTTPResponse.ContentLength; v != -1 {
panic(fmt.Sprintf("http response content length must be -1, got: %d", v))
}
// before completion, response size should be -1
if resp.Size() != -1 {
t.Errorf("expected response size: -1, got: %d", resp.Size())
}
// block for completion
if err := resp.Err(); err != nil {
panic(err)
}
// on completion, response size should be actual transfer size
if resp.Size() != int64(expectSize) {
t.Errorf("expected response size: %d, got: %d", expectSize, resp.Size())
}
}, opts...)
}
func TestNoStore(t *testing.T) {
filename := ".testSubdir/testNoStore"
t.Run("DefaultCase", func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest(filename, url)
req.NoStore = true
req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true)
resp := mustDo(req)
// ensure Response.Bytes is correct and can be reread
b, err := resp.Bytes()
if err != nil {
panic(err)
}
grabtest.AssertSHA256Sum(
t,
grabtest.DefaultHandlerSHA256ChecksumBytes,
bytes.NewReader(b),
)
// ensure Response.Open stream is correct and can be reread
r, err := resp.Open()
if err != nil {
panic(err)
}
defer r.Close()
grabtest.AssertSHA256Sum(
t,
grabtest.DefaultHandlerSHA256ChecksumBytes,
r,
)
// Response.Filename should still be set
if resp.Filename != filename {
t.Errorf("expected Response.Filename: %s, got: %s", filename, resp.Filename)
}
// ensure no files were written
paths := []string{
filename,
filepath.Base(filename),
filepath.Dir(filename),
resp.Filename,
filepath.Base(resp.Filename),
filepath.Dir(resp.Filename),
}
for _, path := range paths {
_, err := os.Stat(path)
if !os.IsNotExist(err) {
t.Errorf(
"expect error: %v, got: %v, for path: %s",
os.ErrNotExist,
err,
path)
}
}
})
})
t.Run("ChecksumValidation", func(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
req := mustNewRequest("", url)
req.NoStore = true
req.SetChecksum(
md5.New(),
grabtest.MustHexDecodeString("deadbeefcafebabe"),
true)
resp := DefaultClient.Do(req)
if err := resp.Err(); err != ErrBadChecksum {
t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
}
})
})
}

1
grab/cmd/grab/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
grab

18
grab/cmd/grab/Makefile Normal file
View File

@ -0,0 +1,18 @@
SOURCES = main.go
all : grab
grab: $(SOURCES)
go build -x -o grab $(SOURCES)
clean:
go clean -x
rm -vf grab
check:
go test -v .
install:
go install -v .
.PHONY: all clean check install

34
grab/cmd/grab/main.go Normal file
View File

@ -0,0 +1,34 @@
package main
import (
"context"
"fmt"
"os"
"github.com/cavaliercoder/grab/grabui"
)
func main() {
// validate command args
if len(os.Args) < 2 {
fmt.Fprintf(os.Stderr, "usage: %s url...\n", os.Args[0])
os.Exit(1)
}
urls := os.Args[1:]
// download files
respch, err := grabui.GetBatch(context.Background(), 0, ".", urls...)
if err != nil {
fmt.Fprint(os.Stderr, err)
os.Exit(1)
}
// return the number of failed downloads as exit code
failed := 0
for resp := range respch {
if resp.Err() != nil {
failed++
}
}
os.Exit(failed)
}

63
grab/doc.go Normal file
View File

@ -0,0 +1,63 @@
/*
Package grab provides a HTTP download manager implementation.
Get is the most simple way to download a file:
resp, err := grab.Get("/tmp", "http://example.com/example.zip")
// ...
Get will download the given URL and save it to the given destination directory.
The destination filename will be determined automatically by grab using
Content-Disposition headers returned by the remote server, or by inspecting the
requested URL path.
An empty destination string or "." means the transfer will be stored in the
current working directory.
If a destination file already exists, grab will assume it is a complete or
partially complete download of the requested file. If the remote server supports
resuming interrupted downloads, grab will resume downloading from the end of the
partial file. If the server does not support resumed downloads, the file will be
retransferred in its entirety. If the file is already complete, grab will return
successfully.
For control over the HTTP client, destination path, auto-resume, checksum
validation and other settings, create a Client:
client := grab.NewClient()
client.HTTPClient.Transport.DisableCompression = true
req, err := grab.NewRequest("/tmp", "http://example.com/example.zip")
// ...
req.NoResume = true
req.HTTPRequest.Header.Set("Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l")
resp := client.Do(req)
// ...
You can monitor the progress of downloads while they are transferring:
client := grab.NewClient()
req, err := grab.NewRequest("", "http://example.com/example.zip")
// ...
resp := client.Do(req)
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-t.C:
fmt.Printf("%.02f%% complete\n", resp.Progress())
case <-resp.Done:
if err := resp.Err(); err != nil {
// ...
}
// ...
return
}
}
*/
package grab

42
grab/error.go Normal file
View File

@ -0,0 +1,42 @@
package grab
import (
"errors"
"fmt"
"net/http"
)
var (
// ErrBadLength indicates that the server response or an existing file does
// not match the expected content length.
ErrBadLength = errors.New("bad content length")
// ErrBadChecksum indicates that a downloaded file failed to pass checksum
// validation.
ErrBadChecksum = errors.New("checksum mismatch")
// ErrNoFilename indicates that a reasonable filename could not be
// automatically determined using the URL or response headers from a server.
ErrNoFilename = errors.New("no filename could be determined")
// ErrNoTimestamp indicates that a timestamp could not be automatically
// determined using the response headers from the remote server.
ErrNoTimestamp = errors.New("no timestamp could be determined for the remote file")
// ErrFileExists indicates that the destination path already exists.
ErrFileExists = errors.New("file exists")
)
// StatusCodeError indicates that the server response had a status code that
// was not in the 200-299 range (after following any redirects).
type StatusCodeError int
func (err StatusCodeError) Error() string {
return fmt.Sprintf("server returned %d %s", err, http.StatusText(int(err)))
}
// IsStatusCodeError returns true if the given error is of type StatusCodeError.
func IsStatusCodeError(err error) bool {
_, ok := err.(StatusCodeError)
return ok
}

View File

@ -0,0 +1,95 @@
package grab
import (
"fmt"
"sync"
)
func ExampleClient_Do() {
client := NewClient()
req, err := NewRequest("/tmp", "http://example.com/example.zip")
if err != nil {
panic(err)
}
resp := client.Do(req)
if err := resp.Err(); err != nil {
panic(err)
}
fmt.Println("Download saved to", resp.Filename)
}
// This example uses DoChannel to create a Producer/Consumer model for
// downloading multiple files concurrently. This is similar to how DoBatch uses
// DoChannel under the hood except that it allows the caller to continually send
// new requests until they wish to close the request channel.
func ExampleClient_DoChannel() {
// create a request and a buffered response channel
reqch := make(chan *Request)
respch := make(chan *Response, 10)
// start 4 workers
client := NewClient()
wg := sync.WaitGroup{}
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
client.DoChannel(reqch, respch)
wg.Done()
}()
}
go func() {
// send requests
for i := 0; i < 10; i++ {
url := fmt.Sprintf("http://example.com/example%d.zip", i+1)
req, err := NewRequest("/tmp", url)
if err != nil {
panic(err)
}
reqch <- req
}
close(reqch)
// wait for workers to finish
wg.Wait()
close(respch)
}()
// check each response
for resp := range respch {
// block until complete
if err := resp.Err(); err != nil {
panic(err)
}
fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename)
}
}
func ExampleClient_DoBatch() {
// create multiple download requests
reqs := make([]*Request, 0)
for i := 0; i < 10; i++ {
url := fmt.Sprintf("http://example.com/example%d.zip", i+1)
req, err := NewRequest("/tmp", url)
if err != nil {
panic(err)
}
reqs = append(reqs, req)
}
// start downloads with 4 workers
client := NewClient()
respch := client.DoBatch(4, reqs...)
// check each response
for resp := range respch {
if err := resp.Err(); err != nil {
panic(err)
}
fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename)
}
}

View File

@ -0,0 +1,52 @@
package grab
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"time"
)
func ExampleRequest_WithContext() {
// create context with a 100ms timeout
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
// create download request with context
req, err := NewRequest("", "http://example.com/example.zip")
if err != nil {
panic(err)
}
req = req.WithContext(ctx)
// send download request
resp := DefaultClient.Do(req)
if err := resp.Err(); err != nil {
fmt.Println("error: request cancelled")
}
// Output:
// error: request cancelled
}
func ExampleRequest_SetChecksum() {
// create download request
req, err := NewRequest("", "http://example.com/example.zip")
if err != nil {
panic(err)
}
// set request checksum
sum, err := hex.DecodeString("33daf4c03f86120fdfdc66bddf6bfff4661c7ca11c5da473e537f4d69b470e57")
if err != nil {
panic(err)
}
req.SetChecksum(sha256.New(), sum, true)
// download and validate file
resp := DefaultClient.Do(req)
if err := resp.Err(); err != nil {
panic(err)
}
}

5
grab/go.mod Normal file
View File

@ -0,0 +1,5 @@
module github.com/cavaliercoder/grab
go 1.14
require github.com/lordwelch/pathvalidate v0.0.0-20201012043703-54efa7ea1308

5
grab/go.sum Normal file
View File

@ -0,0 +1,5 @@
github.com/lordwelch/pathvalidate v0.0.0-20201012043703-54efa7ea1308 h1:CkcsZK6QYg59rc92eqU2h+FRjWltCIiplmEwIB05jfM=
github.com/lordwelch/pathvalidate v0.0.0-20201012043703-54efa7ea1308/go.mod h1:4I4r5Y/LkH+34KACiudU+Q27ooz7xSDyVEuWAVKeJEQ=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

64
grab/grab.go Normal file
View File

@ -0,0 +1,64 @@
package grab
import (
"fmt"
"os"
)
// Get sends a HTTP request and downloads the content of the requested URL to
// the given destination file path. The caller is blocked until the download is
// completed, successfully or otherwise.
//
// An error is returned if caused by client policy (such as CheckRedirect), or
// if there was an HTTP protocol or IO error.
//
// For non-blocking calls or control over HTTP client headers, redirect policy,
// and other settings, create a Client instead.
func Get(dst, urlStr string) (*Response, error) {
req, err := NewRequest(dst, urlStr)
if err != nil {
return nil, err
}
resp := DefaultClient.Do(req)
return resp, resp.Err()
}
// GetBatch sends multiple HTTP requests and downloads the content of the
// requested URLs to the given destination directory using the given number of
// concurrent worker goroutines.
//
// The Response for each requested URL is sent through the returned Response
// channel, as soon as a worker receives a response from the remote server. The
// Response can then be used to track the progress of the download while it is
// in progress.
//
// The returned Response channel will be closed by Grab, only once all downloads
// have completed or failed.
//
// If an error occurs during any download, it will be available via call to the
// associated Response.Err.
//
// For control over HTTP client headers, redirect policy, and other settings,
// create a Client instead.
func GetBatch(workers int, dst string, urlStrs ...string) (<-chan *Response, error) {
fi, err := os.Stat(dst)
if err != nil {
return nil, err
}
if !fi.IsDir() {
return nil, fmt.Errorf("destination is not a directory")
}
reqs := make([]*Request, len(urlStrs))
for i := 0; i < len(urlStrs); i++ {
req, err := NewRequest(dst, urlStrs[i])
if err != nil {
return nil, err
}
reqs[i] = req
}
ch := DefaultClient.DoBatch(workers, reqs...)
return ch, nil
}

74
grab/grab_test.go Normal file
View File

@ -0,0 +1,74 @@
package grab
import (
"fmt"
"io/ioutil"
"log"
"os"
"testing"
"github.com/cavaliercoder/grab/grabtest"
)
func TestMain(m *testing.M) {
os.Exit(func() int {
// chdir to temp so test files downloaded to pwd are isolated and cleaned up
cwd, err := os.Getwd()
if err != nil {
panic(err)
}
tmpDir, err := ioutil.TempDir("", "grab-")
if err != nil {
panic(err)
}
if err := os.Chdir(tmpDir); err != nil {
panic(err)
}
defer func() {
os.Chdir(cwd)
if err := os.RemoveAll(tmpDir); err != nil {
panic(err)
}
}()
return m.Run()
}())
}
// TestGet tests grab.Get
func TestGet(t *testing.T) {
filename := ".testGet"
defer os.Remove(filename)
grabtest.WithTestServer(t, func(url string) {
resp, err := Get(filename, url)
if err != nil {
t.Fatalf("error in Get(): %v", err)
}
testComplete(t, resp)
})
}
func ExampleGet() {
// download a file to /tmp
resp, err := Get("/tmp", "http://example.com/example.zip")
if err != nil {
log.Fatal(err)
}
fmt.Println("Download saved to", resp.Filename)
}
func mustNewRequest(dst, urlStr string) *Request {
req, err := NewRequest(dst, urlStr)
if err != nil {
panic(err)
}
return req
}
func mustDo(req *Request) *Response {
resp := DefaultClient.Do(req)
if err := resp.Err(); err != nil {
panic(err)
}
return resp
}

104
grab/grabtest/assert.go Normal file
View File

@ -0,0 +1,104 @@
package grabtest
import (
"bytes"
"crypto/sha256"
"fmt"
"io"
"io/ioutil"
"net/http"
"testing"
)
func AssertHTTPResponseStatusCode(t *testing.T, resp *http.Response, expect int) (ok bool) {
if resp.StatusCode != expect {
t.Errorf("expected status code: %d, got: %d", expect, resp.StatusCode)
return
}
ok = true
return true
}
func AssertHTTPResponseHeader(t *testing.T, resp *http.Response, key, format string, a ...interface{}) (ok bool) {
expect := fmt.Sprintf(format, a...)
actual := resp.Header.Get(key)
if actual != expect {
t.Errorf("expected header %s: %s, got: %s", key, expect, actual)
return
}
ok = true
return
}
func AssertHTTPResponseContentLength(t *testing.T, resp *http.Response, n int64) (ok bool) {
ok = true
if resp.ContentLength != n {
ok = false
t.Errorf("expected header Content-Length: %d, got: %d", n, resp.ContentLength)
}
if !AssertHTTPResponseBodyLength(t, resp, n) {
ok = false
}
return
}
func AssertHTTPResponseBodyLength(t *testing.T, resp *http.Response, n int64) (ok bool) {
defer func() {
if err := resp.Body.Close(); err != nil {
panic(err)
}
}()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
panic(err)
}
if int64(len(b)) != n {
ok = false
t.Errorf("expected body length: %d, got: %d", n, len(b))
}
return
}
func MustHTTPNewRequest(method, url string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, url, body)
if err != nil {
panic(err)
}
return req
}
func MustHTTPDo(req *http.Request) *http.Response {
resp, err := http.DefaultClient.Do(req)
if err != nil {
panic(err)
}
return resp
}
func MustHTTPDoWithClose(req *http.Request) *http.Response {
resp := MustHTTPDo(req)
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
panic(err)
}
if err := resp.Body.Close(); err != nil {
panic(err)
}
return resp
}
func AssertSHA256Sum(t *testing.T, sum []byte, r io.Reader) (ok bool) {
h := sha256.New()
if _, err := io.Copy(h, r); err != nil {
panic(err)
}
computed := h.Sum(nil)
ok = bytes.Equal(sum, computed)
if !ok {
t.Errorf(
"expected checksum: %s, got: %s",
MustHexEncodeString(sum),
MustHexEncodeString(computed),
)
}
return
}

160
grab/grabtest/handler.go Normal file
View File

@ -0,0 +1,160 @@
package grabtest
import (
"bufio"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
)
var (
DefaultHandlerContentLength = 1 << 20
DefaultHandlerMD5Checksum = "c35cc7d8d91728a0cb052831bc4ef372"
DefaultHandlerMD5ChecksumBytes = MustHexDecodeString(DefaultHandlerMD5Checksum)
DefaultHandlerSHA256Checksum = "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83"
DefaultHandlerSHA256ChecksumBytes = MustHexDecodeString(DefaultHandlerSHA256Checksum)
)
type StatusCodeFunc func(req *http.Request) int
type handler struct {
statusCodeFunc StatusCodeFunc
methodWhitelist []string
headerBlacklist []string
contentLength int
acceptRanges bool
attachmentFilename string
lastModified time.Time
ttfb time.Duration
rateLimiter *time.Ticker
}
func NewHandler(options ...HandlerOption) (http.Handler, error) {
h := &handler{
statusCodeFunc: func(req *http.Request) int { return http.StatusOK },
methodWhitelist: []string{"GET", "HEAD"},
contentLength: DefaultHandlerContentLength,
acceptRanges: true,
}
for _, option := range options {
if err := option(h); err != nil {
return nil, err
}
}
return h, nil
}
func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) {
h, err := NewHandler(options...)
if err != nil {
t.Fatalf("unable to create test server handler: %v", err)
return
}
s := httptest.NewServer(h)
defer func() {
h.(*handler).close()
s.Close()
}()
f(s.URL)
}
func (h *handler) close() {
if h.rateLimiter != nil {
h.rateLimiter.Stop()
}
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// delay response
if h.ttfb > 0 {
time.Sleep(h.ttfb)
}
// validate request method
allowed := false
for _, m := range h.methodWhitelist {
if r.Method == m {
allowed = true
break
}
}
if !allowed {
httpError(w, http.StatusMethodNotAllowed)
return
}
// set server options
if h.acceptRanges {
w.Header().Set("Accept-Ranges", "bytes")
}
// set attachment filename
if h.attachmentFilename != "" {
w.Header().Set(
"Content-Disposition",
fmt.Sprintf("attachment;filename=\"%s\"", h.attachmentFilename),
)
}
// set last modified timestamp
lastMod := time.Now()
if !h.lastModified.IsZero() {
lastMod = h.lastModified
}
w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat))
// set content-length
offset := 0
if h.acceptRanges {
if reqRange := r.Header.Get("Range"); reqRange != "" {
if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil {
httpError(w, http.StatusBadRequest)
return
}
if offset >= h.contentLength {
httpError(w, http.StatusRequestedRangeNotSatisfiable)
return
}
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset))
// apply header blacklist
for _, key := range h.headerBlacklist {
w.Header().Del(key)
}
// send header and status code
w.WriteHeader(h.statusCodeFunc(r))
// send body
if r.Method == "GET" {
// use buffered io to reduce overhead on the reader
bw := bufio.NewWriterSize(w, 4096)
for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ {
bw.Write([]byte{byte(i)})
if h.rateLimiter != nil {
bw.Flush()
w.(http.Flusher).Flush() // force the server to send the data to the client
select {
case <-h.rateLimiter.C:
case <-r.Context().Done():
}
}
}
if !isRequestClosed(r) {
bw.Flush()
}
}
}
// isRequestClosed returns true if the client request has been canceled.
func isRequestClosed(r *http.Request) bool {
return r.Context().Err() != nil
}
func httpError(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}

View File

@ -0,0 +1,92 @@
package grabtest
import (
"errors"
"net/http"
"time"
)
type HandlerOption func(*handler) error
func StatusCodeStatic(code int) HandlerOption {
return func(h *handler) error {
return StatusCode(func(req *http.Request) int {
return code
})(h)
}
}
func StatusCode(f StatusCodeFunc) HandlerOption {
return func(h *handler) error {
if f == nil {
return errors.New("status code function cannot be nil")
}
h.statusCodeFunc = f
return nil
}
}
func MethodWhitelist(methods ...string) HandlerOption {
return func(h *handler) error {
h.methodWhitelist = methods
return nil
}
}
func HeaderBlacklist(headers ...string) HandlerOption {
return func(h *handler) error {
h.headerBlacklist = headers
return nil
}
}
func ContentLength(n int) HandlerOption {
return func(h *handler) error {
if n < 0 {
return errors.New("content length must be zero or greater")
}
h.contentLength = n
return nil
}
}
func AcceptRanges(enabled bool) HandlerOption {
return func(h *handler) error {
h.acceptRanges = enabled
return nil
}
}
func LastModified(t time.Time) HandlerOption {
return func(h *handler) error {
h.lastModified = t.UTC()
return nil
}
}
func TimeToFirstByte(d time.Duration) HandlerOption {
return func(h *handler) error {
if d < 1 {
return errors.New("time to first byte must be greater than zero")
}
h.ttfb = d
return nil
}
}
func RateLimiter(bps int) HandlerOption {
return func(h *handler) error {
if bps < 1 {
return errors.New("bytes per second must be greater than zero")
}
h.rateLimiter = time.NewTicker(time.Second / time.Duration(bps))
return nil
}
}
func AttachmentFilename(filename string) HandlerOption {
return func(h *handler) error {
h.attachmentFilename = filename
return nil
}
}

View File

@ -0,0 +1,150 @@
package grabtest
import (
"fmt"
"io/ioutil"
"net/http"
"testing"
"time"
)
func TestHandlerDefaults(t *testing.T) {
WithTestServer(t, func(url string) {
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
AssertHTTPResponseStatusCode(t, resp, http.StatusOK)
AssertHTTPResponseContentLength(t, resp, 1048576)
AssertHTTPResponseHeader(t, resp, "Accept-Ranges", "bytes")
})
}
func TestHandlerMethodWhitelist(t *testing.T) {
tests := []struct {
Whitelist []string
Method string
ExpectStatusCode int
}{
{[]string{"GET", "HEAD"}, "GET", http.StatusOK},
{[]string{"GET", "HEAD"}, "HEAD", http.StatusOK},
{[]string{"GET"}, "HEAD", http.StatusMethodNotAllowed},
{[]string{"HEAD"}, "GET", http.StatusMethodNotAllowed},
}
for _, test := range tests {
WithTestServer(t, func(url string) {
resp := MustHTTPDoWithClose(MustHTTPNewRequest(test.Method, url, nil))
AssertHTTPResponseStatusCode(t, resp, test.ExpectStatusCode)
}, MethodWhitelist(test.Whitelist...))
}
}
func TestHandlerHeaderBlacklist(t *testing.T) {
contentLength := 4096
WithTestServer(t, func(url string) {
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
defer resp.Body.Close()
if resp.ContentLength != -1 {
t.Errorf("expected Response.ContentLength: -1, got: %d", resp.ContentLength)
}
AssertHTTPResponseHeader(t, resp, "Content-Length", "")
AssertHTTPResponseBodyLength(t, resp, int64(contentLength))
},
ContentLength(contentLength),
HeaderBlacklist("Content-Length"),
)
}
func TestHandlerStatusCodeFuncs(t *testing.T) {
expect := 418 // I'm a teapot
WithTestServer(t, func(url string) {
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
AssertHTTPResponseStatusCode(t, resp, expect)
},
StatusCode(func(req *http.Request) int { return expect }),
)
}
func TestHandlerContentLength(t *testing.T) {
tests := []struct {
Method string
ContentLength int
ExpectHeaderLen int64
ExpectBodyLen int
}{
{"GET", 321, 321, 321},
{"HEAD", 321, 321, 0},
{"GET", 0, 0, 0},
{"HEAD", 0, 0, 0},
}
for _, test := range tests {
WithTestServer(t, func(url string) {
resp := MustHTTPDo(MustHTTPNewRequest(test.Method, url, nil))
defer resp.Body.Close()
AssertHTTPResponseHeader(t, resp, "Content-Length", "%d", test.ExpectHeaderLen)
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
panic(err)
}
if len(b) != test.ExpectBodyLen {
t.Errorf(
"expected body length: %v, got: %v, in: %v",
test.ExpectBodyLen,
len(b),
test,
)
}
},
ContentLength(test.ContentLength),
)
}
}
func TestHandlerAcceptRanges(t *testing.T) {
header := "Accept-Ranges"
n := 128
t.Run("Enabled", func(t *testing.T) {
WithTestServer(t, func(url string) {
req := MustHTTPNewRequest("GET", url, nil)
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2))
resp := MustHTTPDo(req)
AssertHTTPResponseHeader(t, resp, header, "bytes")
AssertHTTPResponseContentLength(t, resp, int64(n/2))
},
ContentLength(n),
)
})
t.Run("Disabled", func(t *testing.T) {
WithTestServer(t, func(url string) {
req := MustHTTPNewRequest("GET", url, nil)
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2))
resp := MustHTTPDo(req)
AssertHTTPResponseHeader(t, resp, header, "")
AssertHTTPResponseContentLength(t, resp, int64(n))
},
AcceptRanges(false),
ContentLength(n),
)
})
}
func TestHandlerAttachmentFilename(t *testing.T) {
filename := "foo.pdf"
WithTestServer(t, func(url string) {
resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil))
AssertHTTPResponseHeader(t, resp, "Content-Disposition", `attachment;filename="%s"`, filename)
},
AttachmentFilename(filename),
)
}
func TestHandlerLastModified(t *testing.T) {
WithTestServer(t, func(url string) {
resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil))
AssertHTTPResponseHeader(t, resp, "Last-Modified", "Thu, 29 Nov 1973 21:33:09 GMT")
},
LastModified(time.Unix(123456789, 0)),
)
}

16
grab/grabtest/util.go Normal file
View File

@ -0,0 +1,16 @@
package grabtest
import "encoding/hex"
func MustHexDecodeString(s string) (b []byte) {
var err error
b, err = hex.DecodeString(s)
if err != nil {
panic(err)
}
return
}
func MustHexEncodeString(b []byte) (s string) {
return hex.EncodeToString(b)
}

View File

@ -0,0 +1,166 @@
package grabui
import (
"context"
"fmt"
"os"
"sync"
"time"
"github.com/cavaliercoder/grab"
)
type ConsoleClient struct {
mu sync.Mutex
client *grab.Client
succeeded, failed, inProgress int
responses []*grab.Response
}
func NewConsoleClient(client *grab.Client) *ConsoleClient {
return &ConsoleClient{
client: client,
}
}
func (c *ConsoleClient) Do(
ctx context.Context,
workers int,
reqs ...*grab.Request,
) <-chan *grab.Response {
// buffer size prevents slow receivers causing back pressure
pump := make(chan *grab.Response, len(reqs))
go func() {
c.mu.Lock()
defer c.mu.Unlock()
c.failed = 0
c.inProgress = 0
c.succeeded = 0
c.responses = make([]*grab.Response, 0, len(reqs))
if c.client == nil {
c.client = grab.DefaultClient
}
fmt.Printf("Downloading %d files...\n", len(reqs))
respch := c.client.DoBatch(workers, reqs...)
t := time.NewTicker(200 * time.Millisecond)
defer t.Stop()
Loop:
for {
select {
case <-ctx.Done():
break Loop
case resp := <-respch:
if resp != nil {
// a new response has been received and has started downloading
c.responses = append(c.responses, resp)
pump <- resp // send to caller
} else {
// channel is closed - all downloads are complete
break Loop
}
case <-t.C:
// update UI on clock tick
c.refresh()
}
}
c.refresh()
close(pump)
fmt.Printf(
"Finished %d successful, %d failed, %d incomplete.\n",
c.succeeded,
c.failed,
c.inProgress)
}()
return pump
}
// refresh prints the progress of all downloads to the terminal
func (c *ConsoleClient) refresh() {
// clear lines for incomplete downloads
if c.inProgress > 0 {
fmt.Printf("\033[%dA\033[K", c.inProgress)
}
// print newly completed downloads
for i, resp := range c.responses {
if resp != nil && resp.IsComplete() {
if resp.Err() != nil {
c.failed++
fmt.Fprintf(os.Stderr, "Error downloading %s: %v\n",
resp.Request.URL(),
resp.Err())
} else {
c.succeeded++
fmt.Printf("Finished %s %s / %s (%d%%)\n",
resp.Filename,
byteString(resp.BytesComplete()),
byteString(resp.Size()),
int(100*resp.Progress()))
}
c.responses[i] = nil
}
}
// print progress for incomplete downloads
c.inProgress = 0
for _, resp := range c.responses {
if resp != nil {
fmt.Printf("Downloading %s %s / %s (%d%%) - %s ETA: %s \033[K\n",
resp.Filename,
byteString(resp.BytesComplete()),
byteString(resp.Size()),
int(100*resp.Progress()),
bpsString(resp.BytesPerSecond()),
etaString(resp.ETA()))
c.inProgress++
}
}
}
func bpsString(n float64) string {
if n < 1e3 {
return fmt.Sprintf("%.02fBps", n)
}
if n < 1e6 {
return fmt.Sprintf("%.02fKB/s", n/1e3)
}
if n < 1e9 {
return fmt.Sprintf("%.02fMB/s", n/1e6)
}
return fmt.Sprintf("%.02fGB/s", n/1e9)
}
func byteString(n int64) string {
if n < 1<<10 {
return fmt.Sprintf("%dB", n)
}
if n < 1<<20 {
return fmt.Sprintf("%dKB", n>>10)
}
if n < 1<<30 {
return fmt.Sprintf("%dMB", n>>20)
}
if n < 1<<40 {
return fmt.Sprintf("%dGB", n>>30)
}
return fmt.Sprintf("%dTB", n>>40)
}
func etaString(eta time.Time) string {
d := eta.Sub(time.Now())
if d < time.Second {
return "<1s"
}
// truncate to 1s resolution
d /= time.Second
d *= time.Second
return d.String()
}

27
grab/grabui/grabui.go Normal file
View File

@ -0,0 +1,27 @@
package grabui
import (
"context"
"github.com/cavaliercoder/grab"
)
func GetBatch(
ctx context.Context,
workers int,
dst string,
urlStrs ...string,
) (<-chan *grab.Response, error) {
reqs := make([]*grab.Request, len(urlStrs))
for i := 0; i < len(urlStrs); i++ {
req, err := grab.NewRequest(dst, urlStrs[i])
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
reqs[i] = req
}
ui := NewConsoleClient(grab.DefaultClient)
return ui.Do(ctx, workers, reqs...), nil
}

12
grab/rate_limiter.go Normal file
View File

@ -0,0 +1,12 @@
package grab
import "context"
// RateLimiter is an interface that must be satisfied by any third-party rate
// limiters that may be used to limit download transfer speeds.
//
// A recommended token bucket implementation can be found at
// https://godoc.org/golang.org/x/time/rate#Limiter.
type RateLimiter interface {
WaitN(ctx context.Context, n int) (err error)
}

69
grab/rate_limiter_test.go Normal file
View File

@ -0,0 +1,69 @@
package grab
import (
"context"
"log"
"os"
"testing"
"time"
"github.com/cavaliercoder/grab/grabtest"
)
// testRateLimiter is a naive rate limiter that limits throughput to r tokens
// per second. The total number of tokens issued is tracked as n.
type testRateLimiter struct {
r, n int
}
func NewLimiter(r int) RateLimiter {
return &testRateLimiter{r: r}
}
func (c *testRateLimiter) WaitN(ctx context.Context, n int) (err error) {
c.n += n
time.Sleep(
time.Duration(1.00 / float64(c.r) * float64(n) * float64(time.Second)))
return
}
func TestRateLimiter(t *testing.T) {
// download a 128 byte file, 8 bytes at a time, with a naive 512bps limiter
// should take > 250ms
filesize := 128
filename := ".testRateLimiter"
defer os.Remove(filename)
grabtest.WithTestServer(t, func(url string) {
// limit to 512bps
lim := &testRateLimiter{r: 512}
req := mustNewRequest(filename, url)
// ensure multiple trips to the rate limiter by downloading 8 bytes at a time
req.BufferSize = 8
req.RateLimiter = lim
resp := mustDo(req)
testComplete(t, resp)
if lim.n != filesize {
t.Errorf("expected %d bytes to pass through limiter, got %d", filesize, lim.n)
}
if resp.Duration().Seconds() < 0.25 {
// BUG: this test can pass if the transfer was slow for unrelated reasons
t.Errorf("expected transfer to take >250ms, took %v", resp.Duration())
}
}, grabtest.ContentLength(filesize))
}
func ExampleRateLimiter() {
req, _ := NewRequest("", "http://www.golang-book.com/public/pdf/gobook.pdf")
// Attach a 1Mbps rate limiter, like the token bucket implementation from
// golang.org/x/time/rate.
req.RateLimiter = NewLimiter(1048576)
resp := DefaultClient.Do(req)
if err := resp.Err(); err != nil {
log.Fatal(err)
}
}

177
grab/request.go Normal file
View File

@ -0,0 +1,177 @@
package grab
import (
"context"
"hash"
"net/http"
"net/url"
)
// A Hook is a user provided callback function that can be called by grab at
// various stages of a requests lifecycle. If a hook returns an error, the
// associated request is canceled and the same error is returned on the Response
// object.
//
// Hook functions are called synchronously and should never block unnecessarily.
// Response methods that block until a download is complete, such as
// Response.Err, Response.Cancel or Response.Wait will deadlock. To cancel a
// download from a callback, simply return a non-nil error.
type Hook func(*Response) error
// A Request represents an HTTP file transfer request to be sent by a Client.
type Request struct {
// Label is an arbitrary string which may used to label a Request with a
// user friendly name.
Label string
// Tag is an arbitrary interface which may be used to relate a Request to
// other data.
Tag interface{}
// HTTPRequest specifies the http.Request to be sent to the remote server to
// initiate a file transfer. It includes request configuration such as URL,
// protocol version, HTTP method, request headers and authentication.
HTTPRequest *http.Request
// Filename specifies the path where the file transfer will be stored in
// local storage. If Filename is empty or a directory, the true Filename will
// be resolved using Content-Disposition headers or the request URL.
//
// An empty string means the transfer will be stored in the current working
// directory.
Filename string
// SkipExisting specifies that ErrFileExists should be returned if the
// destination path already exists. The existing file will not be checked for
// completeness.
SkipExisting bool
// NoResume specifies that a partially completed download will be restarted
// without attempting to resume any existing file. If the download is already
// completed in full, it will not be restarted.
NoResume bool
// NoStore specifies that grab should not write to the local file system.
// Instead, the download will be stored in memory and accessible only via
// Response.Open or Response.Bytes.
NoStore bool
// NoCreateDirectories specifies that any missing directories in the given
// Filename path should not be created automatically, if they do not already
// exist.
NoCreateDirectories bool
// IgnoreBadStatusCodes specifies that grab should accept any status code in
// the response from the remote server. Otherwise, grab expects the response
// status code to be within the 2XX range (after following redirects).
IgnoreBadStatusCodes bool
// IgnoreRemoteTime specifies that grab should not attempt to set the
// timestamp of the local file to match the remote file.
IgnoreRemoteTime bool
// Size specifies the expected size of the file transfer if known. If the
// server response size does not match, the transfer is cancelled and
// ErrBadLength returned.
Size int64
// BufferSize specifies the size in bytes of the buffer that is used for
// transferring the requested file. Larger buffers may result in faster
// throughput but will use more memory and result in less frequent updates
// to the transfer progress statistics. If a RateLimiter is configured,
// BufferSize should be much lower than the rate limit. Default: 32KB.
BufferSize int
// RateLimiter allows the transfer rate of a download to be limited. The given
// Request.BufferSize determines how frequently the RateLimiter will be
// polled.
RateLimiter RateLimiter
// BeforeCopy is a user provided callback that is called immediately before
// a request starts downloading. If BeforeCopy returns an error, the request
// is cancelled and the same error is returned on the Response object.
BeforeCopy Hook
// AfterCopy is a user provided callback that is called immediately after a
// request has finished downloading, before checksum validation and closure.
// This hook is only called if the transfer was successful. If AfterCopy
// returns an error, the request is canceled and the same error is returned on
// the Response object.
AfterCopy Hook
// hash, checksum and deleteOnError - set via SetChecksum.
hash hash.Hash
checksum []byte
deleteOnError bool
// Context for cancellation and timeout - set via WithContext
ctx context.Context
}
// NewRequest returns a new file transfer Request suitable for use with
// Client.Do.
func NewRequest(dst, urlStr string) (*Request, error) {
if dst == "" {
dst = "."
}
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return nil, err
}
return &Request{
HTTPRequest: req,
Filename: dst,
}, nil
}
// Context returns the request's context. To change the context, use
// WithContext.
//
// The returned context is always non-nil; it defaults to the background
// context.
//
// The context controls cancelation.
func (r *Request) Context() context.Context {
if r.ctx != nil {
return r.ctx
}
return context.Background()
}
// WithContext returns a shallow copy of r with its context changed
// to ctx. The provided ctx must be non-nil.
func (r *Request) WithContext(ctx context.Context) *Request {
if ctx == nil {
panic("nil context")
}
r2 := new(Request)
*r2 = *r
r2.ctx = ctx
r2.HTTPRequest = r2.HTTPRequest.WithContext(ctx)
return r2
}
// URL returns the URL to be downloaded.
func (r *Request) URL() *url.URL {
return r.HTTPRequest.URL
}
// SetChecksum sets the desired hashing algorithm and checksum value to validate
// a downloaded file. Once the download is complete, the given hashing algorithm
// will be used to compute the actual checksum of the downloaded file. If the
// checksums do not match, an error will be returned by the associated
// Response.Err method.
//
// If deleteOnError is true, the downloaded file will be deleted automatically
// if it fails checksum validation.
//
// To prevent corruption of the computed checksum, the given hash must not be
// used by any other request or goroutines.
//
// To disable checksum validation, call SetChecksum with a nil hash.
func (r *Request) SetChecksum(h hash.Hash, sum []byte, deleteOnError bool) {
r.hash = h
r.checksum = sum
r.deleteOnError = deleteOnError
}

258
grab/response.go Normal file
View File

@ -0,0 +1,258 @@
package grab
import (
"bytes"
"context"
"io"
"io/ioutil"
"net/http"
"os"
"sync/atomic"
"time"
)
// Response represents the response to a completed or in-progress download
// request.
//
// A response may be returned as soon a HTTP response is received from a remote
// server, but before the body content has started transferring.
//
// All Response method calls are thread-safe.
type Response struct {
// The Request that was submitted to obtain this Response.
Request *Request
// HTTPResponse represents the HTTP response received from an HTTP request.
//
// The response Body should not be used as it will be consumed and closed by
// grab.
HTTPResponse *http.Response
// Filename specifies the path where the file transfer is stored in local
// storage.
Filename string
// Size specifies the total expected size of the file transfer.
sizeUnsafe int64
// Start specifies the time at which the file transfer started.
Start time.Time
// End specifies the time at which the file transfer completed.
//
// This will return zero until the transfer has completed.
End time.Time
// CanResume specifies that the remote server advertised that it can resume
// previous downloads, as the 'Accept-Ranges: bytes' header is set.
CanResume bool
// DidResume specifies that the file transfer resumed a previously incomplete
// transfer.
DidResume bool
// Done is closed once the transfer is finalized, either successfully or with
// errors. Errors are available via Response.Err
Done chan struct{}
// ctx is a Context that controls cancelation of an inprogress transfer
ctx context.Context
// cancel is a cancel func that can be used to cancel the context of this
// Response.
cancel context.CancelFunc
// fi is the FileInfo for the destination file if it already existed before
// transfer started.
fi os.FileInfo
// optionsKnown indicates that a HEAD request has been completed and the
// capabilities of the remote server are known.
optionsKnown bool
// writer is the file handle used to write the downloaded file to local
// storage
writer io.Writer
// storeBuffer receives the contents of the transfer if Request.NoStore is
// enabled.
storeBuffer bytes.Buffer
// bytesCompleted specifies the number of bytes which were already
// transferred before this transfer began.
bytesResumed int64
// transfer is responsible for copying data from the remote server to a local
// file, tracking progress and allowing for cancelation.
transfer *transfer
// bufferSize specifies the size in bytes of the transfer buffer.
bufferSize int
// Error contains any error that may have occurred during the file transfer.
// This should not be read until IsComplete returns true.
err error
}
// IsComplete returns true if the download has completed. If an error occurred
// during the download, it can be returned via Err.
func (c *Response) IsComplete() bool {
select {
case <-c.Done:
return true
default:
return false
}
}
// Cancel cancels the file transfer by canceling the underlying Context for
// this Response. Cancel blocks until the transfer is closed and returns any
// error - typically context.Canceled.
func (c *Response) Cancel() error {
c.cancel()
return c.Err()
}
// Wait blocks until the download is completed.
func (c *Response) Wait() {
<-c.Done
}
// Err blocks the calling goroutine until the underlying file transfer is
// completed and returns any error that may have occurred. If the download is
// already completed, Err returns immediately.
func (c *Response) Err() error {
<-c.Done
return c.err
}
// Size returns the size of the file transfer. If the remote server does not
// specify the total size and the transfer is incomplete, the return value is
// -1.
func (c *Response) Size() int64 {
return atomic.LoadInt64(&c.sizeUnsafe)
}
// BytesComplete returns the total number of bytes which have been copied to
// the destination, including any bytes that were resumed from a previous
// download.
func (c *Response) BytesComplete() int64 {
return c.bytesResumed + c.transfer.N()
}
// BytesPerSecond returns the number of bytes per second transferred using a
// simple moving average of the last five seconds. If the download is already
// complete, the average bytes/sec for the life of the download is returned.
func (c *Response) BytesPerSecond() float64 {
if c.IsComplete() {
return float64(c.transfer.N()) / c.Duration().Seconds()
}
return c.transfer.BPS()
}
// Progress returns the ratio of total bytes that have been downloaded. Multiply
// the returned value by 100 to return the percentage completed.
func (c *Response) Progress() float64 {
size := c.Size()
if size <= 0 {
return 0
}
return float64(c.BytesComplete()) / float64(size)
}
// Duration returns the duration of a file transfer. If the transfer is in
// process, the duration will be between now and the start of the transfer. If
// the transfer is complete, the duration will be between the start and end of
// the completed transfer process.
func (c *Response) Duration() time.Duration {
if c.IsComplete() {
return c.End.Sub(c.Start)
}
return time.Now().Sub(c.Start)
}
// ETA returns the estimated time at which the the download will complete, given
// the current BytesPerSecond. If the transfer has already completed, the actual
// end time will be returned.
func (c *Response) ETA() time.Time {
if c.IsComplete() {
return c.End
}
bt := c.BytesComplete()
bps := c.transfer.BPS()
if bps == 0 {
return time.Time{}
}
secs := float64(c.Size()-bt) / bps
return time.Now().Add(time.Duration(secs) * time.Second)
}
// Open blocks the calling goroutine until the underlying file transfer is
// completed and then opens the transferred file for reading. If Request.NoStore
// was enabled, the reader will read from memory.
//
// If an error occurred during the transfer, it will be returned.
//
// It is the callers responsibility to close the returned file handle.
func (c *Response) Open() (io.ReadCloser, error) {
if err := c.Err(); err != nil {
return nil, err
}
return c.openUnsafe()
}
func (c *Response) openUnsafe() (io.ReadCloser, error) {
if c.Request.NoStore {
return ioutil.NopCloser(bytes.NewReader(c.storeBuffer.Bytes())), nil
}
return os.Open(c.Filename)
}
// Bytes blocks the calling goroutine until the underlying file transfer is
// completed and then reads all bytes from the completed tranafer. If
// Request.NoStore was enabled, the bytes will be read from memory.
//
// If an error occurred during the transfer, it will be returned.
func (c *Response) Bytes() ([]byte, error) {
if err := c.Err(); err != nil {
return nil, err
}
if c.Request.NoStore {
return c.storeBuffer.Bytes(), nil
}
f, err := c.Open()
if err != nil {
return nil, err
}
defer f.Close()
return ioutil.ReadAll(f)
}
func (c *Response) requestMethod() string {
if c == nil || c.HTTPResponse == nil || c.HTTPResponse.Request == nil {
return ""
}
return c.HTTPResponse.Request.Method
}
func (c *Response) checksumUnsafe() ([]byte, error) {
f, err := c.openUnsafe()
if err != nil {
return nil, err
}
defer f.Close()
t := newTransfer(c.Request.Context(), nil, c.Request.hash, f, nil)
if _, err = t.copy(); err != nil {
return nil, err
}
sum := c.Request.hash.Sum(nil)
return sum, nil
}
func (c *Response) closeResponseBody() error {
if c.HTTPResponse == nil || c.HTTPResponse.Body == nil {
return nil
}
return c.HTTPResponse.Body.Close()
}

118
grab/response_test.go Normal file
View File

@ -0,0 +1,118 @@
package grab
import (
"bytes"
"os"
"testing"
"time"
"github.com/cavaliercoder/grab/grabtest"
)
// testComplete validates that a completed Response has all the desired fields.
func testComplete(t *testing.T, resp *Response) {
<-resp.Done
if !resp.IsComplete() {
t.Errorf("Response.IsComplete returned false")
}
if resp.Start.IsZero() {
t.Errorf("Response.Start is zero")
}
if resp.End.IsZero() {
t.Error("Response.End is zero")
}
if eta := resp.ETA(); eta != resp.End {
t.Errorf("Response.ETA is not equal to Response.End: %v", eta)
}
// the following fields should only be set if no error occurred
if resp.Err() == nil {
if resp.Filename == "" {
t.Errorf("Response.Filename is empty")
}
if resp.Size() == 0 {
t.Error("Response.Size is zero")
}
if p := resp.Progress(); p != 1.00 {
t.Errorf("Response.Progress returned %v (%v/%v bytes), expected 1", p, resp.BytesComplete(), resp.Size())
}
}
}
// TestResponseProgress tests the functions which indicate the progress of an
// in-process file transfer.
func TestResponseProgress(t *testing.T) {
filename := ".testResponseProgress"
defer os.Remove(filename)
sleep := 300 * time.Millisecond
size := 1024 * 8 // bytes
grabtest.WithTestServer(t, func(url string) {
// request a slow transfer
req := mustNewRequest(filename, url)
resp := DefaultClient.Do(req)
// make sure transfer has not started
if resp.IsComplete() {
t.Errorf("Transfer should not have started")
}
if p := resp.Progress(); p != 0 {
t.Errorf("Transfer should not have started yet but progress is %v", p)
}
// wait for transfer to complete
<-resp.Done
// make sure transfer is complete
if p := resp.Progress(); p != 1 {
t.Errorf("Transfer is complete but progress is %v", p)
}
if s := resp.BytesComplete(); s != int64(size) {
t.Errorf("Expected to transfer %v bytes, got %v", size, s)
}
},
grabtest.TimeToFirstByte(sleep),
grabtest.ContentLength(size),
)
}
func TestResponseOpen(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
resp := mustDo(mustNewRequest("", url+"/someFilename"))
f, err := resp.Open()
if err != nil {
t.Error(err)
return
}
defer func() {
if err := f.Close(); err != nil {
t.Error(err)
}
}()
grabtest.AssertSHA256Sum(t, grabtest.DefaultHandlerSHA256ChecksumBytes, f)
})
}
func TestResponseBytes(t *testing.T) {
grabtest.WithTestServer(t, func(url string) {
resp := mustDo(mustNewRequest("", url+"/someFilename"))
b, err := resp.Bytes()
if err != nil {
t.Error(err)
return
}
grabtest.AssertSHA256Sum(
t,
grabtest.DefaultHandlerSHA256ChecksumBytes,
bytes.NewReader(b),
)
})
}

111
grab/states.wsd Normal file
View File

@ -0,0 +1,111 @@
@startuml
title Grab transfer state
legend
| # | Meaning |
| D | Destination path known |
| S | File size known |
| O | Server options known (Accept-Ranges) |
| R | Resume supported (Accept-Ranges) |
| Z | Local file empty or missing |
| P | Local file partially complete |
endlegend
[*] --> Empty
[*] --> D
[*] --> S
[*] --> DS
Empty : Filename: ""
Empty : Size: 0
Empty --> O : HEAD: Method not allowed
Empty --> DSO : HEAD: Range not supported
Empty --> DSOR : HEAD: Range supported
DS : Filename: "foo.bar"
DS : Size: > 0
DS --> DSZ : checkExisting(): File missing
DS --> DSP : checkExisting(): File partial
DS --> [*] : checkExisting(): File complete
DS --> ERROR
S : Filename: ""
S : Size: > 0
S --> SO : HEAD: Method not allowed
S --> DSO : HEAD: Range not supported
S --> DSOR : HEAD: Range supported
D : Filename: "foo.bar"
D : Size: 0
D --> DO : HEAD: Method not allowed
D --> DSO : HEAD: Range not supported
D --> DSOR : HEAD: Range supported
O : Filename: ""
O : Size: 0
O : CanResume: false
O --> DSO : GET 200
O --> ERROR
SO : Filename: ""
SO : Size: > 0
SO : CanResume: false
SO --> DSO : GET: 200
SO --> ERROR
DO : Filename: "foo.bar"
DO : Size: 0
DO : CanResume: false
DO --> DSO : GET 200
DO --> ERROR
DSZ : Filename: "foo.bar"
DSZ : Size: > 0
DSZ : File: empty
DSZ --> DSORZ : HEAD: Range supported
DSZ --> DSOZ : HEAD 405 or Range unsupported
DSP : Filename: "foo.bar"
DSP : Size: > 0
DSP : File: partial
DSP --> DSORP : HEAD: Range supported
DSP --> DSOZ : HEAD: 405 or Range unsupported
DSO : Filename: "foo.bar"
DSO : Size: > 0
DSO : CanResume: false
DSO --> DSOZ : checkExisting(): File partial|missing
DSO --> [*] : checkExisting(): File complete
DSOR : Filename: "foo.bar"
DSOR : Size: > 0
DSOR : CanResume: true
DSOR --> DSORP : CheckLocal: File partial
DSOR --> DSORZ : CheckLocal: File missing
DSORP : Filename: "foo.bar"
DSORP : Size: > 0
DSORP : CanResume: true
DSORP : File: partial
DSORP --> Transferring
DSORZ : Filename: "foo.bar"
DSORZ : Size: > 0
DSORZ : CanResume: true
DSORZ : File: empty
DSORZ --> Transferring
DSOZ : Filename: "foo.bar"
DSOZ : Size: > 0
DSOZ : CanResume: false
DSOZ : File: empty
DSOZ --> Transferring
Transferring --> [*]
Transferring --> ERROR
ERROR : Something went wrong
ERROR --> [*]
@enduml

103
grab/transfer.go Normal file
View File

@ -0,0 +1,103 @@
package grab
import (
"context"
"io"
"sync/atomic"
"time"
"github.com/cavaliercoder/grab/bps"
)
type transfer struct {
n int64 // must be 64bit aligned on 386
ctx context.Context
gauge bps.Gauge
lim RateLimiter
w io.Writer
r io.Reader
b []byte
}
func newTransfer(ctx context.Context, lim RateLimiter, dst io.Writer, src io.Reader, buf []byte) *transfer {
return &transfer{
ctx: ctx,
gauge: bps.NewSMA(6), // five second moving average sampling every second
lim: lim,
w: dst,
r: src,
b: buf,
}
}
// copy behaves similarly to io.CopyBuffer except that it checks for cancelation
// of the given context.Context, reports progress in a thread-safe manner and
// tracks the transfer rate.
func (c *transfer) copy() (written int64, err error) {
// maintain a bps gauge in another goroutine
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
go bps.Watch(ctx, c.gauge, c.N, time.Second)
// start the transfer
if c.b == nil {
c.b = make([]byte, 32*1024)
}
for {
select {
case <-c.ctx.Done():
err = c.ctx.Err()
return
default:
// keep working
}
nr, er := c.r.Read(c.b)
if nr > 0 {
nw, ew := c.w.Write(c.b[0:nr])
if nw > 0 {
written += int64(nw)
atomic.StoreInt64(&c.n, written)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
// wait for rate limiter
if c.lim != nil {
err = c.lim.WaitN(c.ctx, nr)
if err != nil {
return
}
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
// N returns the number of bytes transferred.
func (c *transfer) N() (n int64) {
if c == nil {
return 0
}
n = atomic.LoadInt64(&c.n)
return
}
// BPS returns the current bytes per second transfer rate using a simple moving
// average.
func (c *transfer) BPS() (bps float64) {
if c == nil || c.gauge == nil {
return 0
}
return c.gauge.BPS()
}

75
grab/util.go Normal file
View File

@ -0,0 +1,75 @@
package grab
import (
"fmt"
"mime"
"net/http"
"os"
"path"
"path/filepath"
"time"
"github.com/lordwelch/pathvalidate"
)
// setLastModified sets the last modified timestamp of a local file according to
// the Last-Modified header returned by a remote server.
func setLastModified(resp *http.Response, filename string) error {
// https://tools.ietf.org/html/rfc7232#section-2.2
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified
header := resp.Header.Get("Last-Modified")
if header == "" {
return nil
}
lastmod, err := time.Parse(http.TimeFormat, header)
if err != nil {
return nil
}
return os.Chtimes(filename, lastmod, lastmod)
}
// mkdirp creates all missing parent directories for the destination file path.
func mkdirp(path string) error {
dir := filepath.Dir(path)
if fi, err := os.Stat(dir); err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("error checking destination directory: %v", err)
}
if err := os.MkdirAll(dir, 0777); err != nil {
return fmt.Errorf("error creating destination directory: %v", err)
}
} else if !fi.IsDir() {
panic("grab: developer error: destination path is not directory")
}
return nil
}
// guessFilename returns a filename for the given http.Response. If none can be
// determined ErrNoFilename is returned.
//
// TODO: NoStore operations should not require a filename
func guessFilename(resp *http.Response) (string, error) {
var (
err error
filename = resp.Request.URL.Path
)
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
if _, params, err := mime.ParseMediaType(cd); err == nil {
if val, ok := params["filename"]; ok {
filename = val
} // else filename directive is missing. fallback to URL.Path
}
}
// sanitize
filename, err = pathvalidate.SanitizeFilename(path.Base(filename), '_')
if err != nil {
return "", fmt.Errorf("%w: %v", ErrNoFilename, err)
}
if filename == "" || filename == "." || filename == "/" {
return "", ErrNoFilename
}
return filename, nil
}

167
grab/util_test.go Normal file
View File

@ -0,0 +1,167 @@
package grab
import (
"fmt"
"net/http"
"net/url"
"testing"
)
func TestURLFilenames(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
expect := "filename"
testCases := []string{
"http://test.com/filename",
"http://test.com/path/filename",
"http://test.com/deep/path/filename",
"http://test.com/filename?with=args",
"http://test.com/filename#with-fragment",
"http://test.com/filename?with=args&and#with-fragment",
}
for _, tc := range testCases {
req, _ := http.NewRequest("GET", tc, nil)
resp := &http.Response{
Request: req,
}
actual, err := guessFilename(resp)
if err != nil {
t.Errorf("%v", err)
}
if actual != expect {
t.Errorf("expected '%v', got '%v'", expect, actual)
}
}
})
t.Run("Invalid", func(t *testing.T) {
testCases := []string{
"http://test.com",
"http://test.com/",
"http://test.com/filename/",
"http://test.com/filename/?with=args",
"http://test.com/filename/#with-fragment",
"http://test.com/filename\x00",
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
req, err := http.NewRequest("GET", tc, nil)
if err != nil {
if tc == "http://test.com/filename\x00" {
// Since go1.12, urls with invalid control character return an error
// See https://github.com/golang/go/commit/829c5df58694b3345cb5ea41206783c8ccf5c3ca
t.Skip()
}
}
resp := &http.Response{
Request: req,
}
_, err = guessFilename(resp)
if err != ErrNoFilename {
t.Errorf("expected '%v', got '%v'", ErrNoFilename, err)
}
})
}
})
}
func TestHeaderFilenames(t *testing.T) {
u, _ := url.ParseRequestURI("http://test.com/badfilename")
resp := &http.Response{
Request: &http.Request{
URL: u,
},
Header: http.Header{},
}
setFilename := func(resp *http.Response, filename string) {
resp.Header.Set("Content-Disposition", fmt.Sprintf("attachment;filename=\"%s\"", filename))
}
t.Run("Valid", func(t *testing.T) {
expect := "filename"
testCases := []string{
"filename",
"path/filename",
"/path/filename",
"../../filename",
"/path/../../filename",
"/../../././///filename",
}
for _, tc := range testCases {
setFilename(resp, tc)
actual, err := guessFilename(resp)
if err != nil {
t.Errorf("error (%v): %v", tc, err)
}
if actual != expect {
t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual)
}
}
})
t.Run("Invalid", func(t *testing.T) {
testCases := []string{
"",
"/",
".",
"/.",
"/./",
"..",
"../",
"/../",
"/path/",
"../path/",
"filename\x00",
"filename/",
"filename//",
"filename/..",
}
for _, tc := range testCases {
setFilename(resp, tc)
if actual, err := guessFilename(resp); err != ErrNoFilename {
t.Errorf("expected: %v (%v), got: %v (%v)", ErrNoFilename, tc, err, actual)
}
}
})
}
func TestHeaderWithMissingDirective(t *testing.T) {
u, _ := url.ParseRequestURI("http://test.com/filename")
resp := &http.Response{
Request: &http.Request{
URL: u,
},
Header: http.Header{},
}
setHeader := func(resp *http.Response, value string) {
resp.Header.Set("Content-Disposition", value)
}
t.Run("Valid", func(t *testing.T) {
expect := "filename"
testCases := []string{
"inline",
"attachment",
}
for _, tc := range testCases {
setHeader(resp, tc)
actual, err := guessFilename(resp)
if err != nil {
t.Errorf("error (%v): %v", tc, err)
}
if actual != expect {
t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual)
}
}
})
}