Commit custom grab
This commit is contained in:
parent
17d26242d2
commit
f1179ff06e
3
grab/.gitignore
vendored
Normal file
3
grab/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# ignore IDE project files
|
||||||
|
*.iml
|
||||||
|
.idea/
|
14
grab/.travis.yml
Normal file
14
grab/.travis.yml
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
language: go
|
||||||
|
|
||||||
|
go:
|
||||||
|
- tip
|
||||||
|
- 1.10.x
|
||||||
|
- 1.9.x
|
||||||
|
- 1.8.x
|
||||||
|
- 1.7.x
|
||||||
|
|
||||||
|
script: make check
|
||||||
|
|
||||||
|
env:
|
||||||
|
- GOARCH=amd64
|
||||||
|
- GOARCH=386
|
26
grab/LICENSE
Normal file
26
grab/LICENSE
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
Copyright (c) 2017 Ryan Armstrong. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without modification,
|
||||||
|
are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its contributors
|
||||||
|
may be used to endorse or promote products derived from this software without
|
||||||
|
specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||||
|
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29
grab/Makefile
Normal file
29
grab/Makefile
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
GO = go
|
||||||
|
GOGET = $(GO) get -u
|
||||||
|
|
||||||
|
all: check lint
|
||||||
|
|
||||||
|
check:
|
||||||
|
cd cmd/grab && $(MAKE) -B all
|
||||||
|
$(GO) test -cover -race ./...
|
||||||
|
|
||||||
|
install:
|
||||||
|
$(GO) install -v ./...
|
||||||
|
|
||||||
|
clean:
|
||||||
|
$(GO) clean -x ./...
|
||||||
|
rm -rvf ./.test*
|
||||||
|
|
||||||
|
lint:
|
||||||
|
gofmt -l -e -s . || :
|
||||||
|
go vet . || :
|
||||||
|
golint . || :
|
||||||
|
gocyclo -over 15 . || :
|
||||||
|
misspell ./* || :
|
||||||
|
|
||||||
|
deps:
|
||||||
|
$(GOGET) github.com/golang/lint/golint
|
||||||
|
$(GOGET) github.com/fzipp/gocyclo
|
||||||
|
$(GOGET) github.com/client9/misspell/cmd/misspell
|
||||||
|
|
||||||
|
.PHONY: all check install clean lint deps
|
127
grab/README.md
Normal file
127
grab/README.md
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# grab
|
||||||
|
|
||||||
|
[](https://godoc.org/github.com/cavaliercoder/grab) [](https://travis-ci.org/cavaliercoder/grab) [](https://goreportcard.com/report/github.com/cavaliercoder/grab)
|
||||||
|
|
||||||
|
*Downloading the internet, one goroutine at a time!*
|
||||||
|
|
||||||
|
$ go get github.com/cavaliercoder/grab
|
||||||
|
|
||||||
|
Grab is a Go package for downloading files from the internet with the following
|
||||||
|
rad features:
|
||||||
|
|
||||||
|
* Monitor download progress concurrently
|
||||||
|
* Auto-resume incomplete downloads
|
||||||
|
* Guess filename from content header or URL path
|
||||||
|
* Safely cancel downloads using context.Context
|
||||||
|
* Validate downloads using checksums
|
||||||
|
* Download batches of files concurrently
|
||||||
|
* Apply rate limiters
|
||||||
|
|
||||||
|
Requires Go v1.7+
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
The following example downloads a PDF copy of the free eBook, "An Introduction
|
||||||
|
to Programming in Go" into the current working directory.
|
||||||
|
|
||||||
|
```go
|
||||||
|
resp, err := grab.Get(".", "http://www.golang-book.com/public/pdf/gobook.pdf")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Download saved to", resp.Filename)
|
||||||
|
```
|
||||||
|
|
||||||
|
The following, more complete example allows for more granular control and
|
||||||
|
periodically prints the download progress until it is complete.
|
||||||
|
|
||||||
|
The second time you run the example, it will auto-resume the previous download
|
||||||
|
and exit sooner.
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cavaliercoder/grab"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// create client
|
||||||
|
client := grab.NewClient()
|
||||||
|
req, _ := grab.NewRequest(".", "http://www.golang-book.com/public/pdf/gobook.pdf")
|
||||||
|
|
||||||
|
// start download
|
||||||
|
fmt.Printf("Downloading %v...\n", req.URL())
|
||||||
|
resp := client.Do(req)
|
||||||
|
fmt.Printf(" %v\n", resp.HTTPResponse.Status)
|
||||||
|
|
||||||
|
// start UI loop
|
||||||
|
t := time.NewTicker(500 * time.Millisecond)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
fmt.Printf(" transferred %v / %v bytes (%.2f%%)\n",
|
||||||
|
resp.BytesComplete(),
|
||||||
|
resp.Size(),
|
||||||
|
100*resp.Progress())
|
||||||
|
|
||||||
|
case <-resp.Done:
|
||||||
|
// download is complete
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for errors
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Download failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Download saved to ./%v \n", resp.Filename)
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// Downloading http://www.golang-book.com/public/pdf/gobook.pdf...
|
||||||
|
// 200 OK
|
||||||
|
// transferred 42970 / 2893557 bytes (1.49%)
|
||||||
|
// transferred 1207474 / 2893557 bytes (41.73%)
|
||||||
|
// transferred 2758210 / 2893557 bytes (95.32%)
|
||||||
|
// Download saved to ./gobook.pdf
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Design trade-offs
|
||||||
|
|
||||||
|
The primary use case for Grab is to concurrently downloading thousands of large
|
||||||
|
files from remote file repositories where the remote files are immutable.
|
||||||
|
Examples include operating system package repositories or ISO libraries.
|
||||||
|
|
||||||
|
Grab aims to provide robust, sane defaults. These are usually determined using
|
||||||
|
the HTTP specifications, or by mimicking the behavior of common web clients like
|
||||||
|
cURL, wget and common web browsers.
|
||||||
|
|
||||||
|
Grab aims to be stateless. The only state that exists is the remote files you
|
||||||
|
wish to download and the local copy which may be completed, partially completed
|
||||||
|
or not yet created. The advantage to this is that the local file system is not
|
||||||
|
cluttered unnecessarily with addition state files (like a `.crdownload` file).
|
||||||
|
The disadvantage of this approach is that grab must make assumptions about the
|
||||||
|
local and remote state; specifically, that they have not been modified by
|
||||||
|
another program.
|
||||||
|
|
||||||
|
If the local or remote file are modified outside of grab, and you download the
|
||||||
|
file again with resuming enabled, the local file will likely become corrupted.
|
||||||
|
In this case, you might consider making remote files immutable, or disabling
|
||||||
|
resume.
|
||||||
|
|
||||||
|
Grab aims to enable best-in-class functionality for more complex features
|
||||||
|
through extensible interfaces, rather than reimplementation. For example,
|
||||||
|
you can provide your own Hash algorithm to compute file checksums, or your
|
||||||
|
own rate limiter implementation (with all the associated trade-offs) to rate
|
||||||
|
limit downloads.
|
54
grab/bps/bps.go
Normal file
54
grab/bps/bps.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
/*
|
||||||
|
Package bps provides gauges for calculating the Bytes Per Second transfer rate
|
||||||
|
of data streams.
|
||||||
|
*/
|
||||||
|
package bps
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Gauge is the common interface for all BPS gauges in this package. Given a
|
||||||
|
// set of samples over time, each gauge type can be used to measure the Bytes
|
||||||
|
// Per Second transfer rate of a data stream.
|
||||||
|
//
|
||||||
|
// All samples must monotonically increase in timestamp and value. Each sample
|
||||||
|
// should represent the total number of bytes sent in a stream, rather than
|
||||||
|
// accounting for the number sent since the last sample.
|
||||||
|
//
|
||||||
|
// To ensure a gauge can report progress as quickly as possible, take an initial
|
||||||
|
// sample when your stream first starts.
|
||||||
|
//
|
||||||
|
// All gauge implementations are safe for concurrent use.
|
||||||
|
type Gauge interface {
|
||||||
|
// Sample adds a new sample of the progress of the monitored stream.
|
||||||
|
Sample(t time.Time, n int64)
|
||||||
|
|
||||||
|
// BPS returns the calculated Bytes Per Second rate of the monitored stream.
|
||||||
|
BPS() float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// SampleFunc is used by Watch to take periodic samples of a monitored stream.
|
||||||
|
type SampleFunc func() (n int64)
|
||||||
|
|
||||||
|
// Watch will periodically call the given SampleFunc to sample the progress of
|
||||||
|
// a monitored stream and update the given gauge. SampleFunc should return the
|
||||||
|
// total number of bytes transferred by the stream since it started.
|
||||||
|
//
|
||||||
|
// Watch is a blocking call and should typically be called in a new goroutine.
|
||||||
|
// To prevent the goroutine from leaking, make sure to cancel the given context
|
||||||
|
// once the stream is completed or canceled.
|
||||||
|
func Watch(ctx context.Context, g Gauge, f SampleFunc, interval time.Duration) {
|
||||||
|
g.Sample(time.Now(), f())
|
||||||
|
t := time.NewTicker(interval)
|
||||||
|
defer t.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case now := <-t.C:
|
||||||
|
g.Sample(now, f())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
81
grab/bps/sma.go
Normal file
81
grab/bps/sma.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package bps
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewSMA returns a gauge that uses a Simple Moving Average with the given
|
||||||
|
// number of samples to measure the bytes per second of a byte stream.
|
||||||
|
//
|
||||||
|
// BPS is computed using the timestamp of the most recent and oldest sample in
|
||||||
|
// the sample buffer. When a new sample is added, the oldest sample is dropped
|
||||||
|
// if the sample count exceeds maxSamples.
|
||||||
|
//
|
||||||
|
// The gauge does not account for any latency in arrival time of new samples or
|
||||||
|
// the desired window size. Any variance in the arrival of samples will result
|
||||||
|
// in a BPS measurement that is correct for the submitted samples, but over a
|
||||||
|
// varying time window.
|
||||||
|
//
|
||||||
|
// maxSamples should be equal to 1 + (window size / sampling interval) where
|
||||||
|
// window size is the number of seconds over which the moving average is
|
||||||
|
// smoothed and sampling interval is the number of seconds between each sample.
|
||||||
|
//
|
||||||
|
// For example, if you want a five second window, sampling once per second,
|
||||||
|
// maxSamples should be 1 + 5/1 = 6.
|
||||||
|
func NewSMA(maxSamples int) Gauge {
|
||||||
|
if maxSamples < 2 {
|
||||||
|
panic("sample count must be greater than 1")
|
||||||
|
}
|
||||||
|
return &sma{
|
||||||
|
maxSamples: uint64(maxSamples),
|
||||||
|
samples: make([]int64, maxSamples),
|
||||||
|
timestamps: make([]time.Time, maxSamples),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sma struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
index uint64
|
||||||
|
maxSamples uint64
|
||||||
|
sampleCount uint64
|
||||||
|
samples []int64
|
||||||
|
timestamps []time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sma) Sample(t time.Time, n int64) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.timestamps[c.index] = t
|
||||||
|
c.samples[c.index] = n
|
||||||
|
c.index = (c.index + 1) % c.maxSamples
|
||||||
|
|
||||||
|
// prevent integer overflow in sampleCount. Values greater or equal to
|
||||||
|
// maxSamples have the same semantic meaning.
|
||||||
|
c.sampleCount++
|
||||||
|
if c.sampleCount > c.maxSamples {
|
||||||
|
c.sampleCount = c.maxSamples
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sma) BPS() float64 {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
// we need two samples to start
|
||||||
|
if c.sampleCount < 2 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// First sample is always the oldest until ring buffer first overflows
|
||||||
|
oldest := c.index
|
||||||
|
if c.sampleCount < c.maxSamples {
|
||||||
|
oldest = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
newest := (c.index + c.maxSamples - 1) % c.maxSamples
|
||||||
|
seconds := c.timestamps[newest].Sub(c.timestamps[oldest]).Seconds()
|
||||||
|
bytes := float64(c.samples[newest] - c.samples[oldest])
|
||||||
|
return bytes / seconds
|
||||||
|
}
|
55
grab/bps/sma_test.go
Normal file
55
grab/bps/sma_test.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package bps
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sample struct {
|
||||||
|
N int64
|
||||||
|
Expect float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSimpleSamples(sampleCount, rate int) []Sample {
|
||||||
|
a := make([]Sample, sampleCount)
|
||||||
|
for i := 1; i < sampleCount; i++ {
|
||||||
|
a[i] = Sample{N: int64(i * rate), Expect: float64(rate)}
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
type SampleSetTest struct {
|
||||||
|
Gauge Gauge
|
||||||
|
Interval time.Duration
|
||||||
|
Samples []Sample
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SampleSetTest) Run(t *testing.T) {
|
||||||
|
ts := time.Unix(0, 0)
|
||||||
|
for i, sample := range c.Samples {
|
||||||
|
c.Gauge.Sample(ts, sample.N)
|
||||||
|
if actual := c.Gauge.BPS(); actual != sample.Expect {
|
||||||
|
t.Errorf("expected: Gauge.BPS() → %0.2f, got %0.2f in test %d", sample.Expect, actual, i+1)
|
||||||
|
}
|
||||||
|
ts = ts.Add(c.Interval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMA_SimpleSteadyCase(t *testing.T) {
|
||||||
|
test := &SampleSetTest{
|
||||||
|
Interval: time.Second,
|
||||||
|
Samples: getSimpleSamples(100000, 3),
|
||||||
|
}
|
||||||
|
t.Run("SmallSampleSize", func(t *testing.T) {
|
||||||
|
test.Gauge = NewSMA(2)
|
||||||
|
test.Run(t)
|
||||||
|
})
|
||||||
|
t.Run("RegularSize", func(t *testing.T) {
|
||||||
|
test.Gauge = NewSMA(6)
|
||||||
|
test.Run(t)
|
||||||
|
})
|
||||||
|
t.Run("LargeSampleSize", func(t *testing.T) {
|
||||||
|
test.Gauge = NewSMA(1000)
|
||||||
|
test.Run(t)
|
||||||
|
})
|
||||||
|
}
|
570
grab/client.go
Normal file
570
grab/client.go
Normal file
@ -0,0 +1,570 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPClient provides an interface allowing us to perform HTTP requests.
|
||||||
|
type HTTPClient interface {
|
||||||
|
Do(req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncater is a private interface allowing different response
|
||||||
|
// Writers to be truncated
|
||||||
|
type truncater interface {
|
||||||
|
Truncate(size int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Client is a file download client.
|
||||||
|
//
|
||||||
|
// Clients are safe for concurrent use by multiple goroutines.
|
||||||
|
type Client struct {
|
||||||
|
// HTTPClient specifies the http.Client which will be used for communicating
|
||||||
|
// with the remote server during the file transfer.
|
||||||
|
HTTPClient HTTPClient
|
||||||
|
|
||||||
|
// UserAgent specifies the User-Agent string which will be set in the
|
||||||
|
// headers of all requests made by this client.
|
||||||
|
//
|
||||||
|
// The user agent string may be overridden in the headers of each request.
|
||||||
|
UserAgent string
|
||||||
|
|
||||||
|
// BufferSize specifies the size in bytes of the buffer that is used for
|
||||||
|
// transferring all requested files. Larger buffers may result in faster
|
||||||
|
// throughput but will use more memory and result in less frequent updates
|
||||||
|
// to the transfer progress statistics. The BufferSize of each request can
|
||||||
|
// be overridden on each Request object. Default: 32KB.
|
||||||
|
BufferSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient returns a new file download Client, using default configuration.
|
||||||
|
func NewClient() *Client {
|
||||||
|
return &Client{
|
||||||
|
UserAgent: "grab",
|
||||||
|
HTTPClient: &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultClient is the default client and is used by all Get convenience
|
||||||
|
// functions.
|
||||||
|
var DefaultClient = NewClient()
|
||||||
|
|
||||||
|
// Do sends a file transfer request and returns a file transfer response,
|
||||||
|
// following policy (e.g. redirects, cookies, auth) as configured on the
|
||||||
|
// client's HTTPClient.
|
||||||
|
//
|
||||||
|
// Like http.Get, Do blocks while the transfer is initiated, but returns as soon
|
||||||
|
// as the transfer has started transferring in a background goroutine, or if it
|
||||||
|
// failed early.
|
||||||
|
//
|
||||||
|
// An error is returned via Response.Err if caused by client policy (such as
|
||||||
|
// CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err
|
||||||
|
// will block the caller until the transfer is completed, successfully or
|
||||||
|
// otherwise.
|
||||||
|
func (c *Client) Do(req *Request) *Response {
|
||||||
|
// cancel will be called on all code-paths via closeResponse
|
||||||
|
ctx, cancel := context.WithCancel(req.Context())
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
resp := &Response{
|
||||||
|
Request: req,
|
||||||
|
Start: time.Now(),
|
||||||
|
Done: make(chan struct{}, 0),
|
||||||
|
Filename: req.Filename,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
bufferSize: req.BufferSize,
|
||||||
|
}
|
||||||
|
if resp.bufferSize == 0 {
|
||||||
|
// default to Client.BufferSize
|
||||||
|
resp.bufferSize = c.BufferSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run state-machine while caller is blocked to initialize the file transfer.
|
||||||
|
// Must never transition to the copyFile state - this happens next in another
|
||||||
|
// goroutine.
|
||||||
|
c.run(resp, c.statFileInfo)
|
||||||
|
|
||||||
|
// Run copyFile in a new goroutine. copyFile will no-op if the transfer is
|
||||||
|
// already complete or failed.
|
||||||
|
go c.run(resp, c.copyFile)
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoChannel executes all requests sent through the given Request channel, one
|
||||||
|
// at a time, until it is closed by another goroutine. The caller is blocked
|
||||||
|
// until the Request channel is closed and all transfers have completed. All
|
||||||
|
// responses are sent through the given Response channel as soon as they are
|
||||||
|
// received from the remote servers and can be used to track the progress of
|
||||||
|
// each download.
|
||||||
|
//
|
||||||
|
// Slow Response receivers will cause a worker to block and therefore delay the
|
||||||
|
// start of the transfer for an already initiated connection - potentially
|
||||||
|
// causing a server timeout. It is the caller's responsibility to ensure a
|
||||||
|
// sufficient buffer size is used for the Response channel to prevent this.
|
||||||
|
//
|
||||||
|
// If an error occurs during any of the file transfers it will be accessible via
|
||||||
|
// the associated Response.Err function.
|
||||||
|
func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) {
|
||||||
|
// TODO: enable cancelling of batch jobs
|
||||||
|
for req := range reqch {
|
||||||
|
resp := c.Do(req)
|
||||||
|
respch <- resp
|
||||||
|
<-resp.Done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoBatch executes all the given requests using the given number of concurrent
|
||||||
|
// workers. Control is passed back to the caller as soon as the workers are
|
||||||
|
// initiated.
|
||||||
|
//
|
||||||
|
// If the requested number of workers is less than one, a worker will be created
|
||||||
|
// for every request. I.e. all requests will be executed concurrently.
|
||||||
|
//
|
||||||
|
// If an error occurs during any of the file transfers it will be accessible via
|
||||||
|
// call to the associated Response.Err.
|
||||||
|
//
|
||||||
|
// The returned Response channel is closed only after all of the given Requests
|
||||||
|
// have completed, successfully or otherwise.
|
||||||
|
func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response {
|
||||||
|
if workers < 1 {
|
||||||
|
workers = len(requests)
|
||||||
|
}
|
||||||
|
reqch := make(chan *Request, len(requests))
|
||||||
|
respch := make(chan *Response, len(requests))
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
c.DoChannel(reqch, respch)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// queue requests
|
||||||
|
go func() {
|
||||||
|
for _, req := range requests {
|
||||||
|
reqch <- req
|
||||||
|
}
|
||||||
|
close(reqch)
|
||||||
|
wg.Wait()
|
||||||
|
close(respch)
|
||||||
|
}()
|
||||||
|
return respch
|
||||||
|
}
|
||||||
|
|
||||||
|
// An stateFunc is an action that mutates the state of a Response and returns
|
||||||
|
// the next stateFunc to be called.
|
||||||
|
type stateFunc func(*Response) stateFunc
|
||||||
|
|
||||||
|
// run calls the given stateFunc function and all subsequent returned stateFuncs
|
||||||
|
// until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc
|
||||||
|
// should mutate the state of the given Response until it has completed
|
||||||
|
// downloading or failed.
|
||||||
|
func (c *Client) run(resp *Response, f stateFunc) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-resp.ctx.Done():
|
||||||
|
if resp.IsComplete() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.err = resp.ctx.Err()
|
||||||
|
f = c.closeResponse
|
||||||
|
|
||||||
|
default:
|
||||||
|
// keep working
|
||||||
|
}
|
||||||
|
if f = f(resp); f == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// statFileInfo retrieves FileInfo for any local file matching
|
||||||
|
// Response.Filename.
|
||||||
|
//
|
||||||
|
// If the file does not exist, is a directory, or its name is unknown the next
|
||||||
|
// stateFunc is headRequest.
|
||||||
|
//
|
||||||
|
// If the file exists, Response.fi is set and the next stateFunc is
|
||||||
|
// validateLocal.
|
||||||
|
//
|
||||||
|
// If an error occurs, the next stateFunc is closeResponse.
|
||||||
|
func (c *Client) statFileInfo(resp *Response) stateFunc {
|
||||||
|
if resp.Request.NoStore || resp.Filename == "" { // No filename provided will guess
|
||||||
|
return c.headRequest
|
||||||
|
}
|
||||||
|
fi, err := os.Stat(resp.Filename)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) { // Filename does not exist, will download
|
||||||
|
return c.headRequest
|
||||||
|
}
|
||||||
|
resp.err = err
|
||||||
|
return c.closeResponse // Other PathError occured
|
||||||
|
}
|
||||||
|
if fi.IsDir() { // resp.Request.Filename is a directory
|
||||||
|
// Will guess filename and append to resp.Request.Filename
|
||||||
|
resp.Filename = ""
|
||||||
|
return c.headRequest
|
||||||
|
}
|
||||||
|
resp.fi = fi
|
||||||
|
return c.validateLocal // Filename exists, validate file
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateLocal compares a local copy of the downloaded file to the remote
|
||||||
|
// file.
|
||||||
|
//
|
||||||
|
// An error is returned if the local file is larger than the remote file, or
|
||||||
|
// Request.SkipExisting is true.
|
||||||
|
//
|
||||||
|
// If the existing file matches the length of the remote file, the next
|
||||||
|
// stateFunc is checksumFile.
|
||||||
|
//
|
||||||
|
// If the local file is smaller than the remote file and the remote server is
|
||||||
|
// known to support ranged requests, the next stateFunc is getRequest.
|
||||||
|
func (c *Client) validateLocal(resp *Response) stateFunc {
|
||||||
|
if resp.Request.SkipExisting {
|
||||||
|
resp.err = ErrFileExists
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// determine target file size
|
||||||
|
expectedSize := resp.Request.Size
|
||||||
|
if expectedSize == 0 && resp.HTTPResponse != nil {
|
||||||
|
expectedSize = resp.HTTPResponse.ContentLength
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedSize == 0 {
|
||||||
|
// size is either actually 0 or unknown
|
||||||
|
// if unknown, we ask the remote server
|
||||||
|
// if known to be 0, we proceed with a GET
|
||||||
|
return c.headRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedSize == resp.fi.Size() {
|
||||||
|
// local file matches remote file size - wrap it up
|
||||||
|
resp.DidResume = true
|
||||||
|
resp.bytesResumed = resp.fi.Size()
|
||||||
|
return c.checksumFile
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Request.NoResume {
|
||||||
|
// local file should be overwritten
|
||||||
|
return c.getRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedSize >= 0 && expectedSize < resp.fi.Size() {
|
||||||
|
// remote size is known, is smaller than local size and we want to resume
|
||||||
|
resp.err = ErrBadLength
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.CanResume {
|
||||||
|
// set resume range on GET request
|
||||||
|
resp.Request.HTTPRequest.Header.Set(
|
||||||
|
"Range",
|
||||||
|
fmt.Sprintf("bytes=%d-", resp.fi.Size()))
|
||||||
|
resp.DidResume = true
|
||||||
|
resp.bytesResumed = resp.fi.Size()
|
||||||
|
return c.getRequest
|
||||||
|
}
|
||||||
|
return c.headRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) checksumFile(resp *Response) stateFunc {
|
||||||
|
if resp.Request.hash == nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
if resp.Filename == "" {
|
||||||
|
panic("grab: developer error: filename not set")
|
||||||
|
}
|
||||||
|
if resp.Size() < 0 {
|
||||||
|
panic("grab: developer error: size unknown")
|
||||||
|
}
|
||||||
|
req := resp.Request
|
||||||
|
|
||||||
|
// compute checksum
|
||||||
|
var sum []byte
|
||||||
|
sum, resp.err = resp.checksumUnsafe()
|
||||||
|
if resp.err != nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare checksum
|
||||||
|
if !bytes.Equal(sum, req.checksum) {
|
||||||
|
resp.err = ErrBadChecksum
|
||||||
|
if !resp.Request.NoStore && req.deleteOnError {
|
||||||
|
if err := os.Remove(resp.Filename); err != nil {
|
||||||
|
// err should be os.PathError and include file path
|
||||||
|
resp.err = fmt.Errorf(
|
||||||
|
"cannot remove downloaded file with checksum mismatch: %v",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// doHTTPRequest sends a HTTP Request and returns the response
|
||||||
|
func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) {
|
||||||
|
if c.UserAgent != "" && req.Header.Get("User-Agent") == "" {
|
||||||
|
req.Header.Set("User-Agent", c.UserAgent)
|
||||||
|
}
|
||||||
|
return c.HTTPClient.Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) headRequest(resp *Response) stateFunc {
|
||||||
|
if resp.optionsKnown {
|
||||||
|
return c.getRequest
|
||||||
|
}
|
||||||
|
resp.optionsKnown = true
|
||||||
|
|
||||||
|
if resp.Request.NoResume {
|
||||||
|
return c.getRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Filename != "" && resp.fi == nil {
|
||||||
|
// destination path is already known and does not exist
|
||||||
|
return c.getRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
hreq := new(http.Request)
|
||||||
|
*hreq = *resp.Request.HTTPRequest
|
||||||
|
hreq.Method = "HEAD"
|
||||||
|
|
||||||
|
resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq)
|
||||||
|
if resp.err != nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
resp.HTTPResponse.Body.Close()
|
||||||
|
|
||||||
|
if resp.HTTPResponse.StatusCode != http.StatusOK {
|
||||||
|
return c.getRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
// In case of redirects during HEAD, record the final URL and use it
|
||||||
|
// instead of the original URL when sending future requests.
|
||||||
|
// This way we avoid sending potentially unsupported requests to
|
||||||
|
// the original URL, e.g. "Range", since it was the final URL
|
||||||
|
// that advertised its support.
|
||||||
|
resp.Request.HTTPRequest.URL = resp.HTTPResponse.Request.URL
|
||||||
|
resp.Request.HTTPRequest.Host = resp.HTTPResponse.Request.Host
|
||||||
|
|
||||||
|
return c.readResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getRequest(resp *Response) stateFunc {
|
||||||
|
resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest)
|
||||||
|
if resp.err != nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: check Content-Range
|
||||||
|
|
||||||
|
// check status code
|
||||||
|
if !resp.Request.IgnoreBadStatusCodes {
|
||||||
|
if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 {
|
||||||
|
resp.err = StatusCodeError(resp.HTTPResponse.StatusCode)
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.readResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) readResponse(resp *Response) stateFunc {
|
||||||
|
if resp.HTTPResponse == nil {
|
||||||
|
panic("grab: developer error: Response.HTTPResponse is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// check expected size
|
||||||
|
resp.sizeUnsafe = resp.HTTPResponse.ContentLength
|
||||||
|
if resp.sizeUnsafe >= 0 {
|
||||||
|
// remote size is known
|
||||||
|
resp.sizeUnsafe += resp.bytesResumed
|
||||||
|
if resp.Request.Size > 0 && resp.Request.Size != resp.sizeUnsafe {
|
||||||
|
resp.err = ErrBadLength
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check filename
|
||||||
|
if resp.Filename == "" {
|
||||||
|
filename, err := guessFilename(resp.HTTPResponse)
|
||||||
|
if err != nil {
|
||||||
|
resp.err = err
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
// Request.Filename will be empty or a directory
|
||||||
|
resp.Filename = filepath.Join(resp.Request.Filename, filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.Request.NoStore && resp.requestMethod() == "HEAD" {
|
||||||
|
if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" {
|
||||||
|
resp.CanResume = true
|
||||||
|
}
|
||||||
|
return c.statFileInfo
|
||||||
|
}
|
||||||
|
return c.openWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
// openWriter opens the destination file for writing and seeks to the location
|
||||||
|
// from whence the file transfer will resume.
|
||||||
|
//
|
||||||
|
// Requires that Response.Filename and resp.DidResume are already be set.
|
||||||
|
func (c *Client) openWriter(resp *Response) stateFunc {
|
||||||
|
if !resp.Request.NoStore && !resp.Request.NoCreateDirectories {
|
||||||
|
resp.err = mkdirp(resp.Filename)
|
||||||
|
if resp.err != nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Request.NoStore {
|
||||||
|
resp.writer = &resp.storeBuffer
|
||||||
|
} else {
|
||||||
|
// compute write flags
|
||||||
|
flag := os.O_CREATE | os.O_WRONLY
|
||||||
|
if resp.fi != nil {
|
||||||
|
if resp.DidResume {
|
||||||
|
flag = os.O_APPEND | os.O_WRONLY
|
||||||
|
} else {
|
||||||
|
// truncate later in copyFile, if not cancelled
|
||||||
|
// by BeforeCopy hook
|
||||||
|
flag = os.O_WRONLY
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// open file
|
||||||
|
f, err := os.OpenFile(resp.Filename, flag, 0666)
|
||||||
|
if err != nil {
|
||||||
|
resp.err = err
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
resp.writer = f
|
||||||
|
|
||||||
|
// seek to start or end
|
||||||
|
whence := os.SEEK_SET
|
||||||
|
if resp.bytesResumed > 0 {
|
||||||
|
whence = os.SEEK_END
|
||||||
|
}
|
||||||
|
_, resp.err = f.Seek(0, whence)
|
||||||
|
if resp.err != nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// init transfer
|
||||||
|
if resp.bufferSize < 1 {
|
||||||
|
resp.bufferSize = 32 * 1024
|
||||||
|
}
|
||||||
|
b := make([]byte, resp.bufferSize)
|
||||||
|
resp.transfer = newTransfer(
|
||||||
|
resp.Request.Context(),
|
||||||
|
resp.Request.RateLimiter,
|
||||||
|
resp.writer,
|
||||||
|
resp.HTTPResponse.Body,
|
||||||
|
b)
|
||||||
|
|
||||||
|
// next step is copyFile, but this will be called later in another goroutine
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy transfers content for a HTTP connection established via Client.do()
|
||||||
|
func (c *Client) copyFile(resp *Response) stateFunc {
|
||||||
|
if resp.IsComplete() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// run BeforeCopy hook
|
||||||
|
if f := resp.Request.BeforeCopy; f != nil {
|
||||||
|
resp.err = f(resp)
|
||||||
|
if resp.err != nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var bytesCopied int64
|
||||||
|
if resp.transfer == nil {
|
||||||
|
panic("grab: developer error: Response.transfer is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// We waited to truncate the file in openWriter() to make sure
|
||||||
|
// the BeforeCopy didn't cancel the copy. If this was an existing
|
||||||
|
// file that is not going to be resumed, truncate the contents.
|
||||||
|
if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume {
|
||||||
|
t.Truncate(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
bytesCopied, resp.err = resp.transfer.copy()
|
||||||
|
if resp.err != nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
closeWriter(resp)
|
||||||
|
|
||||||
|
// set file timestamp
|
||||||
|
if !resp.Request.NoStore && !resp.Request.IgnoreRemoteTime {
|
||||||
|
resp.err = setLastModified(resp.HTTPResponse, resp.Filename)
|
||||||
|
if resp.err != nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update transfer size if previously unknown
|
||||||
|
if resp.Size() < 0 {
|
||||||
|
discoveredSize := resp.bytesResumed + bytesCopied
|
||||||
|
atomic.StoreInt64(&resp.sizeUnsafe, discoveredSize)
|
||||||
|
if resp.Request.Size > 0 && resp.Request.Size != discoveredSize {
|
||||||
|
resp.err = ErrBadLength
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// run AfterCopy hook
|
||||||
|
if f := resp.Request.AfterCopy; f != nil {
|
||||||
|
resp.err = f(resp)
|
||||||
|
if resp.err != nil {
|
||||||
|
return c.closeResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.checksumFile
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeWriter(resp *Response) {
|
||||||
|
if closer, ok := resp.writer.(io.Closer); ok {
|
||||||
|
closer.Close()
|
||||||
|
}
|
||||||
|
resp.writer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// close finalizes the Response
|
||||||
|
func (c *Client) closeResponse(resp *Response) stateFunc {
|
||||||
|
if resp.IsComplete() {
|
||||||
|
panic("grab: developer error: response already closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.fi = nil
|
||||||
|
closeWriter(resp)
|
||||||
|
resp.closeResponseBody()
|
||||||
|
|
||||||
|
resp.End = time.Now()
|
||||||
|
close(resp.Done)
|
||||||
|
if resp.cancel != nil {
|
||||||
|
resp.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
915
grab/client_test.go
Normal file
915
grab/client_test.go
Normal file
@ -0,0 +1,915 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/md5"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/sha512"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cavaliercoder/grab/grabtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestFilenameResolutions tests that the destination filename for Requests can
|
||||||
|
// be determined correctly, using an explicitly requested path,
|
||||||
|
// Content-Disposition headers or a URL path - with or without an existing
|
||||||
|
// target directory.
|
||||||
|
func TestFilenameResolution(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
Name string
|
||||||
|
Filename string
|
||||||
|
URL string
|
||||||
|
AttachmentFilename string
|
||||||
|
Expect string
|
||||||
|
}{
|
||||||
|
{"Using Request.Filename", ".testWithFilename", "/url-filename", "header-filename", ".testWithFilename"},
|
||||||
|
{"Using Content-Disposition Header", "", "/url-filename", ".testWithHeaderFilename", ".testWithHeaderFilename"},
|
||||||
|
{"Using Content-Disposition Header with target directory", ".test", "/url-filename", "header-filename", ".test/header-filename"},
|
||||||
|
{"Using URL Path", "", "/.testWithURLFilename?params-filename", "", ".testWithURLFilename"},
|
||||||
|
{"Using URL Path with target directory", ".test", "/url-filename?garbage", "", ".test/url-filename"},
|
||||||
|
{"Failure", "", "", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := os.Mkdir(".test", 0777)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(".test")
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.Name, func(t *testing.T) {
|
||||||
|
opts := []grabtest.HandlerOption{}
|
||||||
|
if test.AttachmentFilename != "" {
|
||||||
|
opts = append(opts, grabtest.AttachmentFilename(test.AttachmentFilename))
|
||||||
|
}
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(test.Filename, url+test.URL)
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
defer os.Remove(resp.Filename)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
if test.Expect != "" || err != ErrNoFilename {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if test.Expect == "" {
|
||||||
|
t.Errorf("expected: %v, got: %v", ErrNoFilename, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resp.Filename != test.Expect {
|
||||||
|
t.Errorf("Filename mismatch. Expected '%s', got '%s'.", test.Expect, resp.Filename)
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
}, opts...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestChecksums checks that checksum validation behaves as expected for valid
|
||||||
|
// and corrupted downloads.
|
||||||
|
func TestChecksums(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
size int
|
||||||
|
hash hash.Hash
|
||||||
|
sum string
|
||||||
|
match bool
|
||||||
|
}{
|
||||||
|
{128, md5.New(), "37eff01866ba3f538421b30b7cbefcac", true},
|
||||||
|
{128, md5.New(), "37eff01866ba3f538421b30b7cbefcad", false},
|
||||||
|
{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855b", true},
|
||||||
|
{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855c", false},
|
||||||
|
{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef372", true},
|
||||||
|
{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef373", false},
|
||||||
|
{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d535", true},
|
||||||
|
{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d536", false},
|
||||||
|
{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b77", true},
|
||||||
|
{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b78", false},
|
||||||
|
{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923be", true},
|
||||||
|
{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923bf", false},
|
||||||
|
{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be5", true},
|
||||||
|
{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be4", false},
|
||||||
|
{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c9", true},
|
||||||
|
{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c8", false},
|
||||||
|
{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83", true},
|
||||||
|
{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c82", false},
|
||||||
|
{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f7", true},
|
||||||
|
{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f8", false},
|
||||||
|
{1024, sha512.New(), "37f652be867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566c", true},
|
||||||
|
{1024, sha512.New(), "37f652bf867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566d", false},
|
||||||
|
{1048576, sha512.New(), "ac1d097b4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", true},
|
||||||
|
{1048576, sha512.New(), "ac1d097c4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
var expect error
|
||||||
|
comparison := "Match"
|
||||||
|
if !test.match {
|
||||||
|
comparison = "Mismatch"
|
||||||
|
expect = ErrBadChecksum
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run(fmt.Sprintf("With%s%s", comparison, test.sum[:8]), func(t *testing.T) {
|
||||||
|
filename := fmt.Sprintf(".testChecksum-%s-%s", comparison, test.sum[:8])
|
||||||
|
defer os.Remove(filename)
|
||||||
|
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.SetChecksum(test.hash, grabtest.MustHexDecodeString(test.sum), true)
|
||||||
|
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
err := resp.Err()
|
||||||
|
if err != expect {
|
||||||
|
t.Errorf("expected error: %v, got: %v", expect, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure mismatch file was deleted
|
||||||
|
if !test.match {
|
||||||
|
if _, err := os.Stat(filename); err == nil {
|
||||||
|
t.Errorf("checksum failure not cleaned up: %s", filename)
|
||||||
|
} else if !os.IsNotExist(err) {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testComplete(t, resp)
|
||||||
|
}, grabtest.ContentLength(test.size))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestContentLength ensures that ErrBadLength is returned if a server response
|
||||||
|
// does not match the requested length.
|
||||||
|
func TestContentLength(t *testing.T) {
|
||||||
|
size := int64(32768)
|
||||||
|
testCases := []struct {
|
||||||
|
Name string
|
||||||
|
NoHead bool
|
||||||
|
Size int64
|
||||||
|
Expect int64
|
||||||
|
Match bool
|
||||||
|
}{
|
||||||
|
{"Good size in HEAD request", false, size, size, true},
|
||||||
|
{"Good size in GET request", true, size, size, true},
|
||||||
|
{"Bad size in HEAD request", false, size - 1, size, false},
|
||||||
|
{"Bad size in GET request", true, size - 1, size, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
t.Run(test.Name, func(t *testing.T) {
|
||||||
|
opts := []grabtest.HandlerOption{
|
||||||
|
grabtest.ContentLength(int(test.Size)),
|
||||||
|
}
|
||||||
|
if test.NoHead {
|
||||||
|
opts = append(opts, grabtest.MethodWhitelist("GET"))
|
||||||
|
}
|
||||||
|
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(".testSize-mismatch-head", url)
|
||||||
|
req.Size = size
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
defer os.Remove(resp.Filename)
|
||||||
|
err := resp.Err()
|
||||||
|
if test.Match {
|
||||||
|
if err == ErrBadLength {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
} else if err != nil {
|
||||||
|
panic(err)
|
||||||
|
} else if resp.Size() != size {
|
||||||
|
t.Errorf("expected %v bytes, got %v bytes", size, resp.Size())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected: %v, got %v", ErrBadLength, err)
|
||||||
|
} else if err != ErrBadLength {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
}, opts...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAutoResume tests segmented downloading of a large file.
|
||||||
|
func TestAutoResume(t *testing.T) {
|
||||||
|
segs := 8
|
||||||
|
size := 1048576
|
||||||
|
sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grabtest.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83")
|
||||||
|
filename := ".testAutoResume"
|
||||||
|
|
||||||
|
defer os.Remove(filename)
|
||||||
|
|
||||||
|
for i := 0; i < segs; i++ {
|
||||||
|
segsize := (i + 1) * (size / segs)
|
||||||
|
t.Run(fmt.Sprintf("With%vBytes", segsize), func(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
if i == segs-1 {
|
||||||
|
req.SetChecksum(sha256.New(), sum, false)
|
||||||
|
}
|
||||||
|
resp := mustDo(req)
|
||||||
|
if i > 0 && !resp.DidResume {
|
||||||
|
t.Errorf("expected Response.DidResume to be true")
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
},
|
||||||
|
grabtest.ContentLength(segsize),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("WithFailure", func(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
// request smaller segment
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != ErrBadLength {
|
||||||
|
t.Errorf("expected ErrBadLength for smaller request, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
grabtest.ContentLength(size-128),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WithNoResume", func(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.NoResume = true
|
||||||
|
resp := mustDo(req)
|
||||||
|
if resp.DidResume {
|
||||||
|
t.Errorf("expected Response.DidResume to be false")
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
},
|
||||||
|
grabtest.ContentLength(size+128),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WithNoResumeAndTruncate", func(t *testing.T) {
|
||||||
|
size := size - 128
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.NoResume = true
|
||||||
|
resp := mustDo(req)
|
||||||
|
if resp.DidResume {
|
||||||
|
t.Errorf("expected Response.DidResume to be false")
|
||||||
|
}
|
||||||
|
if v := resp.BytesComplete(); v != int64(size) {
|
||||||
|
t.Errorf("expected Response.BytesComplete: %d, got: %d", size, v)
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
},
|
||||||
|
grabtest.ContentLength(size),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WithNoContentLengthHeader", func(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.SetChecksum(sha256.New(), sum, false)
|
||||||
|
resp := mustDo(req)
|
||||||
|
if !resp.DidResume {
|
||||||
|
t.Errorf("expected Response.DidResume to be true")
|
||||||
|
}
|
||||||
|
if actual := resp.Size(); actual != int64(size) {
|
||||||
|
t.Errorf("expected Response.Size: %d, got: %d", size, actual)
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
},
|
||||||
|
grabtest.ContentLength(size),
|
||||||
|
grabtest.HeaderBlacklist("Content-Length"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WithNoContentLengthHeaderAndChecksumFailure", func(t *testing.T) {
|
||||||
|
// ref: https://github.com/cavaliercoder/grab/pull/27
|
||||||
|
size := size * 2
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.SetChecksum(sha256.New(), sum, false)
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != ErrBadChecksum {
|
||||||
|
t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
|
||||||
|
}
|
||||||
|
if !resp.DidResume {
|
||||||
|
t.Errorf("expected Response.DidResume to be true")
|
||||||
|
}
|
||||||
|
if actual := resp.BytesComplete(); actual != int64(size) {
|
||||||
|
t.Errorf("expected Response.BytesComplete: %d, got: %d", size, actual)
|
||||||
|
}
|
||||||
|
if actual := resp.Size(); actual != int64(size) {
|
||||||
|
t.Errorf("expected Response.Size: %d, got: %d", size, actual)
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
},
|
||||||
|
grabtest.ContentLength(size),
|
||||||
|
grabtest.HeaderBlacklist("Content-Length"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
// TODO: test when existing file is corrupted
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSkipExisting(t *testing.T) {
|
||||||
|
filename := ".testSkipExisting"
|
||||||
|
defer os.Remove(filename)
|
||||||
|
|
||||||
|
// download a file
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
resp := mustDo(mustNewRequest(filename, url))
|
||||||
|
testComplete(t, resp)
|
||||||
|
})
|
||||||
|
|
||||||
|
// redownload
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
resp := mustDo(mustNewRequest(filename, url))
|
||||||
|
testComplete(t, resp)
|
||||||
|
|
||||||
|
// ensure download was resumed
|
||||||
|
if !resp.DidResume {
|
||||||
|
t.Fatalf("Expected download to skip existing file, but it did not")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure all bytes were resumed
|
||||||
|
if resp.Size() == 0 || resp.Size() != resp.bytesResumed {
|
||||||
|
t.Fatalf("Expected to skip %d bytes in redownload; got %d", resp.Size(), resp.bytesResumed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// ensure checksum is performed on pre-existing file
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.SetChecksum(sha256.New(), []byte{0x01, 0x02, 0x03, 0x04}, true)
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != ErrBadChecksum {
|
||||||
|
t.Fatalf("Expected checksum error, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBatch executes multiple requests simultaneously and validates the
|
||||||
|
// responses.
|
||||||
|
func TestBatch(t *testing.T) {
|
||||||
|
tests := 32
|
||||||
|
size := 32768
|
||||||
|
sum := grabtest.MustHexDecodeString("e11360251d1173650cdcd20f111d8f1ca2e412f572e8b36a4dc067121c1799b8")
|
||||||
|
|
||||||
|
// test with 4 workers and with one per request
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
for _, workerCount := range []int{4, 0} {
|
||||||
|
// create requests
|
||||||
|
reqs := make([]*Request, tests)
|
||||||
|
for i := 0; i < len(reqs); i++ {
|
||||||
|
filename := fmt.Sprintf(".testBatch.%d", i+1)
|
||||||
|
reqs[i] = mustNewRequest(filename, url+fmt.Sprintf("/request_%d?", i+1))
|
||||||
|
reqs[i].Label = fmt.Sprintf("Test %d", i+1)
|
||||||
|
reqs[i].SetChecksum(sha256.New(), sum, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// batch run
|
||||||
|
responses := DefaultClient.DoBatch(workerCount, reqs...)
|
||||||
|
|
||||||
|
// listen for responses
|
||||||
|
Loop:
|
||||||
|
for i := 0; i < len(reqs); {
|
||||||
|
select {
|
||||||
|
case resp := <-responses:
|
||||||
|
if resp == nil {
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
t.Errorf("%s: %v", resp.Filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove test file
|
||||||
|
if resp.IsComplete() {
|
||||||
|
os.Remove(resp.Filename) // ignore errors
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
grabtest.ContentLength(size),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCancelContext tests that a batch of requests can be cancel using a
|
||||||
|
// context.Context cancellation. Requests are cancelled in multiple states:
|
||||||
|
// in-progress and unstarted.
|
||||||
|
func TestCancelContext(t *testing.T) {
|
||||||
|
fileSize := 134217728
|
||||||
|
tests := 256
|
||||||
|
client := NewClient()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
reqs := make([]*Request, tests)
|
||||||
|
for i := 0; i < tests; i++ {
|
||||||
|
req := mustNewRequest("", fmt.Sprintf("%s/.testCancelContext%d", url, i))
|
||||||
|
reqs[i] = req.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
respch := client.DoBatch(8, reqs...)
|
||||||
|
time.Sleep(time.Millisecond * 500)
|
||||||
|
cancel()
|
||||||
|
for resp := range respch {
|
||||||
|
defer os.Remove(resp.Filename)
|
||||||
|
|
||||||
|
// err should be context.Canceled or http.errRequestCanceled
|
||||||
|
if resp.Err() == nil || !strings.Contains(resp.Err().Error(), "canceled") {
|
||||||
|
t.Errorf("expected '%v', got '%v'", context.Canceled, resp.Err())
|
||||||
|
}
|
||||||
|
if resp.BytesComplete() >= int64(fileSize) {
|
||||||
|
t.Errorf("expected Response.BytesComplete: < %d, got: %d", fileSize, resp.BytesComplete())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
grabtest.ContentLength(fileSize),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCancelHangingResponse tests that a never ending request is terminated
|
||||||
|
// when the response is cancelled.
|
||||||
|
func TestCancelHangingResponse(t *testing.T) {
|
||||||
|
fileSize := 10
|
||||||
|
client := NewClient()
|
||||||
|
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest("", fmt.Sprintf("%s/.testCancelHangingResponse", url))
|
||||||
|
|
||||||
|
resp := client.Do(req)
|
||||||
|
defer os.Remove(resp.Filename)
|
||||||
|
|
||||||
|
// Wait for some bytes to be transferred
|
||||||
|
for resp.BytesComplete() == 0 {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error)
|
||||||
|
go func() {
|
||||||
|
done <- resp.Cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Errorf("Expected context.Canceled error, go: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("response was not cancelled within 1s")
|
||||||
|
}
|
||||||
|
if resp.BytesComplete() == int64(fileSize) {
|
||||||
|
t.Error("download was not supposed to be complete")
|
||||||
|
}
|
||||||
|
fmt.Println("bye")
|
||||||
|
},
|
||||||
|
grabtest.RateLimiter(1),
|
||||||
|
grabtest.ContentLength(fileSize),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNestedDirectory tests that missing subdirectories are created.
|
||||||
|
func TestNestedDirectory(t *testing.T) {
|
||||||
|
dir := "./.testNested/one/two/three"
|
||||||
|
filename := ".testNestedFile"
|
||||||
|
expect := dir + "/" + filename
|
||||||
|
|
||||||
|
t.Run("Create", func(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
resp := mustDo(mustNewRequest(expect, url+"/"+filename))
|
||||||
|
defer os.RemoveAll("./.testNested/")
|
||||||
|
if resp.Filename != expect {
|
||||||
|
t.Errorf("expected nested Request.Filename to be %v, got %v", expect, resp.Filename)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("No create", func(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(expect, url+"/"+filename)
|
||||||
|
req.NoCreateDirectories = true
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
err := resp.Err()
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
t.Errorf("expected: %v, got: %v", os.ErrNotExist, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRemoteTime tests that the timestamp of the downloaded file can be set
|
||||||
|
// according to the timestamp of the remote file.
|
||||||
|
func TestRemoteTime(t *testing.T) {
|
||||||
|
filename := "./.testRemoteTime"
|
||||||
|
defer os.Remove(filename)
|
||||||
|
|
||||||
|
// random time between epoch and now
|
||||||
|
expect := time.Unix(rand.Int63n(time.Now().Unix()), 0)
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
resp := mustDo(mustNewRequest(filename, url))
|
||||||
|
fi, err := os.Stat(resp.Filename)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
actual := fi.ModTime()
|
||||||
|
if !actual.Equal(expect) {
|
||||||
|
t.Errorf("expected %v, got %v", expect, actual)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
grabtest.LastModified(expect),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseCode(t *testing.T) {
|
||||||
|
filename := "./.testResponseCode"
|
||||||
|
|
||||||
|
t.Run("With404", func(t *testing.T) {
|
||||||
|
defer os.Remove(filename)
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
expect := StatusCodeError(http.StatusNotFound)
|
||||||
|
err := resp.Err()
|
||||||
|
if err != expect {
|
||||||
|
t.Errorf("expected %v, got '%v'", expect, err)
|
||||||
|
}
|
||||||
|
if !IsStatusCodeError(err) {
|
||||||
|
t.Errorf("expected IsStatusCodeError to return true for %T: %v", err, err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
grabtest.StatusCodeStatic(http.StatusNotFound),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WithIgnoreNon2XX", func(t *testing.T) {
|
||||||
|
defer os.Remove(filename)
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.IgnoreBadStatusCodes = true
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
t.Errorf("expected nil, got '%v'", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
grabtest.StatusCodeStatic(http.StatusNotFound),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBeforeCopyHook(t *testing.T) {
|
||||||
|
filename := "./.testBeforeCopy"
|
||||||
|
t.Run("Noop", func(t *testing.T) {
|
||||||
|
defer os.RemoveAll(filename)
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
called := false
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.BeforeCopy = func(resp *Response) error {
|
||||||
|
called = true
|
||||||
|
if resp.IsComplete() {
|
||||||
|
t.Error("Response object passed to BeforeCopy hook has already been closed")
|
||||||
|
}
|
||||||
|
if resp.Progress() != 0 {
|
||||||
|
t.Error("Download progress already > 0 when BeforeCopy hook was called")
|
||||||
|
}
|
||||||
|
if resp.Duration() == 0 {
|
||||||
|
t.Error("Duration was zero when BeforeCopy was called")
|
||||||
|
}
|
||||||
|
if resp.BytesComplete() != 0 {
|
||||||
|
t.Error("BytesComplete already > 0 when BeforeCopy hook was called")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
t.Errorf("unexpected error using BeforeCopy hook: %v", err)
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
if !called {
|
||||||
|
t.Error("BeforeCopy hook was never called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WithError", func(t *testing.T) {
|
||||||
|
defer os.RemoveAll(filename)
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
testError := errors.New("test")
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.BeforeCopy = func(resp *Response) error {
|
||||||
|
return testError
|
||||||
|
}
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != testError {
|
||||||
|
t.Errorf("expected error '%v', got '%v'", testError, err)
|
||||||
|
}
|
||||||
|
if resp.BytesComplete() != 0 {
|
||||||
|
t.Errorf("expected 0 bytes completed for canceled BeforeCopy hook, got %d",
|
||||||
|
resp.BytesComplete())
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Assert that an existing local file will not be truncated prior to the
|
||||||
|
// BeforeCopy hook has a chance to cancel the request
|
||||||
|
t.Run("NoTruncate", func(t *testing.T) {
|
||||||
|
tfile, err := ioutil.TempFile("", "grab_client_test.*.file")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tfile.Name())
|
||||||
|
|
||||||
|
const size = 128
|
||||||
|
_, err = tfile.Write(bytes.Repeat([]byte("x"), size))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
called := false
|
||||||
|
req := mustNewRequest(tfile.Name(), url)
|
||||||
|
req.NoResume = true
|
||||||
|
req.BeforeCopy = func(resp *Response) error {
|
||||||
|
called = true
|
||||||
|
fi, err := tfile.Stat()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to stat temp file: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if fi.Size() != size {
|
||||||
|
t.Errorf("expected existing file size of %d bytes "+
|
||||||
|
"prior to BeforeCopy hook, got %d", size, fi.Size())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
t.Errorf("unexpected error using BeforeCopy hook: %v", err)
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
if !called {
|
||||||
|
t.Error("BeforeCopy hook was never called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAfterCopyHook(t *testing.T) {
|
||||||
|
filename := "./.testAfterCopy"
|
||||||
|
t.Run("Noop", func(t *testing.T) {
|
||||||
|
defer os.RemoveAll(filename)
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
called := false
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.AfterCopy = func(resp *Response) error {
|
||||||
|
called = true
|
||||||
|
if resp.IsComplete() {
|
||||||
|
t.Error("Response object passed to AfterCopy hook has already been closed")
|
||||||
|
}
|
||||||
|
if resp.Progress() <= 0 {
|
||||||
|
t.Error("Download progress was 0 when AfterCopy hook was called")
|
||||||
|
}
|
||||||
|
if resp.Duration() == 0 {
|
||||||
|
t.Error("Duration was zero when AfterCopy was called")
|
||||||
|
}
|
||||||
|
if resp.BytesComplete() <= 0 {
|
||||||
|
t.Error("BytesComplete was 0 when AfterCopy hook was called")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
t.Errorf("unexpected error using AfterCopy hook: %v", err)
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
if !called {
|
||||||
|
t.Error("AfterCopy hook was never called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WithError", func(t *testing.T) {
|
||||||
|
defer os.RemoveAll(filename)
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
testError := errors.New("test")
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.AfterCopy = func(resp *Response) error {
|
||||||
|
return testError
|
||||||
|
}
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != testError {
|
||||||
|
t.Errorf("expected error '%v', got '%v'", testError, err)
|
||||||
|
}
|
||||||
|
if resp.BytesComplete() <= 0 {
|
||||||
|
t.Errorf("ByteCompleted was %d after AfterCopy hook was called",
|
||||||
|
resp.BytesComplete())
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIssue37(t *testing.T) {
|
||||||
|
// ref: https://github.com/cavaliercoder/grab/issues/37
|
||||||
|
filename := "./.testIssue37"
|
||||||
|
largeSize := int64(2097152)
|
||||||
|
smallSize := int64(1048576)
|
||||||
|
defer os.RemoveAll(filename)
|
||||||
|
|
||||||
|
// download large file
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
resp := mustDo(mustNewRequest(filename, url))
|
||||||
|
if resp.Size() != largeSize {
|
||||||
|
t.Errorf("expected response size: %d, got: %d", largeSize, resp.Size())
|
||||||
|
}
|
||||||
|
}, grabtest.ContentLength(int(largeSize)))
|
||||||
|
|
||||||
|
// download new, smaller version of same file
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.NoResume = true
|
||||||
|
resp := mustDo(req)
|
||||||
|
if resp.Size() != smallSize {
|
||||||
|
t.Errorf("expected response size: %d, got: %d", smallSize, resp.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
// local file should have truncated and not resumed
|
||||||
|
if resp.DidResume {
|
||||||
|
t.Errorf("expected download to truncate, resumed instead")
|
||||||
|
}
|
||||||
|
}, grabtest.ContentLength(int(smallSize)))
|
||||||
|
|
||||||
|
fi, err := os.Stat(filename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if fi.Size() != int64(smallSize) {
|
||||||
|
t.Errorf("expected file size %d, got %d", smallSize, fi.Size())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHeadBadStatus validates that HEAD requests that return non-200 can be
|
||||||
|
// ignored and succeed if the GET requests succeeeds.
|
||||||
|
//
|
||||||
|
// Fixes: https://github.com/cavaliercoder/grab/issues/43
|
||||||
|
func TestHeadBadStatus(t *testing.T) {
|
||||||
|
expect := http.StatusOK
|
||||||
|
filename := ".testIssue43"
|
||||||
|
|
||||||
|
statusFunc := func(r *http.Request) int {
|
||||||
|
if r.Method == "HEAD" {
|
||||||
|
return http.StatusForbidden
|
||||||
|
}
|
||||||
|
return http.StatusOK
|
||||||
|
}
|
||||||
|
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
testURL := fmt.Sprintf("%s/%s", url, filename)
|
||||||
|
resp := mustDo(mustNewRequest("", testURL))
|
||||||
|
if resp.HTTPResponse.StatusCode != expect {
|
||||||
|
t.Errorf(
|
||||||
|
"expected status code: %d, got:% d",
|
||||||
|
expect,
|
||||||
|
resp.HTTPResponse.StatusCode)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
grabtest.StatusCode(statusFunc),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMissingContentLength ensures that the Response.Size is correct for
|
||||||
|
// transfers where the remote server does not send a Content-Length header.
|
||||||
|
//
|
||||||
|
// TestAutoResume also covers cases with checksum validation.
|
||||||
|
//
|
||||||
|
// Kudos to Setnička Jiří <Jiri.Setnicka@ysoft.com> for identifying and raising
|
||||||
|
// a solution to this issue. Ref: https://github.com/cavaliercoder/grab/pull/27
|
||||||
|
func TestMissingContentLength(t *testing.T) {
|
||||||
|
// expectSize must be sufficiently large that DefaultClient.Do won't prefetch
|
||||||
|
// the entire body and compute ContentLength before returning a Response.
|
||||||
|
expectSize := 1048576
|
||||||
|
opts := []grabtest.HandlerOption{
|
||||||
|
grabtest.ContentLength(expectSize),
|
||||||
|
grabtest.HeaderBlacklist("Content-Length"),
|
||||||
|
grabtest.TimeToFirstByte(time.Millisecond * 100), // delay for initial read
|
||||||
|
}
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(".testMissingContentLength", url)
|
||||||
|
req.SetChecksum(
|
||||||
|
md5.New(),
|
||||||
|
grabtest.DefaultHandlerMD5ChecksumBytes,
|
||||||
|
false)
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
|
||||||
|
// ensure remote server is not sending content-length header
|
||||||
|
if v := resp.HTTPResponse.Header.Get("Content-Length"); v != "" {
|
||||||
|
panic(fmt.Sprintf("http header content length must be empty, got: %s", v))
|
||||||
|
}
|
||||||
|
if v := resp.HTTPResponse.ContentLength; v != -1 {
|
||||||
|
panic(fmt.Sprintf("http response content length must be -1, got: %d", v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// before completion, response size should be -1
|
||||||
|
if resp.Size() != -1 {
|
||||||
|
t.Errorf("expected response size: -1, got: %d", resp.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
// block for completion
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// on completion, response size should be actual transfer size
|
||||||
|
if resp.Size() != int64(expectSize) {
|
||||||
|
t.Errorf("expected response size: %d, got: %d", expectSize, resp.Size())
|
||||||
|
}
|
||||||
|
}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoStore(t *testing.T) {
|
||||||
|
filename := ".testSubdir/testNoStore"
|
||||||
|
t.Run("DefaultCase", func(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
req.NoStore = true
|
||||||
|
req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true)
|
||||||
|
resp := mustDo(req)
|
||||||
|
|
||||||
|
// ensure Response.Bytes is correct and can be reread
|
||||||
|
b, err := resp.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
grabtest.AssertSHA256Sum(
|
||||||
|
t,
|
||||||
|
grabtest.DefaultHandlerSHA256ChecksumBytes,
|
||||||
|
bytes.NewReader(b),
|
||||||
|
)
|
||||||
|
|
||||||
|
// ensure Response.Open stream is correct and can be reread
|
||||||
|
r, err := resp.Open()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
grabtest.AssertSHA256Sum(
|
||||||
|
t,
|
||||||
|
grabtest.DefaultHandlerSHA256ChecksumBytes,
|
||||||
|
r,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response.Filename should still be set
|
||||||
|
if resp.Filename != filename {
|
||||||
|
t.Errorf("expected Response.Filename: %s, got: %s", filename, resp.Filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure no files were written
|
||||||
|
paths := []string{
|
||||||
|
filename,
|
||||||
|
filepath.Base(filename),
|
||||||
|
filepath.Dir(filename),
|
||||||
|
resp.Filename,
|
||||||
|
filepath.Base(resp.Filename),
|
||||||
|
filepath.Dir(resp.Filename),
|
||||||
|
}
|
||||||
|
for _, path := range paths {
|
||||||
|
_, err := os.Stat(path)
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
t.Errorf(
|
||||||
|
"expect error: %v, got: %v, for path: %s",
|
||||||
|
os.ErrNotExist,
|
||||||
|
err,
|
||||||
|
path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ChecksumValidation", func(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
req := mustNewRequest("", url)
|
||||||
|
req.NoStore = true
|
||||||
|
req.SetChecksum(
|
||||||
|
md5.New(),
|
||||||
|
grabtest.MustHexDecodeString("deadbeefcafebabe"),
|
||||||
|
true)
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != ErrBadChecksum {
|
||||||
|
t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
1
grab/cmd/grab/.gitignore
vendored
Normal file
1
grab/cmd/grab/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
grab
|
18
grab/cmd/grab/Makefile
Normal file
18
grab/cmd/grab/Makefile
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
SOURCES = main.go
|
||||||
|
|
||||||
|
all : grab
|
||||||
|
|
||||||
|
grab: $(SOURCES)
|
||||||
|
go build -x -o grab $(SOURCES)
|
||||||
|
|
||||||
|
clean:
|
||||||
|
go clean -x
|
||||||
|
rm -vf grab
|
||||||
|
|
||||||
|
check:
|
||||||
|
go test -v .
|
||||||
|
|
||||||
|
install:
|
||||||
|
go install -v .
|
||||||
|
|
||||||
|
.PHONY: all clean check install
|
34
grab/cmd/grab/main.go
Normal file
34
grab/cmd/grab/main.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/cavaliercoder/grab/grabui"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// validate command args
|
||||||
|
if len(os.Args) < 2 {
|
||||||
|
fmt.Fprintf(os.Stderr, "usage: %s url...\n", os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
urls := os.Args[1:]
|
||||||
|
|
||||||
|
// download files
|
||||||
|
respch, err := grabui.GetBatch(context.Background(), 0, ".", urls...)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprint(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the number of failed downloads as exit code
|
||||||
|
failed := 0
|
||||||
|
for resp := range respch {
|
||||||
|
if resp.Err() != nil {
|
||||||
|
failed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os.Exit(failed)
|
||||||
|
}
|
63
grab/doc.go
Normal file
63
grab/doc.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/*
|
||||||
|
Package grab provides a HTTP download manager implementation.
|
||||||
|
|
||||||
|
Get is the most simple way to download a file:
|
||||||
|
|
||||||
|
resp, err := grab.Get("/tmp", "http://example.com/example.zip")
|
||||||
|
// ...
|
||||||
|
|
||||||
|
Get will download the given URL and save it to the given destination directory.
|
||||||
|
The destination filename will be determined automatically by grab using
|
||||||
|
Content-Disposition headers returned by the remote server, or by inspecting the
|
||||||
|
requested URL path.
|
||||||
|
|
||||||
|
An empty destination string or "." means the transfer will be stored in the
|
||||||
|
current working directory.
|
||||||
|
|
||||||
|
If a destination file already exists, grab will assume it is a complete or
|
||||||
|
partially complete download of the requested file. If the remote server supports
|
||||||
|
resuming interrupted downloads, grab will resume downloading from the end of the
|
||||||
|
partial file. If the server does not support resumed downloads, the file will be
|
||||||
|
retransferred in its entirety. If the file is already complete, grab will return
|
||||||
|
successfully.
|
||||||
|
|
||||||
|
For control over the HTTP client, destination path, auto-resume, checksum
|
||||||
|
validation and other settings, create a Client:
|
||||||
|
|
||||||
|
client := grab.NewClient()
|
||||||
|
client.HTTPClient.Transport.DisableCompression = true
|
||||||
|
|
||||||
|
req, err := grab.NewRequest("/tmp", "http://example.com/example.zip")
|
||||||
|
// ...
|
||||||
|
req.NoResume = true
|
||||||
|
req.HTTPRequest.Header.Set("Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l")
|
||||||
|
|
||||||
|
resp := client.Do(req)
|
||||||
|
// ...
|
||||||
|
|
||||||
|
You can monitor the progress of downloads while they are transferring:
|
||||||
|
|
||||||
|
client := grab.NewClient()
|
||||||
|
req, err := grab.NewRequest("", "http://example.com/example.zip")
|
||||||
|
// ...
|
||||||
|
resp := client.Do(req)
|
||||||
|
|
||||||
|
t := time.NewTicker(time.Second)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
fmt.Printf("%.02f%% complete\n", resp.Progress())
|
||||||
|
|
||||||
|
case <-resp.Done:
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
// ...
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
package grab
|
42
grab/error.go
Normal file
42
grab/error.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrBadLength indicates that the server response or an existing file does
|
||||||
|
// not match the expected content length.
|
||||||
|
ErrBadLength = errors.New("bad content length")
|
||||||
|
|
||||||
|
// ErrBadChecksum indicates that a downloaded file failed to pass checksum
|
||||||
|
// validation.
|
||||||
|
ErrBadChecksum = errors.New("checksum mismatch")
|
||||||
|
|
||||||
|
// ErrNoFilename indicates that a reasonable filename could not be
|
||||||
|
// automatically determined using the URL or response headers from a server.
|
||||||
|
ErrNoFilename = errors.New("no filename could be determined")
|
||||||
|
|
||||||
|
// ErrNoTimestamp indicates that a timestamp could not be automatically
|
||||||
|
// determined using the response headers from the remote server.
|
||||||
|
ErrNoTimestamp = errors.New("no timestamp could be determined for the remote file")
|
||||||
|
|
||||||
|
// ErrFileExists indicates that the destination path already exists.
|
||||||
|
ErrFileExists = errors.New("file exists")
|
||||||
|
)
|
||||||
|
|
||||||
|
// StatusCodeError indicates that the server response had a status code that
|
||||||
|
// was not in the 200-299 range (after following any redirects).
|
||||||
|
type StatusCodeError int
|
||||||
|
|
||||||
|
func (err StatusCodeError) Error() string {
|
||||||
|
return fmt.Sprintf("server returned %d %s", err, http.StatusText(int(err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsStatusCodeError returns true if the given error is of type StatusCodeError.
|
||||||
|
func IsStatusCodeError(err error) bool {
|
||||||
|
_, ok := err.(StatusCodeError)
|
||||||
|
return ok
|
||||||
|
}
|
95
grab/example_client_test.go
Normal file
95
grab/example_client_test.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleClient_Do() {
|
||||||
|
client := NewClient()
|
||||||
|
req, err := NewRequest("/tmp", "http://example.com/example.zip")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := client.Do(req)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Download saved to", resp.Filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example uses DoChannel to create a Producer/Consumer model for
|
||||||
|
// downloading multiple files concurrently. This is similar to how DoBatch uses
|
||||||
|
// DoChannel under the hood except that it allows the caller to continually send
|
||||||
|
// new requests until they wish to close the request channel.
|
||||||
|
func ExampleClient_DoChannel() {
|
||||||
|
// create a request and a buffered response channel
|
||||||
|
reqch := make(chan *Request)
|
||||||
|
respch := make(chan *Response, 10)
|
||||||
|
|
||||||
|
// start 4 workers
|
||||||
|
client := NewClient()
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
client.DoChannel(reqch, respch)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// send requests
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
url := fmt.Sprintf("http://example.com/example%d.zip", i+1)
|
||||||
|
req, err := NewRequest("/tmp", url)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
reqch <- req
|
||||||
|
}
|
||||||
|
close(reqch)
|
||||||
|
|
||||||
|
// wait for workers to finish
|
||||||
|
wg.Wait()
|
||||||
|
close(respch)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// check each response
|
||||||
|
for resp := range respch {
|
||||||
|
// block until complete
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleClient_DoBatch() {
|
||||||
|
// create multiple download requests
|
||||||
|
reqs := make([]*Request, 0)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
url := fmt.Sprintf("http://example.com/example%d.zip", i+1)
|
||||||
|
req, err := NewRequest("/tmp", url)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
reqs = append(reqs, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// start downloads with 4 workers
|
||||||
|
client := NewClient()
|
||||||
|
respch := client.DoBatch(4, reqs...)
|
||||||
|
|
||||||
|
// check each response
|
||||||
|
for resp := range respch {
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename)
|
||||||
|
}
|
||||||
|
}
|
52
grab/example_request_test.go
Normal file
52
grab/example_request_test.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleRequest_WithContext() {
|
||||||
|
// create context with a 100ms timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// create download request with context
|
||||||
|
req, err := NewRequest("", "http://example.com/example.zip")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
// send download request
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
fmt.Println("error: request cancelled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// error: request cancelled
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleRequest_SetChecksum() {
|
||||||
|
// create download request
|
||||||
|
req, err := NewRequest("", "http://example.com/example.zip")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// set request checksum
|
||||||
|
sum, err := hex.DecodeString("33daf4c03f86120fdfdc66bddf6bfff4661c7ca11c5da473e537f4d69b470e57")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
req.SetChecksum(sha256.New(), sum, true)
|
||||||
|
|
||||||
|
// download and validate file
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
5
grab/go.mod
Normal file
5
grab/go.mod
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
module github.com/cavaliercoder/grab
|
||||||
|
|
||||||
|
go 1.14
|
||||||
|
|
||||||
|
require github.com/lordwelch/pathvalidate v0.0.0-20201012043703-54efa7ea1308
|
5
grab/go.sum
Normal file
5
grab/go.sum
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
github.com/lordwelch/pathvalidate v0.0.0-20201012043703-54efa7ea1308 h1:CkcsZK6QYg59rc92eqU2h+FRjWltCIiplmEwIB05jfM=
|
||||||
|
github.com/lordwelch/pathvalidate v0.0.0-20201012043703-54efa7ea1308/go.mod h1:4I4r5Y/LkH+34KACiudU+Q27ooz7xSDyVEuWAVKeJEQ=
|
||||||
|
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
64
grab/grab.go
Normal file
64
grab/grab.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get sends a HTTP request and downloads the content of the requested URL to
|
||||||
|
// the given destination file path. The caller is blocked until the download is
|
||||||
|
// completed, successfully or otherwise.
|
||||||
|
//
|
||||||
|
// An error is returned if caused by client policy (such as CheckRedirect), or
|
||||||
|
// if there was an HTTP protocol or IO error.
|
||||||
|
//
|
||||||
|
// For non-blocking calls or control over HTTP client headers, redirect policy,
|
||||||
|
// and other settings, create a Client instead.
|
||||||
|
func Get(dst, urlStr string) (*Response, error) {
|
||||||
|
req, err := NewRequest(dst, urlStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
return resp, resp.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBatch sends multiple HTTP requests and downloads the content of the
|
||||||
|
// requested URLs to the given destination directory using the given number of
|
||||||
|
// concurrent worker goroutines.
|
||||||
|
//
|
||||||
|
// The Response for each requested URL is sent through the returned Response
|
||||||
|
// channel, as soon as a worker receives a response from the remote server. The
|
||||||
|
// Response can then be used to track the progress of the download while it is
|
||||||
|
// in progress.
|
||||||
|
//
|
||||||
|
// The returned Response channel will be closed by Grab, only once all downloads
|
||||||
|
// have completed or failed.
|
||||||
|
//
|
||||||
|
// If an error occurs during any download, it will be available via call to the
|
||||||
|
// associated Response.Err.
|
||||||
|
//
|
||||||
|
// For control over HTTP client headers, redirect policy, and other settings,
|
||||||
|
// create a Client instead.
|
||||||
|
func GetBatch(workers int, dst string, urlStrs ...string) (<-chan *Response, error) {
|
||||||
|
fi, err := os.Stat(dst)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !fi.IsDir() {
|
||||||
|
return nil, fmt.Errorf("destination is not a directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
reqs := make([]*Request, len(urlStrs))
|
||||||
|
for i := 0; i < len(urlStrs); i++ {
|
||||||
|
req, err := NewRequest(dst, urlStrs[i])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
reqs[i] = req
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := DefaultClient.DoBatch(workers, reqs...)
|
||||||
|
return ch, nil
|
||||||
|
}
|
74
grab/grab_test.go
Normal file
74
grab/grab_test.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cavaliercoder/grab/grabtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
os.Exit(func() int {
|
||||||
|
// chdir to temp so test files downloaded to pwd are isolated and cleaned up
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
tmpDir, err := ioutil.TempDir("", "grab-")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := os.Chdir(tmpDir); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
os.Chdir(cwd)
|
||||||
|
if err := os.RemoveAll(tmpDir); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return m.Run()
|
||||||
|
}())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGet tests grab.Get
|
||||||
|
func TestGet(t *testing.T) {
|
||||||
|
filename := ".testGet"
|
||||||
|
defer os.Remove(filename)
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
resp, err := Get(filename, url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error in Get(): %v", err)
|
||||||
|
}
|
||||||
|
testComplete(t, resp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleGet() {
|
||||||
|
// download a file to /tmp
|
||||||
|
resp, err := Get("/tmp", "http://example.com/example.zip")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Download saved to", resp.Filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustNewRequest(dst, urlStr string) *Request {
|
||||||
|
req, err := NewRequest(dst, urlStr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustDo(req *Request) *Response {
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
104
grab/grabtest/assert.go
Normal file
104
grab/grabtest/assert.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package grabtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AssertHTTPResponseStatusCode(t *testing.T, resp *http.Response, expect int) (ok bool) {
|
||||||
|
if resp.StatusCode != expect {
|
||||||
|
t.Errorf("expected status code: %d, got: %d", expect, resp.StatusCode)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ok = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertHTTPResponseHeader(t *testing.T, resp *http.Response, key, format string, a ...interface{}) (ok bool) {
|
||||||
|
expect := fmt.Sprintf(format, a...)
|
||||||
|
actual := resp.Header.Get(key)
|
||||||
|
if actual != expect {
|
||||||
|
t.Errorf("expected header %s: %s, got: %s", key, expect, actual)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ok = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertHTTPResponseContentLength(t *testing.T, resp *http.Response, n int64) (ok bool) {
|
||||||
|
ok = true
|
||||||
|
if resp.ContentLength != n {
|
||||||
|
ok = false
|
||||||
|
t.Errorf("expected header Content-Length: %d, got: %d", n, resp.ContentLength)
|
||||||
|
}
|
||||||
|
if !AssertHTTPResponseBodyLength(t, resp, n) {
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertHTTPResponseBodyLength(t *testing.T, resp *http.Response, n int64) (ok bool) {
|
||||||
|
defer func() {
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if int64(len(b)) != n {
|
||||||
|
ok = false
|
||||||
|
t.Errorf("expected body length: %d, got: %d", n, len(b))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustHTTPNewRequest(method, url string, body io.Reader) *http.Request {
|
||||||
|
req, err := http.NewRequest(method, url, body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustHTTPDo(req *http.Request) *http.Response {
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustHTTPDoWithClose(req *http.Request) *http.Response {
|
||||||
|
resp := MustHTTPDo(req)
|
||||||
|
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertSHA256Sum(t *testing.T, sum []byte, r io.Reader) (ok bool) {
|
||||||
|
h := sha256.New()
|
||||||
|
if _, err := io.Copy(h, r); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
computed := h.Sum(nil)
|
||||||
|
ok = bytes.Equal(sum, computed)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf(
|
||||||
|
"expected checksum: %s, got: %s",
|
||||||
|
MustHexEncodeString(sum),
|
||||||
|
MustHexEncodeString(computed),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
160
grab/grabtest/handler.go
Normal file
160
grab/grabtest/handler.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package grabtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
DefaultHandlerContentLength = 1 << 20
|
||||||
|
DefaultHandlerMD5Checksum = "c35cc7d8d91728a0cb052831bc4ef372"
|
||||||
|
DefaultHandlerMD5ChecksumBytes = MustHexDecodeString(DefaultHandlerMD5Checksum)
|
||||||
|
DefaultHandlerSHA256Checksum = "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83"
|
||||||
|
DefaultHandlerSHA256ChecksumBytes = MustHexDecodeString(DefaultHandlerSHA256Checksum)
|
||||||
|
)
|
||||||
|
|
||||||
|
type StatusCodeFunc func(req *http.Request) int
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
statusCodeFunc StatusCodeFunc
|
||||||
|
methodWhitelist []string
|
||||||
|
headerBlacklist []string
|
||||||
|
contentLength int
|
||||||
|
acceptRanges bool
|
||||||
|
attachmentFilename string
|
||||||
|
lastModified time.Time
|
||||||
|
ttfb time.Duration
|
||||||
|
rateLimiter *time.Ticker
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandler(options ...HandlerOption) (http.Handler, error) {
|
||||||
|
h := &handler{
|
||||||
|
statusCodeFunc: func(req *http.Request) int { return http.StatusOK },
|
||||||
|
methodWhitelist: []string{"GET", "HEAD"},
|
||||||
|
contentLength: DefaultHandlerContentLength,
|
||||||
|
acceptRanges: true,
|
||||||
|
}
|
||||||
|
for _, option := range options {
|
||||||
|
if err := option(h); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) {
|
||||||
|
h, err := NewHandler(options...)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create test server handler: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s := httptest.NewServer(h)
|
||||||
|
defer func() {
|
||||||
|
h.(*handler).close()
|
||||||
|
s.Close()
|
||||||
|
}()
|
||||||
|
f(s.URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) close() {
|
||||||
|
if h.rateLimiter != nil {
|
||||||
|
h.rateLimiter.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// delay response
|
||||||
|
if h.ttfb > 0 {
|
||||||
|
time.Sleep(h.ttfb)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate request method
|
||||||
|
allowed := false
|
||||||
|
for _, m := range h.methodWhitelist {
|
||||||
|
if r.Method == m {
|
||||||
|
allowed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
httpError(w, http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// set server options
|
||||||
|
if h.acceptRanges {
|
||||||
|
w.Header().Set("Accept-Ranges", "bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// set attachment filename
|
||||||
|
if h.attachmentFilename != "" {
|
||||||
|
w.Header().Set(
|
||||||
|
"Content-Disposition",
|
||||||
|
fmt.Sprintf("attachment;filename=\"%s\"", h.attachmentFilename),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// set last modified timestamp
|
||||||
|
lastMod := time.Now()
|
||||||
|
if !h.lastModified.IsZero() {
|
||||||
|
lastMod = h.lastModified
|
||||||
|
}
|
||||||
|
w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat))
|
||||||
|
|
||||||
|
// set content-length
|
||||||
|
offset := 0
|
||||||
|
if h.acceptRanges {
|
||||||
|
if reqRange := r.Header.Get("Range"); reqRange != "" {
|
||||||
|
if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil {
|
||||||
|
httpError(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if offset >= h.contentLength {
|
||||||
|
httpError(w, http.StatusRequestedRangeNotSatisfiable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset))
|
||||||
|
|
||||||
|
// apply header blacklist
|
||||||
|
for _, key := range h.headerBlacklist {
|
||||||
|
w.Header().Del(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// send header and status code
|
||||||
|
w.WriteHeader(h.statusCodeFunc(r))
|
||||||
|
|
||||||
|
// send body
|
||||||
|
if r.Method == "GET" {
|
||||||
|
// use buffered io to reduce overhead on the reader
|
||||||
|
bw := bufio.NewWriterSize(w, 4096)
|
||||||
|
for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ {
|
||||||
|
bw.Write([]byte{byte(i)})
|
||||||
|
if h.rateLimiter != nil {
|
||||||
|
bw.Flush()
|
||||||
|
w.(http.Flusher).Flush() // force the server to send the data to the client
|
||||||
|
select {
|
||||||
|
case <-h.rateLimiter.C:
|
||||||
|
case <-r.Context().Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isRequestClosed(r) {
|
||||||
|
bw.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRequestClosed returns true if the client request has been canceled.
|
||||||
|
func isRequestClosed(r *http.Request) bool {
|
||||||
|
return r.Context().Err() != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpError(w http.ResponseWriter, code int) {
|
||||||
|
http.Error(w, http.StatusText(code), code)
|
||||||
|
}
|
92
grab/grabtest/handler_option.go
Normal file
92
grab/grabtest/handler_option.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
package grabtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HandlerOption func(*handler) error
|
||||||
|
|
||||||
|
func StatusCodeStatic(code int) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
return StatusCode(func(req *http.Request) int {
|
||||||
|
return code
|
||||||
|
})(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func StatusCode(f StatusCodeFunc) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
if f == nil {
|
||||||
|
return errors.New("status code function cannot be nil")
|
||||||
|
}
|
||||||
|
h.statusCodeFunc = f
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MethodWhitelist(methods ...string) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
h.methodWhitelist = methods
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HeaderBlacklist(headers ...string) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
h.headerBlacklist = headers
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ContentLength(n int) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
if n < 0 {
|
||||||
|
return errors.New("content length must be zero or greater")
|
||||||
|
}
|
||||||
|
h.contentLength = n
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AcceptRanges(enabled bool) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
h.acceptRanges = enabled
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LastModified(t time.Time) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
h.lastModified = t.UTC()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TimeToFirstByte(d time.Duration) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
if d < 1 {
|
||||||
|
return errors.New("time to first byte must be greater than zero")
|
||||||
|
}
|
||||||
|
h.ttfb = d
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RateLimiter(bps int) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
if bps < 1 {
|
||||||
|
return errors.New("bytes per second must be greater than zero")
|
||||||
|
}
|
||||||
|
h.rateLimiter = time.NewTicker(time.Second / time.Duration(bps))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AttachmentFilename(filename string) HandlerOption {
|
||||||
|
return func(h *handler) error {
|
||||||
|
h.attachmentFilename = filename
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
150
grab/grabtest/handler_test.go
Normal file
150
grab/grabtest/handler_test.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
package grabtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandlerDefaults(t *testing.T) {
|
||||||
|
WithTestServer(t, func(url string) {
|
||||||
|
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
|
||||||
|
AssertHTTPResponseStatusCode(t, resp, http.StatusOK)
|
||||||
|
AssertHTTPResponseContentLength(t, resp, 1048576)
|
||||||
|
AssertHTTPResponseHeader(t, resp, "Accept-Ranges", "bytes")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerMethodWhitelist(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
Whitelist []string
|
||||||
|
Method string
|
||||||
|
ExpectStatusCode int
|
||||||
|
}{
|
||||||
|
{[]string{"GET", "HEAD"}, "GET", http.StatusOK},
|
||||||
|
{[]string{"GET", "HEAD"}, "HEAD", http.StatusOK},
|
||||||
|
{[]string{"GET"}, "HEAD", http.StatusMethodNotAllowed},
|
||||||
|
{[]string{"HEAD"}, "GET", http.StatusMethodNotAllowed},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
WithTestServer(t, func(url string) {
|
||||||
|
resp := MustHTTPDoWithClose(MustHTTPNewRequest(test.Method, url, nil))
|
||||||
|
AssertHTTPResponseStatusCode(t, resp, test.ExpectStatusCode)
|
||||||
|
}, MethodWhitelist(test.Whitelist...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerHeaderBlacklist(t *testing.T) {
|
||||||
|
contentLength := 4096
|
||||||
|
WithTestServer(t, func(url string) {
|
||||||
|
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.ContentLength != -1 {
|
||||||
|
t.Errorf("expected Response.ContentLength: -1, got: %d", resp.ContentLength)
|
||||||
|
}
|
||||||
|
AssertHTTPResponseHeader(t, resp, "Content-Length", "")
|
||||||
|
AssertHTTPResponseBodyLength(t, resp, int64(contentLength))
|
||||||
|
},
|
||||||
|
ContentLength(contentLength),
|
||||||
|
HeaderBlacklist("Content-Length"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerStatusCodeFuncs(t *testing.T) {
|
||||||
|
expect := 418 // I'm a teapot
|
||||||
|
WithTestServer(t, func(url string) {
|
||||||
|
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
|
||||||
|
AssertHTTPResponseStatusCode(t, resp, expect)
|
||||||
|
},
|
||||||
|
StatusCode(func(req *http.Request) int { return expect }),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerContentLength(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
Method string
|
||||||
|
ContentLength int
|
||||||
|
ExpectHeaderLen int64
|
||||||
|
ExpectBodyLen int
|
||||||
|
}{
|
||||||
|
{"GET", 321, 321, 321},
|
||||||
|
{"HEAD", 321, 321, 0},
|
||||||
|
{"GET", 0, 0, 0},
|
||||||
|
{"HEAD", 0, 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
WithTestServer(t, func(url string) {
|
||||||
|
resp := MustHTTPDo(MustHTTPNewRequest(test.Method, url, nil))
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
AssertHTTPResponseHeader(t, resp, "Content-Length", "%d", test.ExpectHeaderLen)
|
||||||
|
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if len(b) != test.ExpectBodyLen {
|
||||||
|
t.Errorf(
|
||||||
|
"expected body length: %v, got: %v, in: %v",
|
||||||
|
test.ExpectBodyLen,
|
||||||
|
len(b),
|
||||||
|
test,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ContentLength(test.ContentLength),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerAcceptRanges(t *testing.T) {
|
||||||
|
header := "Accept-Ranges"
|
||||||
|
n := 128
|
||||||
|
t.Run("Enabled", func(t *testing.T) {
|
||||||
|
WithTestServer(t, func(url string) {
|
||||||
|
req := MustHTTPNewRequest("GET", url, nil)
|
||||||
|
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2))
|
||||||
|
resp := MustHTTPDo(req)
|
||||||
|
AssertHTTPResponseHeader(t, resp, header, "bytes")
|
||||||
|
AssertHTTPResponseContentLength(t, resp, int64(n/2))
|
||||||
|
},
|
||||||
|
ContentLength(n),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Disabled", func(t *testing.T) {
|
||||||
|
WithTestServer(t, func(url string) {
|
||||||
|
req := MustHTTPNewRequest("GET", url, nil)
|
||||||
|
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2))
|
||||||
|
resp := MustHTTPDo(req)
|
||||||
|
AssertHTTPResponseHeader(t, resp, header, "")
|
||||||
|
AssertHTTPResponseContentLength(t, resp, int64(n))
|
||||||
|
},
|
||||||
|
AcceptRanges(false),
|
||||||
|
ContentLength(n),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerAttachmentFilename(t *testing.T) {
|
||||||
|
filename := "foo.pdf"
|
||||||
|
WithTestServer(t, func(url string) {
|
||||||
|
resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil))
|
||||||
|
AssertHTTPResponseHeader(t, resp, "Content-Disposition", `attachment;filename="%s"`, filename)
|
||||||
|
},
|
||||||
|
AttachmentFilename(filename),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerLastModified(t *testing.T) {
|
||||||
|
WithTestServer(t, func(url string) {
|
||||||
|
resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil))
|
||||||
|
AssertHTTPResponseHeader(t, resp, "Last-Modified", "Thu, 29 Nov 1973 21:33:09 GMT")
|
||||||
|
},
|
||||||
|
LastModified(time.Unix(123456789, 0)),
|
||||||
|
)
|
||||||
|
}
|
16
grab/grabtest/util.go
Normal file
16
grab/grabtest/util.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package grabtest
|
||||||
|
|
||||||
|
import "encoding/hex"
|
||||||
|
|
||||||
|
func MustHexDecodeString(s string) (b []byte) {
|
||||||
|
var err error
|
||||||
|
b, err = hex.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustHexEncodeString(b []byte) (s string) {
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
166
grab/grabui/console_client.go
Normal file
166
grab/grabui/console_client.go
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
package grabui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cavaliercoder/grab"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConsoleClient struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
client *grab.Client
|
||||||
|
succeeded, failed, inProgress int
|
||||||
|
responses []*grab.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConsoleClient(client *grab.Client) *ConsoleClient {
|
||||||
|
return &ConsoleClient{
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConsoleClient) Do(
|
||||||
|
ctx context.Context,
|
||||||
|
workers int,
|
||||||
|
reqs ...*grab.Request,
|
||||||
|
) <-chan *grab.Response {
|
||||||
|
// buffer size prevents slow receivers causing back pressure
|
||||||
|
pump := make(chan *grab.Response, len(reqs))
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.failed = 0
|
||||||
|
c.inProgress = 0
|
||||||
|
c.succeeded = 0
|
||||||
|
c.responses = make([]*grab.Response, 0, len(reqs))
|
||||||
|
if c.client == nil {
|
||||||
|
c.client = grab.DefaultClient
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Downloading %d files...\n", len(reqs))
|
||||||
|
respch := c.client.DoBatch(workers, reqs...)
|
||||||
|
t := time.NewTicker(200 * time.Millisecond)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
break Loop
|
||||||
|
|
||||||
|
case resp := <-respch:
|
||||||
|
if resp != nil {
|
||||||
|
// a new response has been received and has started downloading
|
||||||
|
c.responses = append(c.responses, resp)
|
||||||
|
pump <- resp // send to caller
|
||||||
|
} else {
|
||||||
|
// channel is closed - all downloads are complete
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-t.C:
|
||||||
|
// update UI on clock tick
|
||||||
|
c.refresh()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.refresh()
|
||||||
|
close(pump)
|
||||||
|
|
||||||
|
fmt.Printf(
|
||||||
|
"Finished %d successful, %d failed, %d incomplete.\n",
|
||||||
|
c.succeeded,
|
||||||
|
c.failed,
|
||||||
|
c.inProgress)
|
||||||
|
}()
|
||||||
|
return pump
|
||||||
|
}
|
||||||
|
|
||||||
|
// refresh prints the progress of all downloads to the terminal
|
||||||
|
func (c *ConsoleClient) refresh() {
|
||||||
|
// clear lines for incomplete downloads
|
||||||
|
if c.inProgress > 0 {
|
||||||
|
fmt.Printf("\033[%dA\033[K", c.inProgress)
|
||||||
|
}
|
||||||
|
|
||||||
|
// print newly completed downloads
|
||||||
|
for i, resp := range c.responses {
|
||||||
|
if resp != nil && resp.IsComplete() {
|
||||||
|
if resp.Err() != nil {
|
||||||
|
c.failed++
|
||||||
|
fmt.Fprintf(os.Stderr, "Error downloading %s: %v\n",
|
||||||
|
resp.Request.URL(),
|
||||||
|
resp.Err())
|
||||||
|
} else {
|
||||||
|
c.succeeded++
|
||||||
|
fmt.Printf("Finished %s %s / %s (%d%%)\n",
|
||||||
|
resp.Filename,
|
||||||
|
byteString(resp.BytesComplete()),
|
||||||
|
byteString(resp.Size()),
|
||||||
|
int(100*resp.Progress()))
|
||||||
|
}
|
||||||
|
c.responses[i] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// print progress for incomplete downloads
|
||||||
|
c.inProgress = 0
|
||||||
|
for _, resp := range c.responses {
|
||||||
|
if resp != nil {
|
||||||
|
fmt.Printf("Downloading %s %s / %s (%d%%) - %s ETA: %s \033[K\n",
|
||||||
|
resp.Filename,
|
||||||
|
byteString(resp.BytesComplete()),
|
||||||
|
byteString(resp.Size()),
|
||||||
|
int(100*resp.Progress()),
|
||||||
|
bpsString(resp.BytesPerSecond()),
|
||||||
|
etaString(resp.ETA()))
|
||||||
|
c.inProgress++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bpsString(n float64) string {
|
||||||
|
if n < 1e3 {
|
||||||
|
return fmt.Sprintf("%.02fBps", n)
|
||||||
|
}
|
||||||
|
if n < 1e6 {
|
||||||
|
return fmt.Sprintf("%.02fKB/s", n/1e3)
|
||||||
|
}
|
||||||
|
if n < 1e9 {
|
||||||
|
return fmt.Sprintf("%.02fMB/s", n/1e6)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.02fGB/s", n/1e9)
|
||||||
|
}
|
||||||
|
|
||||||
|
func byteString(n int64) string {
|
||||||
|
if n < 1<<10 {
|
||||||
|
return fmt.Sprintf("%dB", n)
|
||||||
|
}
|
||||||
|
if n < 1<<20 {
|
||||||
|
return fmt.Sprintf("%dKB", n>>10)
|
||||||
|
}
|
||||||
|
if n < 1<<30 {
|
||||||
|
return fmt.Sprintf("%dMB", n>>20)
|
||||||
|
}
|
||||||
|
if n < 1<<40 {
|
||||||
|
return fmt.Sprintf("%dGB", n>>30)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%dTB", n>>40)
|
||||||
|
}
|
||||||
|
|
||||||
|
func etaString(eta time.Time) string {
|
||||||
|
d := eta.Sub(time.Now())
|
||||||
|
if d < time.Second {
|
||||||
|
return "<1s"
|
||||||
|
}
|
||||||
|
// truncate to 1s resolution
|
||||||
|
d /= time.Second
|
||||||
|
d *= time.Second
|
||||||
|
return d.String()
|
||||||
|
}
|
27
grab/grabui/grabui.go
Normal file
27
grab/grabui/grabui.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package grabui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/cavaliercoder/grab"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetBatch(
|
||||||
|
ctx context.Context,
|
||||||
|
workers int,
|
||||||
|
dst string,
|
||||||
|
urlStrs ...string,
|
||||||
|
) (<-chan *grab.Response, error) {
|
||||||
|
reqs := make([]*grab.Request, len(urlStrs))
|
||||||
|
for i := 0; i < len(urlStrs); i++ {
|
||||||
|
req, err := grab.NewRequest(dst, urlStrs[i])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
reqs[i] = req
|
||||||
|
}
|
||||||
|
|
||||||
|
ui := NewConsoleClient(grab.DefaultClient)
|
||||||
|
return ui.Do(ctx, workers, reqs...), nil
|
||||||
|
}
|
12
grab/rate_limiter.go
Normal file
12
grab/rate_limiter.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// RateLimiter is an interface that must be satisfied by any third-party rate
|
||||||
|
// limiters that may be used to limit download transfer speeds.
|
||||||
|
//
|
||||||
|
// A recommended token bucket implementation can be found at
|
||||||
|
// https://godoc.org/golang.org/x/time/rate#Limiter.
|
||||||
|
type RateLimiter interface {
|
||||||
|
WaitN(ctx context.Context, n int) (err error)
|
||||||
|
}
|
69
grab/rate_limiter_test.go
Normal file
69
grab/rate_limiter_test.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cavaliercoder/grab/grabtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testRateLimiter is a naive rate limiter that limits throughput to r tokens
|
||||||
|
// per second. The total number of tokens issued is tracked as n.
|
||||||
|
type testRateLimiter struct {
|
||||||
|
r, n int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLimiter(r int) RateLimiter {
|
||||||
|
return &testRateLimiter{r: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testRateLimiter) WaitN(ctx context.Context, n int) (err error) {
|
||||||
|
c.n += n
|
||||||
|
time.Sleep(
|
||||||
|
time.Duration(1.00 / float64(c.r) * float64(n) * float64(time.Second)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter(t *testing.T) {
|
||||||
|
// download a 128 byte file, 8 bytes at a time, with a naive 512bps limiter
|
||||||
|
// should take > 250ms
|
||||||
|
filesize := 128
|
||||||
|
filename := ".testRateLimiter"
|
||||||
|
defer os.Remove(filename)
|
||||||
|
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
// limit to 512bps
|
||||||
|
lim := &testRateLimiter{r: 512}
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
|
||||||
|
// ensure multiple trips to the rate limiter by downloading 8 bytes at a time
|
||||||
|
req.BufferSize = 8
|
||||||
|
req.RateLimiter = lim
|
||||||
|
|
||||||
|
resp := mustDo(req)
|
||||||
|
testComplete(t, resp)
|
||||||
|
if lim.n != filesize {
|
||||||
|
t.Errorf("expected %d bytes to pass through limiter, got %d", filesize, lim.n)
|
||||||
|
}
|
||||||
|
if resp.Duration().Seconds() < 0.25 {
|
||||||
|
// BUG: this test can pass if the transfer was slow for unrelated reasons
|
||||||
|
t.Errorf("expected transfer to take >250ms, took %v", resp.Duration())
|
||||||
|
}
|
||||||
|
}, grabtest.ContentLength(filesize))
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleRateLimiter() {
|
||||||
|
req, _ := NewRequest("", "http://www.golang-book.com/public/pdf/gobook.pdf")
|
||||||
|
|
||||||
|
// Attach a 1Mbps rate limiter, like the token bucket implementation from
|
||||||
|
// golang.org/x/time/rate.
|
||||||
|
req.RateLimiter = NewLimiter(1048576)
|
||||||
|
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
if err := resp.Err(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
177
grab/request.go
Normal file
177
grab/request.go
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"hash"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Hook is a user provided callback function that can be called by grab at
|
||||||
|
// various stages of a requests lifecycle. If a hook returns an error, the
|
||||||
|
// associated request is canceled and the same error is returned on the Response
|
||||||
|
// object.
|
||||||
|
//
|
||||||
|
// Hook functions are called synchronously and should never block unnecessarily.
|
||||||
|
// Response methods that block until a download is complete, such as
|
||||||
|
// Response.Err, Response.Cancel or Response.Wait will deadlock. To cancel a
|
||||||
|
// download from a callback, simply return a non-nil error.
|
||||||
|
type Hook func(*Response) error
|
||||||
|
|
||||||
|
// A Request represents an HTTP file transfer request to be sent by a Client.
|
||||||
|
type Request struct {
|
||||||
|
// Label is an arbitrary string which may used to label a Request with a
|
||||||
|
// user friendly name.
|
||||||
|
Label string
|
||||||
|
|
||||||
|
// Tag is an arbitrary interface which may be used to relate a Request to
|
||||||
|
// other data.
|
||||||
|
Tag interface{}
|
||||||
|
|
||||||
|
// HTTPRequest specifies the http.Request to be sent to the remote server to
|
||||||
|
// initiate a file transfer. It includes request configuration such as URL,
|
||||||
|
// protocol version, HTTP method, request headers and authentication.
|
||||||
|
HTTPRequest *http.Request
|
||||||
|
|
||||||
|
// Filename specifies the path where the file transfer will be stored in
|
||||||
|
// local storage. If Filename is empty or a directory, the true Filename will
|
||||||
|
// be resolved using Content-Disposition headers or the request URL.
|
||||||
|
//
|
||||||
|
// An empty string means the transfer will be stored in the current working
|
||||||
|
// directory.
|
||||||
|
Filename string
|
||||||
|
|
||||||
|
// SkipExisting specifies that ErrFileExists should be returned if the
|
||||||
|
// destination path already exists. The existing file will not be checked for
|
||||||
|
// completeness.
|
||||||
|
SkipExisting bool
|
||||||
|
|
||||||
|
// NoResume specifies that a partially completed download will be restarted
|
||||||
|
// without attempting to resume any existing file. If the download is already
|
||||||
|
// completed in full, it will not be restarted.
|
||||||
|
NoResume bool
|
||||||
|
|
||||||
|
// NoStore specifies that grab should not write to the local file system.
|
||||||
|
// Instead, the download will be stored in memory and accessible only via
|
||||||
|
// Response.Open or Response.Bytes.
|
||||||
|
NoStore bool
|
||||||
|
|
||||||
|
// NoCreateDirectories specifies that any missing directories in the given
|
||||||
|
// Filename path should not be created automatically, if they do not already
|
||||||
|
// exist.
|
||||||
|
NoCreateDirectories bool
|
||||||
|
|
||||||
|
// IgnoreBadStatusCodes specifies that grab should accept any status code in
|
||||||
|
// the response from the remote server. Otherwise, grab expects the response
|
||||||
|
// status code to be within the 2XX range (after following redirects).
|
||||||
|
IgnoreBadStatusCodes bool
|
||||||
|
|
||||||
|
// IgnoreRemoteTime specifies that grab should not attempt to set the
|
||||||
|
// timestamp of the local file to match the remote file.
|
||||||
|
IgnoreRemoteTime bool
|
||||||
|
|
||||||
|
// Size specifies the expected size of the file transfer if known. If the
|
||||||
|
// server response size does not match, the transfer is cancelled and
|
||||||
|
// ErrBadLength returned.
|
||||||
|
Size int64
|
||||||
|
|
||||||
|
// BufferSize specifies the size in bytes of the buffer that is used for
|
||||||
|
// transferring the requested file. Larger buffers may result in faster
|
||||||
|
// throughput but will use more memory and result in less frequent updates
|
||||||
|
// to the transfer progress statistics. If a RateLimiter is configured,
|
||||||
|
// BufferSize should be much lower than the rate limit. Default: 32KB.
|
||||||
|
BufferSize int
|
||||||
|
|
||||||
|
// RateLimiter allows the transfer rate of a download to be limited. The given
|
||||||
|
// Request.BufferSize determines how frequently the RateLimiter will be
|
||||||
|
// polled.
|
||||||
|
RateLimiter RateLimiter
|
||||||
|
|
||||||
|
// BeforeCopy is a user provided callback that is called immediately before
|
||||||
|
// a request starts downloading. If BeforeCopy returns an error, the request
|
||||||
|
// is cancelled and the same error is returned on the Response object.
|
||||||
|
BeforeCopy Hook
|
||||||
|
|
||||||
|
// AfterCopy is a user provided callback that is called immediately after a
|
||||||
|
// request has finished downloading, before checksum validation and closure.
|
||||||
|
// This hook is only called if the transfer was successful. If AfterCopy
|
||||||
|
// returns an error, the request is canceled and the same error is returned on
|
||||||
|
// the Response object.
|
||||||
|
AfterCopy Hook
|
||||||
|
|
||||||
|
// hash, checksum and deleteOnError - set via SetChecksum.
|
||||||
|
hash hash.Hash
|
||||||
|
checksum []byte
|
||||||
|
deleteOnError bool
|
||||||
|
|
||||||
|
// Context for cancellation and timeout - set via WithContext
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequest returns a new file transfer Request suitable for use with
|
||||||
|
// Client.Do.
|
||||||
|
func NewRequest(dst, urlStr string) (*Request, error) {
|
||||||
|
if dst == "" {
|
||||||
|
dst = "."
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("GET", urlStr, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Request{
|
||||||
|
HTTPRequest: req,
|
||||||
|
Filename: dst,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context returns the request's context. To change the context, use
|
||||||
|
// WithContext.
|
||||||
|
//
|
||||||
|
// The returned context is always non-nil; it defaults to the background
|
||||||
|
// context.
|
||||||
|
//
|
||||||
|
// The context controls cancelation.
|
||||||
|
func (r *Request) Context() context.Context {
|
||||||
|
if r.ctx != nil {
|
||||||
|
return r.ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContext returns a shallow copy of r with its context changed
|
||||||
|
// to ctx. The provided ctx must be non-nil.
|
||||||
|
func (r *Request) WithContext(ctx context.Context) *Request {
|
||||||
|
if ctx == nil {
|
||||||
|
panic("nil context")
|
||||||
|
}
|
||||||
|
r2 := new(Request)
|
||||||
|
*r2 = *r
|
||||||
|
r2.ctx = ctx
|
||||||
|
r2.HTTPRequest = r2.HTTPRequest.WithContext(ctx)
|
||||||
|
return r2
|
||||||
|
}
|
||||||
|
|
||||||
|
// URL returns the URL to be downloaded.
|
||||||
|
func (r *Request) URL() *url.URL {
|
||||||
|
return r.HTTPRequest.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetChecksum sets the desired hashing algorithm and checksum value to validate
|
||||||
|
// a downloaded file. Once the download is complete, the given hashing algorithm
|
||||||
|
// will be used to compute the actual checksum of the downloaded file. If the
|
||||||
|
// checksums do not match, an error will be returned by the associated
|
||||||
|
// Response.Err method.
|
||||||
|
//
|
||||||
|
// If deleteOnError is true, the downloaded file will be deleted automatically
|
||||||
|
// if it fails checksum validation.
|
||||||
|
//
|
||||||
|
// To prevent corruption of the computed checksum, the given hash must not be
|
||||||
|
// used by any other request or goroutines.
|
||||||
|
//
|
||||||
|
// To disable checksum validation, call SetChecksum with a nil hash.
|
||||||
|
func (r *Request) SetChecksum(h hash.Hash, sum []byte, deleteOnError bool) {
|
||||||
|
r.hash = h
|
||||||
|
r.checksum = sum
|
||||||
|
r.deleteOnError = deleteOnError
|
||||||
|
}
|
258
grab/response.go
Normal file
258
grab/response.go
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response represents the response to a completed or in-progress download
|
||||||
|
// request.
|
||||||
|
//
|
||||||
|
// A response may be returned as soon a HTTP response is received from a remote
|
||||||
|
// server, but before the body content has started transferring.
|
||||||
|
//
|
||||||
|
// All Response method calls are thread-safe.
|
||||||
|
type Response struct {
|
||||||
|
// The Request that was submitted to obtain this Response.
|
||||||
|
Request *Request
|
||||||
|
|
||||||
|
// HTTPResponse represents the HTTP response received from an HTTP request.
|
||||||
|
//
|
||||||
|
// The response Body should not be used as it will be consumed and closed by
|
||||||
|
// grab.
|
||||||
|
HTTPResponse *http.Response
|
||||||
|
|
||||||
|
// Filename specifies the path where the file transfer is stored in local
|
||||||
|
// storage.
|
||||||
|
Filename string
|
||||||
|
|
||||||
|
// Size specifies the total expected size of the file transfer.
|
||||||
|
sizeUnsafe int64
|
||||||
|
|
||||||
|
// Start specifies the time at which the file transfer started.
|
||||||
|
Start time.Time
|
||||||
|
|
||||||
|
// End specifies the time at which the file transfer completed.
|
||||||
|
//
|
||||||
|
// This will return zero until the transfer has completed.
|
||||||
|
End time.Time
|
||||||
|
|
||||||
|
// CanResume specifies that the remote server advertised that it can resume
|
||||||
|
// previous downloads, as the 'Accept-Ranges: bytes' header is set.
|
||||||
|
CanResume bool
|
||||||
|
|
||||||
|
// DidResume specifies that the file transfer resumed a previously incomplete
|
||||||
|
// transfer.
|
||||||
|
DidResume bool
|
||||||
|
|
||||||
|
// Done is closed once the transfer is finalized, either successfully or with
|
||||||
|
// errors. Errors are available via Response.Err
|
||||||
|
Done chan struct{}
|
||||||
|
|
||||||
|
// ctx is a Context that controls cancelation of an inprogress transfer
|
||||||
|
ctx context.Context
|
||||||
|
|
||||||
|
// cancel is a cancel func that can be used to cancel the context of this
|
||||||
|
// Response.
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
// fi is the FileInfo for the destination file if it already existed before
|
||||||
|
// transfer started.
|
||||||
|
fi os.FileInfo
|
||||||
|
|
||||||
|
// optionsKnown indicates that a HEAD request has been completed and the
|
||||||
|
// capabilities of the remote server are known.
|
||||||
|
optionsKnown bool
|
||||||
|
|
||||||
|
// writer is the file handle used to write the downloaded file to local
|
||||||
|
// storage
|
||||||
|
writer io.Writer
|
||||||
|
|
||||||
|
// storeBuffer receives the contents of the transfer if Request.NoStore is
|
||||||
|
// enabled.
|
||||||
|
storeBuffer bytes.Buffer
|
||||||
|
|
||||||
|
// bytesCompleted specifies the number of bytes which were already
|
||||||
|
// transferred before this transfer began.
|
||||||
|
bytesResumed int64
|
||||||
|
|
||||||
|
// transfer is responsible for copying data from the remote server to a local
|
||||||
|
// file, tracking progress and allowing for cancelation.
|
||||||
|
transfer *transfer
|
||||||
|
|
||||||
|
// bufferSize specifies the size in bytes of the transfer buffer.
|
||||||
|
bufferSize int
|
||||||
|
|
||||||
|
// Error contains any error that may have occurred during the file transfer.
|
||||||
|
// This should not be read until IsComplete returns true.
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsComplete returns true if the download has completed. If an error occurred
|
||||||
|
// during the download, it can be returned via Err.
|
||||||
|
func (c *Response) IsComplete() bool {
|
||||||
|
select {
|
||||||
|
case <-c.Done:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel cancels the file transfer by canceling the underlying Context for
|
||||||
|
// this Response. Cancel blocks until the transfer is closed and returns any
|
||||||
|
// error - typically context.Canceled.
|
||||||
|
func (c *Response) Cancel() error {
|
||||||
|
c.cancel()
|
||||||
|
return c.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks until the download is completed.
|
||||||
|
func (c *Response) Wait() {
|
||||||
|
<-c.Done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err blocks the calling goroutine until the underlying file transfer is
|
||||||
|
// completed and returns any error that may have occurred. If the download is
|
||||||
|
// already completed, Err returns immediately.
|
||||||
|
func (c *Response) Err() error {
|
||||||
|
<-c.Done
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the size of the file transfer. If the remote server does not
|
||||||
|
// specify the total size and the transfer is incomplete, the return value is
|
||||||
|
// -1.
|
||||||
|
func (c *Response) Size() int64 {
|
||||||
|
return atomic.LoadInt64(&c.sizeUnsafe)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BytesComplete returns the total number of bytes which have been copied to
|
||||||
|
// the destination, including any bytes that were resumed from a previous
|
||||||
|
// download.
|
||||||
|
func (c *Response) BytesComplete() int64 {
|
||||||
|
return c.bytesResumed + c.transfer.N()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BytesPerSecond returns the number of bytes per second transferred using a
|
||||||
|
// simple moving average of the last five seconds. If the download is already
|
||||||
|
// complete, the average bytes/sec for the life of the download is returned.
|
||||||
|
func (c *Response) BytesPerSecond() float64 {
|
||||||
|
if c.IsComplete() {
|
||||||
|
return float64(c.transfer.N()) / c.Duration().Seconds()
|
||||||
|
}
|
||||||
|
return c.transfer.BPS()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Progress returns the ratio of total bytes that have been downloaded. Multiply
|
||||||
|
// the returned value by 100 to return the percentage completed.
|
||||||
|
func (c *Response) Progress() float64 {
|
||||||
|
size := c.Size()
|
||||||
|
if size <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return float64(c.BytesComplete()) / float64(size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duration returns the duration of a file transfer. If the transfer is in
|
||||||
|
// process, the duration will be between now and the start of the transfer. If
|
||||||
|
// the transfer is complete, the duration will be between the start and end of
|
||||||
|
// the completed transfer process.
|
||||||
|
func (c *Response) Duration() time.Duration {
|
||||||
|
if c.IsComplete() {
|
||||||
|
return c.End.Sub(c.Start)
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Now().Sub(c.Start)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ETA returns the estimated time at which the the download will complete, given
|
||||||
|
// the current BytesPerSecond. If the transfer has already completed, the actual
|
||||||
|
// end time will be returned.
|
||||||
|
func (c *Response) ETA() time.Time {
|
||||||
|
if c.IsComplete() {
|
||||||
|
return c.End
|
||||||
|
}
|
||||||
|
bt := c.BytesComplete()
|
||||||
|
bps := c.transfer.BPS()
|
||||||
|
if bps == 0 {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
secs := float64(c.Size()-bt) / bps
|
||||||
|
return time.Now().Add(time.Duration(secs) * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open blocks the calling goroutine until the underlying file transfer is
|
||||||
|
// completed and then opens the transferred file for reading. If Request.NoStore
|
||||||
|
// was enabled, the reader will read from memory.
|
||||||
|
//
|
||||||
|
// If an error occurred during the transfer, it will be returned.
|
||||||
|
//
|
||||||
|
// It is the callers responsibility to close the returned file handle.
|
||||||
|
func (c *Response) Open() (io.ReadCloser, error) {
|
||||||
|
if err := c.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.openUnsafe()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Response) openUnsafe() (io.ReadCloser, error) {
|
||||||
|
if c.Request.NoStore {
|
||||||
|
return ioutil.NopCloser(bytes.NewReader(c.storeBuffer.Bytes())), nil
|
||||||
|
}
|
||||||
|
return os.Open(c.Filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes blocks the calling goroutine until the underlying file transfer is
|
||||||
|
// completed and then reads all bytes from the completed tranafer. If
|
||||||
|
// Request.NoStore was enabled, the bytes will be read from memory.
|
||||||
|
//
|
||||||
|
// If an error occurred during the transfer, it will be returned.
|
||||||
|
func (c *Response) Bytes() ([]byte, error) {
|
||||||
|
if err := c.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if c.Request.NoStore {
|
||||||
|
return c.storeBuffer.Bytes(), nil
|
||||||
|
}
|
||||||
|
f, err := c.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
return ioutil.ReadAll(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Response) requestMethod() string {
|
||||||
|
if c == nil || c.HTTPResponse == nil || c.HTTPResponse.Request == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.HTTPResponse.Request.Method
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Response) checksumUnsafe() ([]byte, error) {
|
||||||
|
f, err := c.openUnsafe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
t := newTransfer(c.Request.Context(), nil, c.Request.hash, f, nil)
|
||||||
|
if _, err = t.copy(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sum := c.Request.hash.Sum(nil)
|
||||||
|
return sum, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Response) closeResponseBody() error {
|
||||||
|
if c.HTTPResponse == nil || c.HTTPResponse.Body == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.HTTPResponse.Body.Close()
|
||||||
|
}
|
118
grab/response_test.go
Normal file
118
grab/response_test.go
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cavaliercoder/grab/grabtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testComplete validates that a completed Response has all the desired fields.
|
||||||
|
func testComplete(t *testing.T, resp *Response) {
|
||||||
|
<-resp.Done
|
||||||
|
if !resp.IsComplete() {
|
||||||
|
t.Errorf("Response.IsComplete returned false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Start.IsZero() {
|
||||||
|
t.Errorf("Response.Start is zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.End.IsZero() {
|
||||||
|
t.Error("Response.End is zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
if eta := resp.ETA(); eta != resp.End {
|
||||||
|
t.Errorf("Response.ETA is not equal to Response.End: %v", eta)
|
||||||
|
}
|
||||||
|
|
||||||
|
// the following fields should only be set if no error occurred
|
||||||
|
if resp.Err() == nil {
|
||||||
|
if resp.Filename == "" {
|
||||||
|
t.Errorf("Response.Filename is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Size() == 0 {
|
||||||
|
t.Error("Response.Size is zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p := resp.Progress(); p != 1.00 {
|
||||||
|
t.Errorf("Response.Progress returned %v (%v/%v bytes), expected 1", p, resp.BytesComplete(), resp.Size())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestResponseProgress tests the functions which indicate the progress of an
|
||||||
|
// in-process file transfer.
|
||||||
|
func TestResponseProgress(t *testing.T) {
|
||||||
|
filename := ".testResponseProgress"
|
||||||
|
defer os.Remove(filename)
|
||||||
|
|
||||||
|
sleep := 300 * time.Millisecond
|
||||||
|
size := 1024 * 8 // bytes
|
||||||
|
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
// request a slow transfer
|
||||||
|
req := mustNewRequest(filename, url)
|
||||||
|
resp := DefaultClient.Do(req)
|
||||||
|
|
||||||
|
// make sure transfer has not started
|
||||||
|
if resp.IsComplete() {
|
||||||
|
t.Errorf("Transfer should not have started")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p := resp.Progress(); p != 0 {
|
||||||
|
t.Errorf("Transfer should not have started yet but progress is %v", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for transfer to complete
|
||||||
|
<-resp.Done
|
||||||
|
|
||||||
|
// make sure transfer is complete
|
||||||
|
if p := resp.Progress(); p != 1 {
|
||||||
|
t.Errorf("Transfer is complete but progress is %v", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s := resp.BytesComplete(); s != int64(size) {
|
||||||
|
t.Errorf("Expected to transfer %v bytes, got %v", size, s)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
grabtest.TimeToFirstByte(sleep),
|
||||||
|
grabtest.ContentLength(size),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseOpen(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
resp := mustDo(mustNewRequest("", url+"/someFilename"))
|
||||||
|
f, err := resp.Open()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := f.Close(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
grabtest.AssertSHA256Sum(t, grabtest.DefaultHandlerSHA256ChecksumBytes, f)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseBytes(t *testing.T) {
|
||||||
|
grabtest.WithTestServer(t, func(url string) {
|
||||||
|
resp := mustDo(mustNewRequest("", url+"/someFilename"))
|
||||||
|
b, err := resp.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
grabtest.AssertSHA256Sum(
|
||||||
|
t,
|
||||||
|
grabtest.DefaultHandlerSHA256ChecksumBytes,
|
||||||
|
bytes.NewReader(b),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
111
grab/states.wsd
Normal file
111
grab/states.wsd
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
@startuml
|
||||||
|
title Grab transfer state
|
||||||
|
|
||||||
|
legend
|
||||||
|
| # | Meaning |
|
||||||
|
| D | Destination path known |
|
||||||
|
| S | File size known |
|
||||||
|
| O | Server options known (Accept-Ranges) |
|
||||||
|
| R | Resume supported (Accept-Ranges) |
|
||||||
|
| Z | Local file empty or missing |
|
||||||
|
| P | Local file partially complete |
|
||||||
|
endlegend
|
||||||
|
|
||||||
|
[*] --> Empty
|
||||||
|
[*] --> D
|
||||||
|
[*] --> S
|
||||||
|
[*] --> DS
|
||||||
|
|
||||||
|
Empty : Filename: ""
|
||||||
|
Empty : Size: 0
|
||||||
|
Empty --> O : HEAD: Method not allowed
|
||||||
|
Empty --> DSO : HEAD: Range not supported
|
||||||
|
Empty --> DSOR : HEAD: Range supported
|
||||||
|
|
||||||
|
DS : Filename: "foo.bar"
|
||||||
|
DS : Size: > 0
|
||||||
|
DS --> DSZ : checkExisting(): File missing
|
||||||
|
DS --> DSP : checkExisting(): File partial
|
||||||
|
DS --> [*] : checkExisting(): File complete
|
||||||
|
DS --> ERROR
|
||||||
|
|
||||||
|
S : Filename: ""
|
||||||
|
S : Size: > 0
|
||||||
|
S --> SO : HEAD: Method not allowed
|
||||||
|
S --> DSO : HEAD: Range not supported
|
||||||
|
S --> DSOR : HEAD: Range supported
|
||||||
|
|
||||||
|
D : Filename: "foo.bar"
|
||||||
|
D : Size: 0
|
||||||
|
D --> DO : HEAD: Method not allowed
|
||||||
|
D --> DSO : HEAD: Range not supported
|
||||||
|
D --> DSOR : HEAD: Range supported
|
||||||
|
|
||||||
|
|
||||||
|
O : Filename: ""
|
||||||
|
O : Size: 0
|
||||||
|
O : CanResume: false
|
||||||
|
O --> DSO : GET 200
|
||||||
|
O --> ERROR
|
||||||
|
|
||||||
|
SO : Filename: ""
|
||||||
|
SO : Size: > 0
|
||||||
|
SO : CanResume: false
|
||||||
|
SO --> DSO : GET: 200
|
||||||
|
SO --> ERROR
|
||||||
|
|
||||||
|
DO : Filename: "foo.bar"
|
||||||
|
DO : Size: 0
|
||||||
|
DO : CanResume: false
|
||||||
|
DO --> DSO : GET 200
|
||||||
|
DO --> ERROR
|
||||||
|
|
||||||
|
DSZ : Filename: "foo.bar"
|
||||||
|
DSZ : Size: > 0
|
||||||
|
DSZ : File: empty
|
||||||
|
DSZ --> DSORZ : HEAD: Range supported
|
||||||
|
DSZ --> DSOZ : HEAD 405 or Range unsupported
|
||||||
|
|
||||||
|
DSP : Filename: "foo.bar"
|
||||||
|
DSP : Size: > 0
|
||||||
|
DSP : File: partial
|
||||||
|
DSP --> DSORP : HEAD: Range supported
|
||||||
|
DSP --> DSOZ : HEAD: 405 or Range unsupported
|
||||||
|
|
||||||
|
DSO : Filename: "foo.bar"
|
||||||
|
DSO : Size: > 0
|
||||||
|
DSO : CanResume: false
|
||||||
|
DSO --> DSOZ : checkExisting(): File partial|missing
|
||||||
|
DSO --> [*] : checkExisting(): File complete
|
||||||
|
|
||||||
|
DSOR : Filename: "foo.bar"
|
||||||
|
DSOR : Size: > 0
|
||||||
|
DSOR : CanResume: true
|
||||||
|
DSOR --> DSORP : CheckLocal: File partial
|
||||||
|
DSOR --> DSORZ : CheckLocal: File missing
|
||||||
|
|
||||||
|
DSORP : Filename: "foo.bar"
|
||||||
|
DSORP : Size: > 0
|
||||||
|
DSORP : CanResume: true
|
||||||
|
DSORP : File: partial
|
||||||
|
DSORP --> Transferring
|
||||||
|
|
||||||
|
DSORZ : Filename: "foo.bar"
|
||||||
|
DSORZ : Size: > 0
|
||||||
|
DSORZ : CanResume: true
|
||||||
|
DSORZ : File: empty
|
||||||
|
DSORZ --> Transferring
|
||||||
|
|
||||||
|
DSOZ : Filename: "foo.bar"
|
||||||
|
DSOZ : Size: > 0
|
||||||
|
DSOZ : CanResume: false
|
||||||
|
DSOZ : File: empty
|
||||||
|
DSOZ --> Transferring
|
||||||
|
|
||||||
|
Transferring --> [*]
|
||||||
|
Transferring --> ERROR
|
||||||
|
|
||||||
|
ERROR : Something went wrong
|
||||||
|
ERROR --> [*]
|
||||||
|
|
||||||
|
@enduml
|
103
grab/transfer.go
Normal file
103
grab/transfer.go
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cavaliercoder/grab/bps"
|
||||||
|
)
|
||||||
|
|
||||||
|
type transfer struct {
|
||||||
|
n int64 // must be 64bit aligned on 386
|
||||||
|
ctx context.Context
|
||||||
|
gauge bps.Gauge
|
||||||
|
lim RateLimiter
|
||||||
|
w io.Writer
|
||||||
|
r io.Reader
|
||||||
|
b []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransfer(ctx context.Context, lim RateLimiter, dst io.Writer, src io.Reader, buf []byte) *transfer {
|
||||||
|
return &transfer{
|
||||||
|
ctx: ctx,
|
||||||
|
gauge: bps.NewSMA(6), // five second moving average sampling every second
|
||||||
|
lim: lim,
|
||||||
|
w: dst,
|
||||||
|
r: src,
|
||||||
|
b: buf,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy behaves similarly to io.CopyBuffer except that it checks for cancelation
|
||||||
|
// of the given context.Context, reports progress in a thread-safe manner and
|
||||||
|
// tracks the transfer rate.
|
||||||
|
func (c *transfer) copy() (written int64, err error) {
|
||||||
|
// maintain a bps gauge in another goroutine
|
||||||
|
ctx, cancel := context.WithCancel(c.ctx)
|
||||||
|
defer cancel()
|
||||||
|
go bps.Watch(ctx, c.gauge, c.N, time.Second)
|
||||||
|
|
||||||
|
// start the transfer
|
||||||
|
if c.b == nil {
|
||||||
|
c.b = make([]byte, 32*1024)
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
err = c.ctx.Err()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// keep working
|
||||||
|
}
|
||||||
|
nr, er := c.r.Read(c.b)
|
||||||
|
if nr > 0 {
|
||||||
|
nw, ew := c.w.Write(c.b[0:nr])
|
||||||
|
if nw > 0 {
|
||||||
|
written += int64(nw)
|
||||||
|
atomic.StoreInt64(&c.n, written)
|
||||||
|
}
|
||||||
|
if ew != nil {
|
||||||
|
err = ew
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if nr != nw {
|
||||||
|
err = io.ErrShortWrite
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// wait for rate limiter
|
||||||
|
if c.lim != nil {
|
||||||
|
err = c.lim.WaitN(c.ctx, nr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if er != nil {
|
||||||
|
if er != io.EOF {
|
||||||
|
err = er
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// N returns the number of bytes transferred.
|
||||||
|
func (c *transfer) N() (n int64) {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
n = atomic.LoadInt64(&c.n)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// BPS returns the current bytes per second transfer rate using a simple moving
|
||||||
|
// average.
|
||||||
|
func (c *transfer) BPS() (bps float64) {
|
||||||
|
if c == nil || c.gauge == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.gauge.BPS()
|
||||||
|
}
|
75
grab/util.go
Normal file
75
grab/util.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"mime"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lordwelch/pathvalidate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setLastModified sets the last modified timestamp of a local file according to
|
||||||
|
// the Last-Modified header returned by a remote server.
|
||||||
|
func setLastModified(resp *http.Response, filename string) error {
|
||||||
|
// https://tools.ietf.org/html/rfc7232#section-2.2
|
||||||
|
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified
|
||||||
|
header := resp.Header.Get("Last-Modified")
|
||||||
|
if header == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
lastmod, err := time.Parse(http.TimeFormat, header)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return os.Chtimes(filename, lastmod, lastmod)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mkdirp creates all missing parent directories for the destination file path.
|
||||||
|
func mkdirp(path string) error {
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
if fi, err := os.Stat(dir); err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("error checking destination directory: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(dir, 0777); err != nil {
|
||||||
|
return fmt.Errorf("error creating destination directory: %v", err)
|
||||||
|
}
|
||||||
|
} else if !fi.IsDir() {
|
||||||
|
panic("grab: developer error: destination path is not directory")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// guessFilename returns a filename for the given http.Response. If none can be
|
||||||
|
// determined ErrNoFilename is returned.
|
||||||
|
//
|
||||||
|
// TODO: NoStore operations should not require a filename
|
||||||
|
func guessFilename(resp *http.Response) (string, error) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
filename = resp.Request.URL.Path
|
||||||
|
)
|
||||||
|
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
|
||||||
|
if _, params, err := mime.ParseMediaType(cd); err == nil {
|
||||||
|
if val, ok := params["filename"]; ok {
|
||||||
|
filename = val
|
||||||
|
} // else filename directive is missing. fallback to URL.Path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitize
|
||||||
|
filename, err = pathvalidate.SanitizeFilename(path.Base(filename), '_')
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("%w: %v", ErrNoFilename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filename == "" || filename == "." || filename == "/" {
|
||||||
|
return "", ErrNoFilename
|
||||||
|
}
|
||||||
|
|
||||||
|
return filename, nil
|
||||||
|
}
|
167
grab/util_test.go
Normal file
167
grab/util_test.go
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
package grab
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestURLFilenames(t *testing.T) {
|
||||||
|
t.Run("Valid", func(t *testing.T) {
|
||||||
|
expect := "filename"
|
||||||
|
testCases := []string{
|
||||||
|
"http://test.com/filename",
|
||||||
|
"http://test.com/path/filename",
|
||||||
|
"http://test.com/deep/path/filename",
|
||||||
|
"http://test.com/filename?with=args",
|
||||||
|
"http://test.com/filename#with-fragment",
|
||||||
|
"http://test.com/filename?with=args&and#with-fragment",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
req, _ := http.NewRequest("GET", tc, nil)
|
||||||
|
resp := &http.Response{
|
||||||
|
Request: req,
|
||||||
|
}
|
||||||
|
actual, err := guessFilename(resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual != expect {
|
||||||
|
t.Errorf("expected '%v', got '%v'", expect, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Invalid", func(t *testing.T) {
|
||||||
|
testCases := []string{
|
||||||
|
"http://test.com",
|
||||||
|
"http://test.com/",
|
||||||
|
"http://test.com/filename/",
|
||||||
|
"http://test.com/filename/?with=args",
|
||||||
|
"http://test.com/filename/#with-fragment",
|
||||||
|
"http://test.com/filename\x00",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc, func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest("GET", tc, nil)
|
||||||
|
if err != nil {
|
||||||
|
if tc == "http://test.com/filename\x00" {
|
||||||
|
// Since go1.12, urls with invalid control character return an error
|
||||||
|
// See https://github.com/golang/go/commit/829c5df58694b3345cb5ea41206783c8ccf5c3ca
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp := &http.Response{
|
||||||
|
Request: req,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = guessFilename(resp)
|
||||||
|
if err != ErrNoFilename {
|
||||||
|
t.Errorf("expected '%v', got '%v'", ErrNoFilename, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderFilenames(t *testing.T) {
|
||||||
|
u, _ := url.ParseRequestURI("http://test.com/badfilename")
|
||||||
|
resp := &http.Response{
|
||||||
|
Request: &http.Request{
|
||||||
|
URL: u,
|
||||||
|
},
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
setFilename := func(resp *http.Response, filename string) {
|
||||||
|
resp.Header.Set("Content-Disposition", fmt.Sprintf("attachment;filename=\"%s\"", filename))
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Valid", func(t *testing.T) {
|
||||||
|
expect := "filename"
|
||||||
|
testCases := []string{
|
||||||
|
"filename",
|
||||||
|
"path/filename",
|
||||||
|
"/path/filename",
|
||||||
|
"../../filename",
|
||||||
|
"/path/../../filename",
|
||||||
|
"/../../././///filename",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
setFilename(resp, tc)
|
||||||
|
actual, err := guessFilename(resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error (%v): %v", tc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual != expect {
|
||||||
|
t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Invalid", func(t *testing.T) {
|
||||||
|
testCases := []string{
|
||||||
|
"",
|
||||||
|
"/",
|
||||||
|
".",
|
||||||
|
"/.",
|
||||||
|
"/./",
|
||||||
|
"..",
|
||||||
|
"../",
|
||||||
|
"/../",
|
||||||
|
"/path/",
|
||||||
|
"../path/",
|
||||||
|
"filename\x00",
|
||||||
|
"filename/",
|
||||||
|
"filename//",
|
||||||
|
"filename/..",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
setFilename(resp, tc)
|
||||||
|
if actual, err := guessFilename(resp); err != ErrNoFilename {
|
||||||
|
t.Errorf("expected: %v (%v), got: %v (%v)", ErrNoFilename, tc, err, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderWithMissingDirective(t *testing.T) {
|
||||||
|
u, _ := url.ParseRequestURI("http://test.com/filename")
|
||||||
|
resp := &http.Response{
|
||||||
|
Request: &http.Request{
|
||||||
|
URL: u,
|
||||||
|
},
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
setHeader := func(resp *http.Response, value string) {
|
||||||
|
resp.Header.Set("Content-Disposition", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Valid", func(t *testing.T) {
|
||||||
|
expect := "filename"
|
||||||
|
testCases := []string{
|
||||||
|
"inline",
|
||||||
|
"attachment",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
setHeader(resp, tc)
|
||||||
|
actual, err := guessFilename(resp)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error (%v): %v", tc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual != expect {
|
||||||
|
t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user