Commit custom grab
This commit is contained in:
parent
17d26242d2
commit
f1179ff06e
3
grab/.gitignore
vendored
Normal file
3
grab/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
# ignore IDE project files
|
||||
*.iml
|
||||
.idea/
|
14
grab/.travis.yml
Normal file
14
grab/.travis.yml
Normal file
@ -0,0 +1,14 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- tip
|
||||
- 1.10.x
|
||||
- 1.9.x
|
||||
- 1.8.x
|
||||
- 1.7.x
|
||||
|
||||
script: make check
|
||||
|
||||
env:
|
||||
- GOARCH=amd64
|
||||
- GOARCH=386
|
26
grab/LICENSE
Normal file
26
grab/LICENSE
Normal file
@ -0,0 +1,26 @@
|
||||
Copyright (c) 2017 Ryan Armstrong. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its contributors
|
||||
may be used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29
grab/Makefile
Normal file
29
grab/Makefile
Normal file
@ -0,0 +1,29 @@
|
||||
GO = go
|
||||
GOGET = $(GO) get -u
|
||||
|
||||
all: check lint
|
||||
|
||||
check:
|
||||
cd cmd/grab && $(MAKE) -B all
|
||||
$(GO) test -cover -race ./...
|
||||
|
||||
install:
|
||||
$(GO) install -v ./...
|
||||
|
||||
clean:
|
||||
$(GO) clean -x ./...
|
||||
rm -rvf ./.test*
|
||||
|
||||
lint:
|
||||
gofmt -l -e -s . || :
|
||||
go vet . || :
|
||||
golint . || :
|
||||
gocyclo -over 15 . || :
|
||||
misspell ./* || :
|
||||
|
||||
deps:
|
||||
$(GOGET) github.com/golang/lint/golint
|
||||
$(GOGET) github.com/fzipp/gocyclo
|
||||
$(GOGET) github.com/client9/misspell/cmd/misspell
|
||||
|
||||
.PHONY: all check install clean lint deps
|
127
grab/README.md
Normal file
127
grab/README.md
Normal file
@ -0,0 +1,127 @@
|
||||
# grab
|
||||
|
||||
[](https://godoc.org/github.com/cavaliercoder/grab) [](https://travis-ci.org/cavaliercoder/grab) [](https://goreportcard.com/report/github.com/cavaliercoder/grab)
|
||||
|
||||
*Downloading the internet, one goroutine at a time!*
|
||||
|
||||
$ go get github.com/cavaliercoder/grab
|
||||
|
||||
Grab is a Go package for downloading files from the internet with the following
|
||||
rad features:
|
||||
|
||||
* Monitor download progress concurrently
|
||||
* Auto-resume incomplete downloads
|
||||
* Guess filename from content header or URL path
|
||||
* Safely cancel downloads using context.Context
|
||||
* Validate downloads using checksums
|
||||
* Download batches of files concurrently
|
||||
* Apply rate limiters
|
||||
|
||||
Requires Go v1.7+
|
||||
|
||||
## Example
|
||||
|
||||
The following example downloads a PDF copy of the free eBook, "An Introduction
|
||||
to Programming in Go" into the current working directory.
|
||||
|
||||
```go
|
||||
resp, err := grab.Get(".", "http://www.golang-book.com/public/pdf/gobook.pdf")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println("Download saved to", resp.Filename)
|
||||
```
|
||||
|
||||
The following, more complete example allows for more granular control and
|
||||
periodically prints the download progress until it is complete.
|
||||
|
||||
The second time you run the example, it will auto-resume the previous download
|
||||
and exit sooner.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/cavaliercoder/grab"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// create client
|
||||
client := grab.NewClient()
|
||||
req, _ := grab.NewRequest(".", "http://www.golang-book.com/public/pdf/gobook.pdf")
|
||||
|
||||
// start download
|
||||
fmt.Printf("Downloading %v...\n", req.URL())
|
||||
resp := client.Do(req)
|
||||
fmt.Printf(" %v\n", resp.HTTPResponse.Status)
|
||||
|
||||
// start UI loop
|
||||
t := time.NewTicker(500 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
fmt.Printf(" transferred %v / %v bytes (%.2f%%)\n",
|
||||
resp.BytesComplete(),
|
||||
resp.Size(),
|
||||
100*resp.Progress())
|
||||
|
||||
case <-resp.Done:
|
||||
// download is complete
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
|
||||
// check for errors
|
||||
if err := resp.Err(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Download failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Download saved to ./%v \n", resp.Filename)
|
||||
|
||||
// Output:
|
||||
// Downloading http://www.golang-book.com/public/pdf/gobook.pdf...
|
||||
// 200 OK
|
||||
// transferred 42970 / 2893557 bytes (1.49%)
|
||||
// transferred 1207474 / 2893557 bytes (41.73%)
|
||||
// transferred 2758210 / 2893557 bytes (95.32%)
|
||||
// Download saved to ./gobook.pdf
|
||||
}
|
||||
```
|
||||
|
||||
## Design trade-offs
|
||||
|
||||
The primary use case for Grab is to concurrently downloading thousands of large
|
||||
files from remote file repositories where the remote files are immutable.
|
||||
Examples include operating system package repositories or ISO libraries.
|
||||
|
||||
Grab aims to provide robust, sane defaults. These are usually determined using
|
||||
the HTTP specifications, or by mimicking the behavior of common web clients like
|
||||
cURL, wget and common web browsers.
|
||||
|
||||
Grab aims to be stateless. The only state that exists is the remote files you
|
||||
wish to download and the local copy which may be completed, partially completed
|
||||
or not yet created. The advantage to this is that the local file system is not
|
||||
cluttered unnecessarily with addition state files (like a `.crdownload` file).
|
||||
The disadvantage of this approach is that grab must make assumptions about the
|
||||
local and remote state; specifically, that they have not been modified by
|
||||
another program.
|
||||
|
||||
If the local or remote file are modified outside of grab, and you download the
|
||||
file again with resuming enabled, the local file will likely become corrupted.
|
||||
In this case, you might consider making remote files immutable, or disabling
|
||||
resume.
|
||||
|
||||
Grab aims to enable best-in-class functionality for more complex features
|
||||
through extensible interfaces, rather than reimplementation. For example,
|
||||
you can provide your own Hash algorithm to compute file checksums, or your
|
||||
own rate limiter implementation (with all the associated trade-offs) to rate
|
||||
limit downloads.
|
54
grab/bps/bps.go
Normal file
54
grab/bps/bps.go
Normal file
@ -0,0 +1,54 @@
|
||||
/*
|
||||
Package bps provides gauges for calculating the Bytes Per Second transfer rate
|
||||
of data streams.
|
||||
*/
|
||||
package bps
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Gauge is the common interface for all BPS gauges in this package. Given a
|
||||
// set of samples over time, each gauge type can be used to measure the Bytes
|
||||
// Per Second transfer rate of a data stream.
|
||||
//
|
||||
// All samples must monotonically increase in timestamp and value. Each sample
|
||||
// should represent the total number of bytes sent in a stream, rather than
|
||||
// accounting for the number sent since the last sample.
|
||||
//
|
||||
// To ensure a gauge can report progress as quickly as possible, take an initial
|
||||
// sample when your stream first starts.
|
||||
//
|
||||
// All gauge implementations are safe for concurrent use.
|
||||
type Gauge interface {
|
||||
// Sample adds a new sample of the progress of the monitored stream.
|
||||
Sample(t time.Time, n int64)
|
||||
|
||||
// BPS returns the calculated Bytes Per Second rate of the monitored stream.
|
||||
BPS() float64
|
||||
}
|
||||
|
||||
// SampleFunc is used by Watch to take periodic samples of a monitored stream.
|
||||
type SampleFunc func() (n int64)
|
||||
|
||||
// Watch will periodically call the given SampleFunc to sample the progress of
|
||||
// a monitored stream and update the given gauge. SampleFunc should return the
|
||||
// total number of bytes transferred by the stream since it started.
|
||||
//
|
||||
// Watch is a blocking call and should typically be called in a new goroutine.
|
||||
// To prevent the goroutine from leaking, make sure to cancel the given context
|
||||
// once the stream is completed or canceled.
|
||||
func Watch(ctx context.Context, g Gauge, f SampleFunc, interval time.Duration) {
|
||||
g.Sample(time.Now(), f())
|
||||
t := time.NewTicker(interval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case now := <-t.C:
|
||||
g.Sample(now, f())
|
||||
}
|
||||
}
|
||||
}
|
81
grab/bps/sma.go
Normal file
81
grab/bps/sma.go
Normal file
@ -0,0 +1,81 @@
|
||||
package bps
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewSMA returns a gauge that uses a Simple Moving Average with the given
|
||||
// number of samples to measure the bytes per second of a byte stream.
|
||||
//
|
||||
// BPS is computed using the timestamp of the most recent and oldest sample in
|
||||
// the sample buffer. When a new sample is added, the oldest sample is dropped
|
||||
// if the sample count exceeds maxSamples.
|
||||
//
|
||||
// The gauge does not account for any latency in arrival time of new samples or
|
||||
// the desired window size. Any variance in the arrival of samples will result
|
||||
// in a BPS measurement that is correct for the submitted samples, but over a
|
||||
// varying time window.
|
||||
//
|
||||
// maxSamples should be equal to 1 + (window size / sampling interval) where
|
||||
// window size is the number of seconds over which the moving average is
|
||||
// smoothed and sampling interval is the number of seconds between each sample.
|
||||
//
|
||||
// For example, if you want a five second window, sampling once per second,
|
||||
// maxSamples should be 1 + 5/1 = 6.
|
||||
func NewSMA(maxSamples int) Gauge {
|
||||
if maxSamples < 2 {
|
||||
panic("sample count must be greater than 1")
|
||||
}
|
||||
return &sma{
|
||||
maxSamples: uint64(maxSamples),
|
||||
samples: make([]int64, maxSamples),
|
||||
timestamps: make([]time.Time, maxSamples),
|
||||
}
|
||||
}
|
||||
|
||||
type sma struct {
|
||||
mu sync.Mutex
|
||||
index uint64
|
||||
maxSamples uint64
|
||||
sampleCount uint64
|
||||
samples []int64
|
||||
timestamps []time.Time
|
||||
}
|
||||
|
||||
func (c *sma) Sample(t time.Time, n int64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.timestamps[c.index] = t
|
||||
c.samples[c.index] = n
|
||||
c.index = (c.index + 1) % c.maxSamples
|
||||
|
||||
// prevent integer overflow in sampleCount. Values greater or equal to
|
||||
// maxSamples have the same semantic meaning.
|
||||
c.sampleCount++
|
||||
if c.sampleCount > c.maxSamples {
|
||||
c.sampleCount = c.maxSamples
|
||||
}
|
||||
}
|
||||
|
||||
func (c *sma) BPS() float64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// we need two samples to start
|
||||
if c.sampleCount < 2 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// First sample is always the oldest until ring buffer first overflows
|
||||
oldest := c.index
|
||||
if c.sampleCount < c.maxSamples {
|
||||
oldest = 0
|
||||
}
|
||||
|
||||
newest := (c.index + c.maxSamples - 1) % c.maxSamples
|
||||
seconds := c.timestamps[newest].Sub(c.timestamps[oldest]).Seconds()
|
||||
bytes := float64(c.samples[newest] - c.samples[oldest])
|
||||
return bytes / seconds
|
||||
}
|
55
grab/bps/sma_test.go
Normal file
55
grab/bps/sma_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package bps
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Sample struct {
|
||||
N int64
|
||||
Expect float64
|
||||
}
|
||||
|
||||
func getSimpleSamples(sampleCount, rate int) []Sample {
|
||||
a := make([]Sample, sampleCount)
|
||||
for i := 1; i < sampleCount; i++ {
|
||||
a[i] = Sample{N: int64(i * rate), Expect: float64(rate)}
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
type SampleSetTest struct {
|
||||
Gauge Gauge
|
||||
Interval time.Duration
|
||||
Samples []Sample
|
||||
}
|
||||
|
||||
func (c *SampleSetTest) Run(t *testing.T) {
|
||||
ts := time.Unix(0, 0)
|
||||
for i, sample := range c.Samples {
|
||||
c.Gauge.Sample(ts, sample.N)
|
||||
if actual := c.Gauge.BPS(); actual != sample.Expect {
|
||||
t.Errorf("expected: Gauge.BPS() → %0.2f, got %0.2f in test %d", sample.Expect, actual, i+1)
|
||||
}
|
||||
ts = ts.Add(c.Interval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMA_SimpleSteadyCase(t *testing.T) {
|
||||
test := &SampleSetTest{
|
||||
Interval: time.Second,
|
||||
Samples: getSimpleSamples(100000, 3),
|
||||
}
|
||||
t.Run("SmallSampleSize", func(t *testing.T) {
|
||||
test.Gauge = NewSMA(2)
|
||||
test.Run(t)
|
||||
})
|
||||
t.Run("RegularSize", func(t *testing.T) {
|
||||
test.Gauge = NewSMA(6)
|
||||
test.Run(t)
|
||||
})
|
||||
t.Run("LargeSampleSize", func(t *testing.T) {
|
||||
test.Gauge = NewSMA(1000)
|
||||
test.Run(t)
|
||||
})
|
||||
}
|
570
grab/client.go
Normal file
570
grab/client.go
Normal file
@ -0,0 +1,570 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTPClient provides an interface allowing us to perform HTTP requests.
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// truncater is a private interface allowing different response
|
||||
// Writers to be truncated
|
||||
type truncater interface {
|
||||
Truncate(size int64) error
|
||||
}
|
||||
|
||||
// A Client is a file download client.
|
||||
//
|
||||
// Clients are safe for concurrent use by multiple goroutines.
|
||||
type Client struct {
|
||||
// HTTPClient specifies the http.Client which will be used for communicating
|
||||
// with the remote server during the file transfer.
|
||||
HTTPClient HTTPClient
|
||||
|
||||
// UserAgent specifies the User-Agent string which will be set in the
|
||||
// headers of all requests made by this client.
|
||||
//
|
||||
// The user agent string may be overridden in the headers of each request.
|
||||
UserAgent string
|
||||
|
||||
// BufferSize specifies the size in bytes of the buffer that is used for
|
||||
// transferring all requested files. Larger buffers may result in faster
|
||||
// throughput but will use more memory and result in less frequent updates
|
||||
// to the transfer progress statistics. The BufferSize of each request can
|
||||
// be overridden on each Request object. Default: 32KB.
|
||||
BufferSize int
|
||||
}
|
||||
|
||||
// NewClient returns a new file download Client, using default configuration.
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
UserAgent: "grab",
|
||||
HTTPClient: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultClient is the default client and is used by all Get convenience
|
||||
// functions.
|
||||
var DefaultClient = NewClient()
|
||||
|
||||
// Do sends a file transfer request and returns a file transfer response,
|
||||
// following policy (e.g. redirects, cookies, auth) as configured on the
|
||||
// client's HTTPClient.
|
||||
//
|
||||
// Like http.Get, Do blocks while the transfer is initiated, but returns as soon
|
||||
// as the transfer has started transferring in a background goroutine, or if it
|
||||
// failed early.
|
||||
//
|
||||
// An error is returned via Response.Err if caused by client policy (such as
|
||||
// CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err
|
||||
// will block the caller until the transfer is completed, successfully or
|
||||
// otherwise.
|
||||
func (c *Client) Do(req *Request) *Response {
|
||||
// cancel will be called on all code-paths via closeResponse
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
req = req.WithContext(ctx)
|
||||
resp := &Response{
|
||||
Request: req,
|
||||
Start: time.Now(),
|
||||
Done: make(chan struct{}, 0),
|
||||
Filename: req.Filename,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
bufferSize: req.BufferSize,
|
||||
}
|
||||
if resp.bufferSize == 0 {
|
||||
// default to Client.BufferSize
|
||||
resp.bufferSize = c.BufferSize
|
||||
}
|
||||
|
||||
// Run state-machine while caller is blocked to initialize the file transfer.
|
||||
// Must never transition to the copyFile state - this happens next in another
|
||||
// goroutine.
|
||||
c.run(resp, c.statFileInfo)
|
||||
|
||||
// Run copyFile in a new goroutine. copyFile will no-op if the transfer is
|
||||
// already complete or failed.
|
||||
go c.run(resp, c.copyFile)
|
||||
return resp
|
||||
}
|
||||
|
||||
// DoChannel executes all requests sent through the given Request channel, one
|
||||
// at a time, until it is closed by another goroutine. The caller is blocked
|
||||
// until the Request channel is closed and all transfers have completed. All
|
||||
// responses are sent through the given Response channel as soon as they are
|
||||
// received from the remote servers and can be used to track the progress of
|
||||
// each download.
|
||||
//
|
||||
// Slow Response receivers will cause a worker to block and therefore delay the
|
||||
// start of the transfer for an already initiated connection - potentially
|
||||
// causing a server timeout. It is the caller's responsibility to ensure a
|
||||
// sufficient buffer size is used for the Response channel to prevent this.
|
||||
//
|
||||
// If an error occurs during any of the file transfers it will be accessible via
|
||||
// the associated Response.Err function.
|
||||
func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) {
|
||||
// TODO: enable cancelling of batch jobs
|
||||
for req := range reqch {
|
||||
resp := c.Do(req)
|
||||
respch <- resp
|
||||
<-resp.Done
|
||||
}
|
||||
}
|
||||
|
||||
// DoBatch executes all the given requests using the given number of concurrent
|
||||
// workers. Control is passed back to the caller as soon as the workers are
|
||||
// initiated.
|
||||
//
|
||||
// If the requested number of workers is less than one, a worker will be created
|
||||
// for every request. I.e. all requests will be executed concurrently.
|
||||
//
|
||||
// If an error occurs during any of the file transfers it will be accessible via
|
||||
// call to the associated Response.Err.
|
||||
//
|
||||
// The returned Response channel is closed only after all of the given Requests
|
||||
// have completed, successfully or otherwise.
|
||||
func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response {
|
||||
if workers < 1 {
|
||||
workers = len(requests)
|
||||
}
|
||||
reqch := make(chan *Request, len(requests))
|
||||
respch := make(chan *Response, len(requests))
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < workers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
c.DoChannel(reqch, respch)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// queue requests
|
||||
go func() {
|
||||
for _, req := range requests {
|
||||
reqch <- req
|
||||
}
|
||||
close(reqch)
|
||||
wg.Wait()
|
||||
close(respch)
|
||||
}()
|
||||
return respch
|
||||
}
|
||||
|
||||
// An stateFunc is an action that mutates the state of a Response and returns
|
||||
// the next stateFunc to be called.
|
||||
type stateFunc func(*Response) stateFunc
|
||||
|
||||
// run calls the given stateFunc function and all subsequent returned stateFuncs
|
||||
// until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc
|
||||
// should mutate the state of the given Response until it has completed
|
||||
// downloading or failed.
|
||||
func (c *Client) run(resp *Response, f stateFunc) {
|
||||
for {
|
||||
select {
|
||||
case <-resp.ctx.Done():
|
||||
if resp.IsComplete() {
|
||||
return
|
||||
}
|
||||
resp.err = resp.ctx.Err()
|
||||
f = c.closeResponse
|
||||
|
||||
default:
|
||||
// keep working
|
||||
}
|
||||
if f = f(resp); f == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// statFileInfo retrieves FileInfo for any local file matching
|
||||
// Response.Filename.
|
||||
//
|
||||
// If the file does not exist, is a directory, or its name is unknown the next
|
||||
// stateFunc is headRequest.
|
||||
//
|
||||
// If the file exists, Response.fi is set and the next stateFunc is
|
||||
// validateLocal.
|
||||
//
|
||||
// If an error occurs, the next stateFunc is closeResponse.
|
||||
func (c *Client) statFileInfo(resp *Response) stateFunc {
|
||||
if resp.Request.NoStore || resp.Filename == "" { // No filename provided will guess
|
||||
return c.headRequest
|
||||
}
|
||||
fi, err := os.Stat(resp.Filename)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) { // Filename does not exist, will download
|
||||
return c.headRequest
|
||||
}
|
||||
resp.err = err
|
||||
return c.closeResponse // Other PathError occured
|
||||
}
|
||||
if fi.IsDir() { // resp.Request.Filename is a directory
|
||||
// Will guess filename and append to resp.Request.Filename
|
||||
resp.Filename = ""
|
||||
return c.headRequest
|
||||
}
|
||||
resp.fi = fi
|
||||
return c.validateLocal // Filename exists, validate file
|
||||
}
|
||||
|
||||
// validateLocal compares a local copy of the downloaded file to the remote
|
||||
// file.
|
||||
//
|
||||
// An error is returned if the local file is larger than the remote file, or
|
||||
// Request.SkipExisting is true.
|
||||
//
|
||||
// If the existing file matches the length of the remote file, the next
|
||||
// stateFunc is checksumFile.
|
||||
//
|
||||
// If the local file is smaller than the remote file and the remote server is
|
||||
// known to support ranged requests, the next stateFunc is getRequest.
|
||||
func (c *Client) validateLocal(resp *Response) stateFunc {
|
||||
if resp.Request.SkipExisting {
|
||||
resp.err = ErrFileExists
|
||||
return c.closeResponse
|
||||
}
|
||||
|
||||
// determine target file size
|
||||
expectedSize := resp.Request.Size
|
||||
if expectedSize == 0 && resp.HTTPResponse != nil {
|
||||
expectedSize = resp.HTTPResponse.ContentLength
|
||||
}
|
||||
|
||||
if expectedSize == 0 {
|
||||
// size is either actually 0 or unknown
|
||||
// if unknown, we ask the remote server
|
||||
// if known to be 0, we proceed with a GET
|
||||
return c.headRequest
|
||||
}
|
||||
|
||||
if expectedSize == resp.fi.Size() {
|
||||
// local file matches remote file size - wrap it up
|
||||
resp.DidResume = true
|
||||
resp.bytesResumed = resp.fi.Size()
|
||||
return c.checksumFile
|
||||
}
|
||||
|
||||
if resp.Request.NoResume {
|
||||
// local file should be overwritten
|
||||
return c.getRequest
|
||||
}
|
||||
|
||||
if expectedSize >= 0 && expectedSize < resp.fi.Size() {
|
||||
// remote size is known, is smaller than local size and we want to resume
|
||||
resp.err = ErrBadLength
|
||||
return c.closeResponse
|
||||
}
|
||||
|
||||
if resp.CanResume {
|
||||
// set resume range on GET request
|
||||
resp.Request.HTTPRequest.Header.Set(
|
||||
"Range",
|
||||
fmt.Sprintf("bytes=%d-", resp.fi.Size()))
|
||||
resp.DidResume = true
|
||||
resp.bytesResumed = resp.fi.Size()
|
||||
return c.getRequest
|
||||
}
|
||||
return c.headRequest
|
||||
}
|
||||
|
||||
func (c *Client) checksumFile(resp *Response) stateFunc {
|
||||
if resp.Request.hash == nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
if resp.Filename == "" {
|
||||
panic("grab: developer error: filename not set")
|
||||
}
|
||||
if resp.Size() < 0 {
|
||||
panic("grab: developer error: size unknown")
|
||||
}
|
||||
req := resp.Request
|
||||
|
||||
// compute checksum
|
||||
var sum []byte
|
||||
sum, resp.err = resp.checksumUnsafe()
|
||||
if resp.err != nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
|
||||
// compare checksum
|
||||
if !bytes.Equal(sum, req.checksum) {
|
||||
resp.err = ErrBadChecksum
|
||||
if !resp.Request.NoStore && req.deleteOnError {
|
||||
if err := os.Remove(resp.Filename); err != nil {
|
||||
// err should be os.PathError and include file path
|
||||
resp.err = fmt.Errorf(
|
||||
"cannot remove downloaded file with checksum mismatch: %v",
|
||||
err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return c.closeResponse
|
||||
}
|
||||
|
||||
// doHTTPRequest sends a HTTP Request and returns the response
|
||||
func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) {
|
||||
if c.UserAgent != "" && req.Header.Get("User-Agent") == "" {
|
||||
req.Header.Set("User-Agent", c.UserAgent)
|
||||
}
|
||||
return c.HTTPClient.Do(req)
|
||||
}
|
||||
|
||||
func (c *Client) headRequest(resp *Response) stateFunc {
|
||||
if resp.optionsKnown {
|
||||
return c.getRequest
|
||||
}
|
||||
resp.optionsKnown = true
|
||||
|
||||
if resp.Request.NoResume {
|
||||
return c.getRequest
|
||||
}
|
||||
|
||||
if resp.Filename != "" && resp.fi == nil {
|
||||
// destination path is already known and does not exist
|
||||
return c.getRequest
|
||||
}
|
||||
|
||||
hreq := new(http.Request)
|
||||
*hreq = *resp.Request.HTTPRequest
|
||||
hreq.Method = "HEAD"
|
||||
|
||||
resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq)
|
||||
if resp.err != nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
resp.HTTPResponse.Body.Close()
|
||||
|
||||
if resp.HTTPResponse.StatusCode != http.StatusOK {
|
||||
return c.getRequest
|
||||
}
|
||||
|
||||
// In case of redirects during HEAD, record the final URL and use it
|
||||
// instead of the original URL when sending future requests.
|
||||
// This way we avoid sending potentially unsupported requests to
|
||||
// the original URL, e.g. "Range", since it was the final URL
|
||||
// that advertised its support.
|
||||
resp.Request.HTTPRequest.URL = resp.HTTPResponse.Request.URL
|
||||
resp.Request.HTTPRequest.Host = resp.HTTPResponse.Request.Host
|
||||
|
||||
return c.readResponse
|
||||
}
|
||||
|
||||
func (c *Client) getRequest(resp *Response) stateFunc {
|
||||
resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest)
|
||||
if resp.err != nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
|
||||
// TODO: check Content-Range
|
||||
|
||||
// check status code
|
||||
if !resp.Request.IgnoreBadStatusCodes {
|
||||
if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 {
|
||||
resp.err = StatusCodeError(resp.HTTPResponse.StatusCode)
|
||||
return c.closeResponse
|
||||
}
|
||||
}
|
||||
|
||||
return c.readResponse
|
||||
}
|
||||
|
||||
func (c *Client) readResponse(resp *Response) stateFunc {
|
||||
if resp.HTTPResponse == nil {
|
||||
panic("grab: developer error: Response.HTTPResponse is nil")
|
||||
}
|
||||
|
||||
// check expected size
|
||||
resp.sizeUnsafe = resp.HTTPResponse.ContentLength
|
||||
if resp.sizeUnsafe >= 0 {
|
||||
// remote size is known
|
||||
resp.sizeUnsafe += resp.bytesResumed
|
||||
if resp.Request.Size > 0 && resp.Request.Size != resp.sizeUnsafe {
|
||||
resp.err = ErrBadLength
|
||||
return c.closeResponse
|
||||
}
|
||||
}
|
||||
|
||||
// check filename
|
||||
if resp.Filename == "" {
|
||||
filename, err := guessFilename(resp.HTTPResponse)
|
||||
if err != nil {
|
||||
resp.err = err
|
||||
return c.closeResponse
|
||||
}
|
||||
// Request.Filename will be empty or a directory
|
||||
resp.Filename = filepath.Join(resp.Request.Filename, filename)
|
||||
}
|
||||
|
||||
if !resp.Request.NoStore && resp.requestMethod() == "HEAD" {
|
||||
if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" {
|
||||
resp.CanResume = true
|
||||
}
|
||||
return c.statFileInfo
|
||||
}
|
||||
return c.openWriter
|
||||
}
|
||||
|
||||
// openWriter opens the destination file for writing and seeks to the location
|
||||
// from whence the file transfer will resume.
|
||||
//
|
||||
// Requires that Response.Filename and resp.DidResume are already be set.
|
||||
func (c *Client) openWriter(resp *Response) stateFunc {
|
||||
if !resp.Request.NoStore && !resp.Request.NoCreateDirectories {
|
||||
resp.err = mkdirp(resp.Filename)
|
||||
if resp.err != nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Request.NoStore {
|
||||
resp.writer = &resp.storeBuffer
|
||||
} else {
|
||||
// compute write flags
|
||||
flag := os.O_CREATE | os.O_WRONLY
|
||||
if resp.fi != nil {
|
||||
if resp.DidResume {
|
||||
flag = os.O_APPEND | os.O_WRONLY
|
||||
} else {
|
||||
// truncate later in copyFile, if not cancelled
|
||||
// by BeforeCopy hook
|
||||
flag = os.O_WRONLY
|
||||
}
|
||||
}
|
||||
|
||||
// open file
|
||||
f, err := os.OpenFile(resp.Filename, flag, 0666)
|
||||
if err != nil {
|
||||
resp.err = err
|
||||
return c.closeResponse
|
||||
}
|
||||
resp.writer = f
|
||||
|
||||
// seek to start or end
|
||||
whence := os.SEEK_SET
|
||||
if resp.bytesResumed > 0 {
|
||||
whence = os.SEEK_END
|
||||
}
|
||||
_, resp.err = f.Seek(0, whence)
|
||||
if resp.err != nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
}
|
||||
|
||||
// init transfer
|
||||
if resp.bufferSize < 1 {
|
||||
resp.bufferSize = 32 * 1024
|
||||
}
|
||||
b := make([]byte, resp.bufferSize)
|
||||
resp.transfer = newTransfer(
|
||||
resp.Request.Context(),
|
||||
resp.Request.RateLimiter,
|
||||
resp.writer,
|
||||
resp.HTTPResponse.Body,
|
||||
b)
|
||||
|
||||
// next step is copyFile, but this will be called later in another goroutine
|
||||
return nil
|
||||
}
|
||||
|
||||
// copy transfers content for a HTTP connection established via Client.do()
|
||||
func (c *Client) copyFile(resp *Response) stateFunc {
|
||||
if resp.IsComplete() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// run BeforeCopy hook
|
||||
if f := resp.Request.BeforeCopy; f != nil {
|
||||
resp.err = f(resp)
|
||||
if resp.err != nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
}
|
||||
|
||||
var bytesCopied int64
|
||||
if resp.transfer == nil {
|
||||
panic("grab: developer error: Response.transfer is nil")
|
||||
}
|
||||
|
||||
// We waited to truncate the file in openWriter() to make sure
|
||||
// the BeforeCopy didn't cancel the copy. If this was an existing
|
||||
// file that is not going to be resumed, truncate the contents.
|
||||
if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume {
|
||||
t.Truncate(0)
|
||||
}
|
||||
|
||||
bytesCopied, resp.err = resp.transfer.copy()
|
||||
if resp.err != nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
closeWriter(resp)
|
||||
|
||||
// set file timestamp
|
||||
if !resp.Request.NoStore && !resp.Request.IgnoreRemoteTime {
|
||||
resp.err = setLastModified(resp.HTTPResponse, resp.Filename)
|
||||
if resp.err != nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
}
|
||||
|
||||
// update transfer size if previously unknown
|
||||
if resp.Size() < 0 {
|
||||
discoveredSize := resp.bytesResumed + bytesCopied
|
||||
atomic.StoreInt64(&resp.sizeUnsafe, discoveredSize)
|
||||
if resp.Request.Size > 0 && resp.Request.Size != discoveredSize {
|
||||
resp.err = ErrBadLength
|
||||
return c.closeResponse
|
||||
}
|
||||
}
|
||||
|
||||
// run AfterCopy hook
|
||||
if f := resp.Request.AfterCopy; f != nil {
|
||||
resp.err = f(resp)
|
||||
if resp.err != nil {
|
||||
return c.closeResponse
|
||||
}
|
||||
}
|
||||
|
||||
return c.checksumFile
|
||||
}
|
||||
|
||||
func closeWriter(resp *Response) {
|
||||
if closer, ok := resp.writer.(io.Closer); ok {
|
||||
closer.Close()
|
||||
}
|
||||
resp.writer = nil
|
||||
}
|
||||
|
||||
// close finalizes the Response
|
||||
func (c *Client) closeResponse(resp *Response) stateFunc {
|
||||
if resp.IsComplete() {
|
||||
panic("grab: developer error: response already closed")
|
||||
}
|
||||
|
||||
resp.fi = nil
|
||||
closeWriter(resp)
|
||||
resp.closeResponseBody()
|
||||
|
||||
resp.End = time.Now()
|
||||
close(resp.Done)
|
||||
if resp.cancel != nil {
|
||||
resp.cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
915
grab/client_test.go
Normal file
915
grab/client_test.go
Normal file
@ -0,0 +1,915 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cavaliercoder/grab/grabtest"
|
||||
)
|
||||
|
||||
// TestFilenameResolutions tests that the destination filename for Requests can
|
||||
// be determined correctly, using an explicitly requested path,
|
||||
// Content-Disposition headers or a URL path - with or without an existing
|
||||
// target directory.
|
||||
func TestFilenameResolution(t *testing.T) {
|
||||
tests := []struct {
|
||||
Name string
|
||||
Filename string
|
||||
URL string
|
||||
AttachmentFilename string
|
||||
Expect string
|
||||
}{
|
||||
{"Using Request.Filename", ".testWithFilename", "/url-filename", "header-filename", ".testWithFilename"},
|
||||
{"Using Content-Disposition Header", "", "/url-filename", ".testWithHeaderFilename", ".testWithHeaderFilename"},
|
||||
{"Using Content-Disposition Header with target directory", ".test", "/url-filename", "header-filename", ".test/header-filename"},
|
||||
{"Using URL Path", "", "/.testWithURLFilename?params-filename", "", ".testWithURLFilename"},
|
||||
{"Using URL Path with target directory", ".test", "/url-filename?garbage", "", ".test/url-filename"},
|
||||
{"Failure", "", "", "", ""},
|
||||
}
|
||||
|
||||
err := os.Mkdir(".test", 0777)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer os.RemoveAll(".test")
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
opts := []grabtest.HandlerOption{}
|
||||
if test.AttachmentFilename != "" {
|
||||
opts = append(opts, grabtest.AttachmentFilename(test.AttachmentFilename))
|
||||
}
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(test.Filename, url+test.URL)
|
||||
resp := DefaultClient.Do(req)
|
||||
defer os.Remove(resp.Filename)
|
||||
if err := resp.Err(); err != nil {
|
||||
if test.Expect != "" || err != ErrNoFilename {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
if test.Expect == "" {
|
||||
t.Errorf("expected: %v, got: %v", ErrNoFilename, err)
|
||||
}
|
||||
}
|
||||
if resp.Filename != test.Expect {
|
||||
t.Errorf("Filename mismatch. Expected '%s', got '%s'.", test.Expect, resp.Filename)
|
||||
}
|
||||
testComplete(t, resp)
|
||||
}, opts...)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestChecksums checks that checksum validation behaves as expected for valid
|
||||
// and corrupted downloads.
|
||||
func TestChecksums(t *testing.T) {
|
||||
tests := []struct {
|
||||
size int
|
||||
hash hash.Hash
|
||||
sum string
|
||||
match bool
|
||||
}{
|
||||
{128, md5.New(), "37eff01866ba3f538421b30b7cbefcac", true},
|
||||
{128, md5.New(), "37eff01866ba3f538421b30b7cbefcad", false},
|
||||
{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855b", true},
|
||||
{1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855c", false},
|
||||
{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef372", true},
|
||||
{1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef373", false},
|
||||
{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d535", true},
|
||||
{128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d536", false},
|
||||
{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b77", true},
|
||||
{1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b78", false},
|
||||
{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923be", true},
|
||||
{1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923bf", false},
|
||||
{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be5", true},
|
||||
{128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be4", false},
|
||||
{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c9", true},
|
||||
{1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c8", false},
|
||||
{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83", true},
|
||||
{1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c82", false},
|
||||
{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f7", true},
|
||||
{128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f8", false},
|
||||
{1024, sha512.New(), "37f652be867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566c", true},
|
||||
{1024, sha512.New(), "37f652bf867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566d", false},
|
||||
{1048576, sha512.New(), "ac1d097b4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", true},
|
||||
{1048576, sha512.New(), "ac1d097c4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
var expect error
|
||||
comparison := "Match"
|
||||
if !test.match {
|
||||
comparison = "Mismatch"
|
||||
expect = ErrBadChecksum
|
||||
}
|
||||
|
||||
t.Run(fmt.Sprintf("With%s%s", comparison, test.sum[:8]), func(t *testing.T) {
|
||||
filename := fmt.Sprintf(".testChecksum-%s-%s", comparison, test.sum[:8])
|
||||
defer os.Remove(filename)
|
||||
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
req.SetChecksum(test.hash, grabtest.MustHexDecodeString(test.sum), true)
|
||||
|
||||
resp := DefaultClient.Do(req)
|
||||
err := resp.Err()
|
||||
if err != expect {
|
||||
t.Errorf("expected error: %v, got: %v", expect, err)
|
||||
}
|
||||
|
||||
// ensure mismatch file was deleted
|
||||
if !test.match {
|
||||
if _, err := os.Stat(filename); err == nil {
|
||||
t.Errorf("checksum failure not cleaned up: %s", filename)
|
||||
} else if !os.IsNotExist(err) {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
testComplete(t, resp)
|
||||
}, grabtest.ContentLength(test.size))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContentLength ensures that ErrBadLength is returned if a server response
|
||||
// does not match the requested length.
|
||||
func TestContentLength(t *testing.T) {
|
||||
size := int64(32768)
|
||||
testCases := []struct {
|
||||
Name string
|
||||
NoHead bool
|
||||
Size int64
|
||||
Expect int64
|
||||
Match bool
|
||||
}{
|
||||
{"Good size in HEAD request", false, size, size, true},
|
||||
{"Good size in GET request", true, size, size, true},
|
||||
{"Bad size in HEAD request", false, size - 1, size, false},
|
||||
{"Bad size in GET request", true, size - 1, size, false},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
opts := []grabtest.HandlerOption{
|
||||
grabtest.ContentLength(int(test.Size)),
|
||||
}
|
||||
if test.NoHead {
|
||||
opts = append(opts, grabtest.MethodWhitelist("GET"))
|
||||
}
|
||||
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(".testSize-mismatch-head", url)
|
||||
req.Size = size
|
||||
resp := DefaultClient.Do(req)
|
||||
defer os.Remove(resp.Filename)
|
||||
err := resp.Err()
|
||||
if test.Match {
|
||||
if err == ErrBadLength {
|
||||
t.Errorf("error: %v", err)
|
||||
} else if err != nil {
|
||||
panic(err)
|
||||
} else if resp.Size() != size {
|
||||
t.Errorf("expected %v bytes, got %v bytes", size, resp.Size())
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Errorf("expected: %v, got %v", ErrBadLength, err)
|
||||
} else if err != ErrBadLength {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
testComplete(t, resp)
|
||||
}, opts...)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoResume tests segmented downloading of a large file.
|
||||
func TestAutoResume(t *testing.T) {
|
||||
segs := 8
|
||||
size := 1048576
|
||||
sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grabtest.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83")
|
||||
filename := ".testAutoResume"
|
||||
|
||||
defer os.Remove(filename)
|
||||
|
||||
for i := 0; i < segs; i++ {
|
||||
segsize := (i + 1) * (size / segs)
|
||||
t.Run(fmt.Sprintf("With%vBytes", segsize), func(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
if i == segs-1 {
|
||||
req.SetChecksum(sha256.New(), sum, false)
|
||||
}
|
||||
resp := mustDo(req)
|
||||
if i > 0 && !resp.DidResume {
|
||||
t.Errorf("expected Response.DidResume to be true")
|
||||
}
|
||||
testComplete(t, resp)
|
||||
},
|
||||
grabtest.ContentLength(segsize),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("WithFailure", func(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
// request smaller segment
|
||||
req := mustNewRequest(filename, url)
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != ErrBadLength {
|
||||
t.Errorf("expected ErrBadLength for smaller request, got: %v", err)
|
||||
}
|
||||
},
|
||||
grabtest.ContentLength(size-128),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("WithNoResume", func(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
req.NoResume = true
|
||||
resp := mustDo(req)
|
||||
if resp.DidResume {
|
||||
t.Errorf("expected Response.DidResume to be false")
|
||||
}
|
||||
testComplete(t, resp)
|
||||
},
|
||||
grabtest.ContentLength(size+128),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("WithNoResumeAndTruncate", func(t *testing.T) {
|
||||
size := size - 128
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
req.NoResume = true
|
||||
resp := mustDo(req)
|
||||
if resp.DidResume {
|
||||
t.Errorf("expected Response.DidResume to be false")
|
||||
}
|
||||
if v := resp.BytesComplete(); v != int64(size) {
|
||||
t.Errorf("expected Response.BytesComplete: %d, got: %d", size, v)
|
||||
}
|
||||
testComplete(t, resp)
|
||||
},
|
||||
grabtest.ContentLength(size),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("WithNoContentLengthHeader", func(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
req.SetChecksum(sha256.New(), sum, false)
|
||||
resp := mustDo(req)
|
||||
if !resp.DidResume {
|
||||
t.Errorf("expected Response.DidResume to be true")
|
||||
}
|
||||
if actual := resp.Size(); actual != int64(size) {
|
||||
t.Errorf("expected Response.Size: %d, got: %d", size, actual)
|
||||
}
|
||||
testComplete(t, resp)
|
||||
},
|
||||
grabtest.ContentLength(size),
|
||||
grabtest.HeaderBlacklist("Content-Length"),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("WithNoContentLengthHeaderAndChecksumFailure", func(t *testing.T) {
|
||||
// ref: https://github.com/cavaliercoder/grab/pull/27
|
||||
size := size * 2
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
req.SetChecksum(sha256.New(), sum, false)
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != ErrBadChecksum {
|
||||
t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
|
||||
}
|
||||
if !resp.DidResume {
|
||||
t.Errorf("expected Response.DidResume to be true")
|
||||
}
|
||||
if actual := resp.BytesComplete(); actual != int64(size) {
|
||||
t.Errorf("expected Response.BytesComplete: %d, got: %d", size, actual)
|
||||
}
|
||||
if actual := resp.Size(); actual != int64(size) {
|
||||
t.Errorf("expected Response.Size: %d, got: %d", size, actual)
|
||||
}
|
||||
testComplete(t, resp)
|
||||
},
|
||||
grabtest.ContentLength(size),
|
||||
grabtest.HeaderBlacklist("Content-Length"),
|
||||
)
|
||||
})
|
||||
// TODO: test when existing file is corrupted
|
||||
}
|
||||
|
||||
func TestSkipExisting(t *testing.T) {
|
||||
filename := ".testSkipExisting"
|
||||
defer os.Remove(filename)
|
||||
|
||||
// download a file
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
resp := mustDo(mustNewRequest(filename, url))
|
||||
testComplete(t, resp)
|
||||
})
|
||||
|
||||
// redownload
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
resp := mustDo(mustNewRequest(filename, url))
|
||||
testComplete(t, resp)
|
||||
|
||||
// ensure download was resumed
|
||||
if !resp.DidResume {
|
||||
t.Fatalf("Expected download to skip existing file, but it did not")
|
||||
}
|
||||
|
||||
// ensure all bytes were resumed
|
||||
if resp.Size() == 0 || resp.Size() != resp.bytesResumed {
|
||||
t.Fatalf("Expected to skip %d bytes in redownload; got %d", resp.Size(), resp.bytesResumed)
|
||||
}
|
||||
})
|
||||
|
||||
// ensure checksum is performed on pre-existing file
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
req.SetChecksum(sha256.New(), []byte{0x01, 0x02, 0x03, 0x04}, true)
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != ErrBadChecksum {
|
||||
t.Fatalf("Expected checksum error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestBatch executes multiple requests simultaneously and validates the
|
||||
// responses.
|
||||
func TestBatch(t *testing.T) {
|
||||
tests := 32
|
||||
size := 32768
|
||||
sum := grabtest.MustHexDecodeString("e11360251d1173650cdcd20f111d8f1ca2e412f572e8b36a4dc067121c1799b8")
|
||||
|
||||
// test with 4 workers and with one per request
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
for _, workerCount := range []int{4, 0} {
|
||||
// create requests
|
||||
reqs := make([]*Request, tests)
|
||||
for i := 0; i < len(reqs); i++ {
|
||||
filename := fmt.Sprintf(".testBatch.%d", i+1)
|
||||
reqs[i] = mustNewRequest(filename, url+fmt.Sprintf("/request_%d?", i+1))
|
||||
reqs[i].Label = fmt.Sprintf("Test %d", i+1)
|
||||
reqs[i].SetChecksum(sha256.New(), sum, false)
|
||||
}
|
||||
|
||||
// batch run
|
||||
responses := DefaultClient.DoBatch(workerCount, reqs...)
|
||||
|
||||
// listen for responses
|
||||
Loop:
|
||||
for i := 0; i < len(reqs); {
|
||||
select {
|
||||
case resp := <-responses:
|
||||
if resp == nil {
|
||||
break Loop
|
||||
}
|
||||
testComplete(t, resp)
|
||||
if err := resp.Err(); err != nil {
|
||||
t.Errorf("%s: %v", resp.Filename, err)
|
||||
}
|
||||
|
||||
// remove test file
|
||||
if resp.IsComplete() {
|
||||
os.Remove(resp.Filename) // ignore errors
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
grabtest.ContentLength(size),
|
||||
)
|
||||
}
|
||||
|
||||
// TestCancelContext tests that a batch of requests can be cancel using a
|
||||
// context.Context cancellation. Requests are cancelled in multiple states:
|
||||
// in-progress and unstarted.
|
||||
func TestCancelContext(t *testing.T) {
|
||||
fileSize := 134217728
|
||||
tests := 256
|
||||
client := NewClient()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
reqs := make([]*Request, tests)
|
||||
for i := 0; i < tests; i++ {
|
||||
req := mustNewRequest("", fmt.Sprintf("%s/.testCancelContext%d", url, i))
|
||||
reqs[i] = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
respch := client.DoBatch(8, reqs...)
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
cancel()
|
||||
for resp := range respch {
|
||||
defer os.Remove(resp.Filename)
|
||||
|
||||
// err should be context.Canceled or http.errRequestCanceled
|
||||
if resp.Err() == nil || !strings.Contains(resp.Err().Error(), "canceled") {
|
||||
t.Errorf("expected '%v', got '%v'", context.Canceled, resp.Err())
|
||||
}
|
||||
if resp.BytesComplete() >= int64(fileSize) {
|
||||
t.Errorf("expected Response.BytesComplete: < %d, got: %d", fileSize, resp.BytesComplete())
|
||||
}
|
||||
}
|
||||
},
|
||||
grabtest.ContentLength(fileSize),
|
||||
)
|
||||
}
|
||||
|
||||
// TestCancelHangingResponse tests that a never ending request is terminated
|
||||
// when the response is cancelled.
|
||||
func TestCancelHangingResponse(t *testing.T) {
|
||||
fileSize := 10
|
||||
client := NewClient()
|
||||
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest("", fmt.Sprintf("%s/.testCancelHangingResponse", url))
|
||||
|
||||
resp := client.Do(req)
|
||||
defer os.Remove(resp.Filename)
|
||||
|
||||
// Wait for some bytes to be transferred
|
||||
for resp.BytesComplete() == 0 {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
done <- resp.Cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != context.Canceled {
|
||||
t.Errorf("Expected context.Canceled error, go: %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("response was not cancelled within 1s")
|
||||
}
|
||||
if resp.BytesComplete() == int64(fileSize) {
|
||||
t.Error("download was not supposed to be complete")
|
||||
}
|
||||
fmt.Println("bye")
|
||||
},
|
||||
grabtest.RateLimiter(1),
|
||||
grabtest.ContentLength(fileSize),
|
||||
)
|
||||
}
|
||||
|
||||
// TestNestedDirectory tests that missing subdirectories are created.
|
||||
func TestNestedDirectory(t *testing.T) {
|
||||
dir := "./.testNested/one/two/three"
|
||||
filename := ".testNestedFile"
|
||||
expect := dir + "/" + filename
|
||||
|
||||
t.Run("Create", func(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
resp := mustDo(mustNewRequest(expect, url+"/"+filename))
|
||||
defer os.RemoveAll("./.testNested/")
|
||||
if resp.Filename != expect {
|
||||
t.Errorf("expected nested Request.Filename to be %v, got %v", expect, resp.Filename)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("No create", func(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(expect, url+"/"+filename)
|
||||
req.NoCreateDirectories = true
|
||||
resp := DefaultClient.Do(req)
|
||||
err := resp.Err()
|
||||
if !os.IsNotExist(err) {
|
||||
t.Errorf("expected: %v, got: %v", os.ErrNotExist, err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestRemoteTime tests that the timestamp of the downloaded file can be set
|
||||
// according to the timestamp of the remote file.
|
||||
func TestRemoteTime(t *testing.T) {
|
||||
filename := "./.testRemoteTime"
|
||||
defer os.Remove(filename)
|
||||
|
||||
// random time between epoch and now
|
||||
expect := time.Unix(rand.Int63n(time.Now().Unix()), 0)
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
resp := mustDo(mustNewRequest(filename, url))
|
||||
fi, err := os.Stat(resp.Filename)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
actual := fi.ModTime()
|
||||
if !actual.Equal(expect) {
|
||||
t.Errorf("expected %v, got %v", expect, actual)
|
||||
}
|
||||
},
|
||||
grabtest.LastModified(expect),
|
||||
)
|
||||
}
|
||||
|
||||
func TestResponseCode(t *testing.T) {
|
||||
filename := "./.testResponseCode"
|
||||
|
||||
t.Run("With404", func(t *testing.T) {
|
||||
defer os.Remove(filename)
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
resp := DefaultClient.Do(req)
|
||||
expect := StatusCodeError(http.StatusNotFound)
|
||||
err := resp.Err()
|
||||
if err != expect {
|
||||
t.Errorf("expected %v, got '%v'", expect, err)
|
||||
}
|
||||
if !IsStatusCodeError(err) {
|
||||
t.Errorf("expected IsStatusCodeError to return true for %T: %v", err, err)
|
||||
}
|
||||
},
|
||||
grabtest.StatusCodeStatic(http.StatusNotFound),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("WithIgnoreNon2XX", func(t *testing.T) {
|
||||
defer os.Remove(filename)
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
req.IgnoreBadStatusCodes = true
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != nil {
|
||||
t.Errorf("expected nil, got '%v'", err)
|
||||
}
|
||||
},
|
||||
grabtest.StatusCodeStatic(http.StatusNotFound),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBeforeCopyHook(t *testing.T) {
|
||||
filename := "./.testBeforeCopy"
|
||||
t.Run("Noop", func(t *testing.T) {
|
||||
defer os.RemoveAll(filename)
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
called := false
|
||||
req := mustNewRequest(filename, url)
|
||||
req.BeforeCopy = func(resp *Response) error {
|
||||
called = true
|
||||
if resp.IsComplete() {
|
||||
t.Error("Response object passed to BeforeCopy hook has already been closed")
|
||||
}
|
||||
if resp.Progress() != 0 {
|
||||
t.Error("Download progress already > 0 when BeforeCopy hook was called")
|
||||
}
|
||||
if resp.Duration() == 0 {
|
||||
t.Error("Duration was zero when BeforeCopy was called")
|
||||
}
|
||||
if resp.BytesComplete() != 0 {
|
||||
t.Error("BytesComplete already > 0 when BeforeCopy hook was called")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != nil {
|
||||
t.Errorf("unexpected error using BeforeCopy hook: %v", err)
|
||||
}
|
||||
testComplete(t, resp)
|
||||
if !called {
|
||||
t.Error("BeforeCopy hook was never called")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("WithError", func(t *testing.T) {
|
||||
defer os.RemoveAll(filename)
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
testError := errors.New("test")
|
||||
req := mustNewRequest(filename, url)
|
||||
req.BeforeCopy = func(resp *Response) error {
|
||||
return testError
|
||||
}
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != testError {
|
||||
t.Errorf("expected error '%v', got '%v'", testError, err)
|
||||
}
|
||||
if resp.BytesComplete() != 0 {
|
||||
t.Errorf("expected 0 bytes completed for canceled BeforeCopy hook, got %d",
|
||||
resp.BytesComplete())
|
||||
}
|
||||
testComplete(t, resp)
|
||||
})
|
||||
})
|
||||
|
||||
// Assert that an existing local file will not be truncated prior to the
|
||||
// BeforeCopy hook has a chance to cancel the request
|
||||
t.Run("NoTruncate", func(t *testing.T) {
|
||||
tfile, err := ioutil.TempFile("", "grab_client_test.*.file")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tfile.Name())
|
||||
|
||||
const size = 128
|
||||
_, err = tfile.Write(bytes.Repeat([]byte("x"), size))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
called := false
|
||||
req := mustNewRequest(tfile.Name(), url)
|
||||
req.NoResume = true
|
||||
req.BeforeCopy = func(resp *Response) error {
|
||||
called = true
|
||||
fi, err := tfile.Stat()
|
||||
if err != nil {
|
||||
t.Errorf("failed to stat temp file: %v", err)
|
||||
return nil
|
||||
}
|
||||
if fi.Size() != size {
|
||||
t.Errorf("expected existing file size of %d bytes "+
|
||||
"prior to BeforeCopy hook, got %d", size, fi.Size())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != nil {
|
||||
t.Errorf("unexpected error using BeforeCopy hook: %v", err)
|
||||
}
|
||||
testComplete(t, resp)
|
||||
if !called {
|
||||
t.Error("BeforeCopy hook was never called")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAfterCopyHook(t *testing.T) {
|
||||
filename := "./.testAfterCopy"
|
||||
t.Run("Noop", func(t *testing.T) {
|
||||
defer os.RemoveAll(filename)
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
called := false
|
||||
req := mustNewRequest(filename, url)
|
||||
req.AfterCopy = func(resp *Response) error {
|
||||
called = true
|
||||
if resp.IsComplete() {
|
||||
t.Error("Response object passed to AfterCopy hook has already been closed")
|
||||
}
|
||||
if resp.Progress() <= 0 {
|
||||
t.Error("Download progress was 0 when AfterCopy hook was called")
|
||||
}
|
||||
if resp.Duration() == 0 {
|
||||
t.Error("Duration was zero when AfterCopy was called")
|
||||
}
|
||||
if resp.BytesComplete() <= 0 {
|
||||
t.Error("BytesComplete was 0 when AfterCopy hook was called")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != nil {
|
||||
t.Errorf("unexpected error using AfterCopy hook: %v", err)
|
||||
}
|
||||
testComplete(t, resp)
|
||||
if !called {
|
||||
t.Error("AfterCopy hook was never called")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("WithError", func(t *testing.T) {
|
||||
defer os.RemoveAll(filename)
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
testError := errors.New("test")
|
||||
req := mustNewRequest(filename, url)
|
||||
req.AfterCopy = func(resp *Response) error {
|
||||
return testError
|
||||
}
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != testError {
|
||||
t.Errorf("expected error '%v', got '%v'", testError, err)
|
||||
}
|
||||
if resp.BytesComplete() <= 0 {
|
||||
t.Errorf("ByteCompleted was %d after AfterCopy hook was called",
|
||||
resp.BytesComplete())
|
||||
}
|
||||
testComplete(t, resp)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestIssue37(t *testing.T) {
|
||||
// ref: https://github.com/cavaliercoder/grab/issues/37
|
||||
filename := "./.testIssue37"
|
||||
largeSize := int64(2097152)
|
||||
smallSize := int64(1048576)
|
||||
defer os.RemoveAll(filename)
|
||||
|
||||
// download large file
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
resp := mustDo(mustNewRequest(filename, url))
|
||||
if resp.Size() != largeSize {
|
||||
t.Errorf("expected response size: %d, got: %d", largeSize, resp.Size())
|
||||
}
|
||||
}, grabtest.ContentLength(int(largeSize)))
|
||||
|
||||
// download new, smaller version of same file
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
req.NoResume = true
|
||||
resp := mustDo(req)
|
||||
if resp.Size() != smallSize {
|
||||
t.Errorf("expected response size: %d, got: %d", smallSize, resp.Size())
|
||||
}
|
||||
|
||||
// local file should have truncated and not resumed
|
||||
if resp.DidResume {
|
||||
t.Errorf("expected download to truncate, resumed instead")
|
||||
}
|
||||
}, grabtest.ContentLength(int(smallSize)))
|
||||
|
||||
fi, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if fi.Size() != int64(smallSize) {
|
||||
t.Errorf("expected file size %d, got %d", smallSize, fi.Size())
|
||||
}
|
||||
}
|
||||
|
||||
// TestHeadBadStatus validates that HEAD requests that return non-200 can be
|
||||
// ignored and succeed if the GET requests succeeeds.
|
||||
//
|
||||
// Fixes: https://github.com/cavaliercoder/grab/issues/43
|
||||
func TestHeadBadStatus(t *testing.T) {
|
||||
expect := http.StatusOK
|
||||
filename := ".testIssue43"
|
||||
|
||||
statusFunc := func(r *http.Request) int {
|
||||
if r.Method == "HEAD" {
|
||||
return http.StatusForbidden
|
||||
}
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
testURL := fmt.Sprintf("%s/%s", url, filename)
|
||||
resp := mustDo(mustNewRequest("", testURL))
|
||||
if resp.HTTPResponse.StatusCode != expect {
|
||||
t.Errorf(
|
||||
"expected status code: %d, got:% d",
|
||||
expect,
|
||||
resp.HTTPResponse.StatusCode)
|
||||
}
|
||||
},
|
||||
grabtest.StatusCode(statusFunc),
|
||||
)
|
||||
}
|
||||
|
||||
// TestMissingContentLength ensures that the Response.Size is correct for
|
||||
// transfers where the remote server does not send a Content-Length header.
|
||||
//
|
||||
// TestAutoResume also covers cases with checksum validation.
|
||||
//
|
||||
// Kudos to Setnička Jiří <Jiri.Setnicka@ysoft.com> for identifying and raising
|
||||
// a solution to this issue. Ref: https://github.com/cavaliercoder/grab/pull/27
|
||||
func TestMissingContentLength(t *testing.T) {
|
||||
// expectSize must be sufficiently large that DefaultClient.Do won't prefetch
|
||||
// the entire body and compute ContentLength before returning a Response.
|
||||
expectSize := 1048576
|
||||
opts := []grabtest.HandlerOption{
|
||||
grabtest.ContentLength(expectSize),
|
||||
grabtest.HeaderBlacklist("Content-Length"),
|
||||
grabtest.TimeToFirstByte(time.Millisecond * 100), // delay for initial read
|
||||
}
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(".testMissingContentLength", url)
|
||||
req.SetChecksum(
|
||||
md5.New(),
|
||||
grabtest.DefaultHandlerMD5ChecksumBytes,
|
||||
false)
|
||||
resp := DefaultClient.Do(req)
|
||||
|
||||
// ensure remote server is not sending content-length header
|
||||
if v := resp.HTTPResponse.Header.Get("Content-Length"); v != "" {
|
||||
panic(fmt.Sprintf("http header content length must be empty, got: %s", v))
|
||||
}
|
||||
if v := resp.HTTPResponse.ContentLength; v != -1 {
|
||||
panic(fmt.Sprintf("http response content length must be -1, got: %d", v))
|
||||
}
|
||||
|
||||
// before completion, response size should be -1
|
||||
if resp.Size() != -1 {
|
||||
t.Errorf("expected response size: -1, got: %d", resp.Size())
|
||||
}
|
||||
|
||||
// block for completion
|
||||
if err := resp.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// on completion, response size should be actual transfer size
|
||||
if resp.Size() != int64(expectSize) {
|
||||
t.Errorf("expected response size: %d, got: %d", expectSize, resp.Size())
|
||||
}
|
||||
}, opts...)
|
||||
}
|
||||
|
||||
func TestNoStore(t *testing.T) {
|
||||
filename := ".testSubdir/testNoStore"
|
||||
t.Run("DefaultCase", func(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest(filename, url)
|
||||
req.NoStore = true
|
||||
req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true)
|
||||
resp := mustDo(req)
|
||||
|
||||
// ensure Response.Bytes is correct and can be reread
|
||||
b, err := resp.Bytes()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
grabtest.AssertSHA256Sum(
|
||||
t,
|
||||
grabtest.DefaultHandlerSHA256ChecksumBytes,
|
||||
bytes.NewReader(b),
|
||||
)
|
||||
|
||||
// ensure Response.Open stream is correct and can be reread
|
||||
r, err := resp.Open()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer r.Close()
|
||||
grabtest.AssertSHA256Sum(
|
||||
t,
|
||||
grabtest.DefaultHandlerSHA256ChecksumBytes,
|
||||
r,
|
||||
)
|
||||
|
||||
// Response.Filename should still be set
|
||||
if resp.Filename != filename {
|
||||
t.Errorf("expected Response.Filename: %s, got: %s", filename, resp.Filename)
|
||||
}
|
||||
|
||||
// ensure no files were written
|
||||
paths := []string{
|
||||
filename,
|
||||
filepath.Base(filename),
|
||||
filepath.Dir(filename),
|
||||
resp.Filename,
|
||||
filepath.Base(resp.Filename),
|
||||
filepath.Dir(resp.Filename),
|
||||
}
|
||||
for _, path := range paths {
|
||||
_, err := os.Stat(path)
|
||||
if !os.IsNotExist(err) {
|
||||
t.Errorf(
|
||||
"expect error: %v, got: %v, for path: %s",
|
||||
os.ErrNotExist,
|
||||
err,
|
||||
path)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ChecksumValidation", func(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
req := mustNewRequest("", url)
|
||||
req.NoStore = true
|
||||
req.SetChecksum(
|
||||
md5.New(),
|
||||
grabtest.MustHexDecodeString("deadbeefcafebabe"),
|
||||
true)
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != ErrBadChecksum {
|
||||
t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
1
grab/cmd/grab/.gitignore
vendored
Normal file
1
grab/cmd/grab/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
grab
|
18
grab/cmd/grab/Makefile
Normal file
18
grab/cmd/grab/Makefile
Normal file
@ -0,0 +1,18 @@
|
||||
SOURCES = main.go
|
||||
|
||||
all : grab
|
||||
|
||||
grab: $(SOURCES)
|
||||
go build -x -o grab $(SOURCES)
|
||||
|
||||
clean:
|
||||
go clean -x
|
||||
rm -vf grab
|
||||
|
||||
check:
|
||||
go test -v .
|
||||
|
||||
install:
|
||||
go install -v .
|
||||
|
||||
.PHONY: all clean check install
|
34
grab/cmd/grab/main.go
Normal file
34
grab/cmd/grab/main.go
Normal file
@ -0,0 +1,34 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/cavaliercoder/grab/grabui"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// validate command args
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprintf(os.Stderr, "usage: %s url...\n", os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
urls := os.Args[1:]
|
||||
|
||||
// download files
|
||||
respch, err := grabui.GetBatch(context.Background(), 0, ".", urls...)
|
||||
if err != nil {
|
||||
fmt.Fprint(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// return the number of failed downloads as exit code
|
||||
failed := 0
|
||||
for resp := range respch {
|
||||
if resp.Err() != nil {
|
||||
failed++
|
||||
}
|
||||
}
|
||||
os.Exit(failed)
|
||||
}
|
63
grab/doc.go
Normal file
63
grab/doc.go
Normal file
@ -0,0 +1,63 @@
|
||||
/*
|
||||
Package grab provides a HTTP download manager implementation.
|
||||
|
||||
Get is the most simple way to download a file:
|
||||
|
||||
resp, err := grab.Get("/tmp", "http://example.com/example.zip")
|
||||
// ...
|
||||
|
||||
Get will download the given URL and save it to the given destination directory.
|
||||
The destination filename will be determined automatically by grab using
|
||||
Content-Disposition headers returned by the remote server, or by inspecting the
|
||||
requested URL path.
|
||||
|
||||
An empty destination string or "." means the transfer will be stored in the
|
||||
current working directory.
|
||||
|
||||
If a destination file already exists, grab will assume it is a complete or
|
||||
partially complete download of the requested file. If the remote server supports
|
||||
resuming interrupted downloads, grab will resume downloading from the end of the
|
||||
partial file. If the server does not support resumed downloads, the file will be
|
||||
retransferred in its entirety. If the file is already complete, grab will return
|
||||
successfully.
|
||||
|
||||
For control over the HTTP client, destination path, auto-resume, checksum
|
||||
validation and other settings, create a Client:
|
||||
|
||||
client := grab.NewClient()
|
||||
client.HTTPClient.Transport.DisableCompression = true
|
||||
|
||||
req, err := grab.NewRequest("/tmp", "http://example.com/example.zip")
|
||||
// ...
|
||||
req.NoResume = true
|
||||
req.HTTPRequest.Header.Set("Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l")
|
||||
|
||||
resp := client.Do(req)
|
||||
// ...
|
||||
|
||||
You can monitor the progress of downloads while they are transferring:
|
||||
|
||||
client := grab.NewClient()
|
||||
req, err := grab.NewRequest("", "http://example.com/example.zip")
|
||||
// ...
|
||||
resp := client.Do(req)
|
||||
|
||||
t := time.NewTicker(time.Second)
|
||||
defer t.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
fmt.Printf("%.02f%% complete\n", resp.Progress())
|
||||
|
||||
case <-resp.Done:
|
||||
if err := resp.Err(); err != nil {
|
||||
// ...
|
||||
}
|
||||
|
||||
// ...
|
||||
return
|
||||
}
|
||||
}
|
||||
*/
|
||||
package grab
|
42
grab/error.go
Normal file
42
grab/error.go
Normal file
@ -0,0 +1,42 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrBadLength indicates that the server response or an existing file does
|
||||
// not match the expected content length.
|
||||
ErrBadLength = errors.New("bad content length")
|
||||
|
||||
// ErrBadChecksum indicates that a downloaded file failed to pass checksum
|
||||
// validation.
|
||||
ErrBadChecksum = errors.New("checksum mismatch")
|
||||
|
||||
// ErrNoFilename indicates that a reasonable filename could not be
|
||||
// automatically determined using the URL or response headers from a server.
|
||||
ErrNoFilename = errors.New("no filename could be determined")
|
||||
|
||||
// ErrNoTimestamp indicates that a timestamp could not be automatically
|
||||
// determined using the response headers from the remote server.
|
||||
ErrNoTimestamp = errors.New("no timestamp could be determined for the remote file")
|
||||
|
||||
// ErrFileExists indicates that the destination path already exists.
|
||||
ErrFileExists = errors.New("file exists")
|
||||
)
|
||||
|
||||
// StatusCodeError indicates that the server response had a status code that
|
||||
// was not in the 200-299 range (after following any redirects).
|
||||
type StatusCodeError int
|
||||
|
||||
func (err StatusCodeError) Error() string {
|
||||
return fmt.Sprintf("server returned %d %s", err, http.StatusText(int(err)))
|
||||
}
|
||||
|
||||
// IsStatusCodeError returns true if the given error is of type StatusCodeError.
|
||||
func IsStatusCodeError(err error) bool {
|
||||
_, ok := err.(StatusCodeError)
|
||||
return ok
|
||||
}
|
95
grab/example_client_test.go
Normal file
95
grab/example_client_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func ExampleClient_Do() {
|
||||
client := NewClient()
|
||||
req, err := NewRequest("/tmp", "http://example.com/example.zip")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
resp := client.Do(req)
|
||||
if err := resp.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println("Download saved to", resp.Filename)
|
||||
}
|
||||
|
||||
// This example uses DoChannel to create a Producer/Consumer model for
|
||||
// downloading multiple files concurrently. This is similar to how DoBatch uses
|
||||
// DoChannel under the hood except that it allows the caller to continually send
|
||||
// new requests until they wish to close the request channel.
|
||||
func ExampleClient_DoChannel() {
|
||||
// create a request and a buffered response channel
|
||||
reqch := make(chan *Request)
|
||||
respch := make(chan *Response, 10)
|
||||
|
||||
// start 4 workers
|
||||
client := NewClient()
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 4; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
client.DoChannel(reqch, respch)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
// send requests
|
||||
for i := 0; i < 10; i++ {
|
||||
url := fmt.Sprintf("http://example.com/example%d.zip", i+1)
|
||||
req, err := NewRequest("/tmp", url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
reqch <- req
|
||||
}
|
||||
close(reqch)
|
||||
|
||||
// wait for workers to finish
|
||||
wg.Wait()
|
||||
close(respch)
|
||||
}()
|
||||
|
||||
// check each response
|
||||
for resp := range respch {
|
||||
// block until complete
|
||||
if err := resp.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleClient_DoBatch() {
|
||||
// create multiple download requests
|
||||
reqs := make([]*Request, 0)
|
||||
for i := 0; i < 10; i++ {
|
||||
url := fmt.Sprintf("http://example.com/example%d.zip", i+1)
|
||||
req, err := NewRequest("/tmp", url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
reqs = append(reqs, req)
|
||||
}
|
||||
|
||||
// start downloads with 4 workers
|
||||
client := NewClient()
|
||||
respch := client.DoBatch(4, reqs...)
|
||||
|
||||
// check each response
|
||||
for resp := range respch {
|
||||
if err := resp.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename)
|
||||
}
|
||||
}
|
52
grab/example_request_test.go
Normal file
52
grab/example_request_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func ExampleRequest_WithContext() {
|
||||
// create context with a 100ms timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// create download request with context
|
||||
req, err := NewRequest("", "http://example.com/example.zip")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// send download request
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != nil {
|
||||
fmt.Println("error: request cancelled")
|
||||
}
|
||||
|
||||
// Output:
|
||||
// error: request cancelled
|
||||
}
|
||||
|
||||
func ExampleRequest_SetChecksum() {
|
||||
// create download request
|
||||
req, err := NewRequest("", "http://example.com/example.zip")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// set request checksum
|
||||
sum, err := hex.DecodeString("33daf4c03f86120fdfdc66bddf6bfff4661c7ca11c5da473e537f4d69b470e57")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.SetChecksum(sha256.New(), sum, true)
|
||||
|
||||
// download and validate file
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
5
grab/go.mod
Normal file
5
grab/go.mod
Normal file
@ -0,0 +1,5 @@
|
||||
module github.com/cavaliercoder/grab
|
||||
|
||||
go 1.14
|
||||
|
||||
require github.com/lordwelch/pathvalidate v0.0.0-20201012043703-54efa7ea1308
|
5
grab/go.sum
Normal file
5
grab/go.sum
Normal file
@ -0,0 +1,5 @@
|
||||
github.com/lordwelch/pathvalidate v0.0.0-20201012043703-54efa7ea1308 h1:CkcsZK6QYg59rc92eqU2h+FRjWltCIiplmEwIB05jfM=
|
||||
github.com/lordwelch/pathvalidate v0.0.0-20201012043703-54efa7ea1308/go.mod h1:4I4r5Y/LkH+34KACiudU+Q27ooz7xSDyVEuWAVKeJEQ=
|
||||
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
64
grab/grab.go
Normal file
64
grab/grab.go
Normal file
@ -0,0 +1,64 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Get sends a HTTP request and downloads the content of the requested URL to
|
||||
// the given destination file path. The caller is blocked until the download is
|
||||
// completed, successfully or otherwise.
|
||||
//
|
||||
// An error is returned if caused by client policy (such as CheckRedirect), or
|
||||
// if there was an HTTP protocol or IO error.
|
||||
//
|
||||
// For non-blocking calls or control over HTTP client headers, redirect policy,
|
||||
// and other settings, create a Client instead.
|
||||
func Get(dst, urlStr string) (*Response, error) {
|
||||
req, err := NewRequest(dst, urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := DefaultClient.Do(req)
|
||||
return resp, resp.Err()
|
||||
}
|
||||
|
||||
// GetBatch sends multiple HTTP requests and downloads the content of the
|
||||
// requested URLs to the given destination directory using the given number of
|
||||
// concurrent worker goroutines.
|
||||
//
|
||||
// The Response for each requested URL is sent through the returned Response
|
||||
// channel, as soon as a worker receives a response from the remote server. The
|
||||
// Response can then be used to track the progress of the download while it is
|
||||
// in progress.
|
||||
//
|
||||
// The returned Response channel will be closed by Grab, only once all downloads
|
||||
// have completed or failed.
|
||||
//
|
||||
// If an error occurs during any download, it will be available via call to the
|
||||
// associated Response.Err.
|
||||
//
|
||||
// For control over HTTP client headers, redirect policy, and other settings,
|
||||
// create a Client instead.
|
||||
func GetBatch(workers int, dst string, urlStrs ...string) (<-chan *Response, error) {
|
||||
fi, err := os.Stat(dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !fi.IsDir() {
|
||||
return nil, fmt.Errorf("destination is not a directory")
|
||||
}
|
||||
|
||||
reqs := make([]*Request, len(urlStrs))
|
||||
for i := 0; i < len(urlStrs); i++ {
|
||||
req, err := NewRequest(dst, urlStrs[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqs[i] = req
|
||||
}
|
||||
|
||||
ch := DefaultClient.DoBatch(workers, reqs...)
|
||||
return ch, nil
|
||||
}
|
74
grab/grab_test.go
Normal file
74
grab/grab_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cavaliercoder/grab/grabtest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
// chdir to temp so test files downloaded to pwd are isolated and cleaned up
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tmpDir, err := ioutil.TempDir("", "grab-")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := os.Chdir(tmpDir); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
os.Chdir(cwd)
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
// TestGet tests grab.Get
|
||||
func TestGet(t *testing.T) {
|
||||
filename := ".testGet"
|
||||
defer os.Remove(filename)
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
resp, err := Get(filename, url)
|
||||
if err != nil {
|
||||
t.Fatalf("error in Get(): %v", err)
|
||||
}
|
||||
testComplete(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func ExampleGet() {
|
||||
// download a file to /tmp
|
||||
resp, err := Get("/tmp", "http://example.com/example.zip")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println("Download saved to", resp.Filename)
|
||||
}
|
||||
|
||||
func mustNewRequest(dst, urlStr string) *Request {
|
||||
req, err := NewRequest(dst, urlStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func mustDo(req *Request) *Response {
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return resp
|
||||
}
|
104
grab/grabtest/assert.go
Normal file
104
grab/grabtest/assert.go
Normal file
@ -0,0 +1,104 @@
|
||||
package grabtest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func AssertHTTPResponseStatusCode(t *testing.T, resp *http.Response, expect int) (ok bool) {
|
||||
if resp.StatusCode != expect {
|
||||
t.Errorf("expected status code: %d, got: %d", expect, resp.StatusCode)
|
||||
return
|
||||
}
|
||||
ok = true
|
||||
return true
|
||||
}
|
||||
|
||||
func AssertHTTPResponseHeader(t *testing.T, resp *http.Response, key, format string, a ...interface{}) (ok bool) {
|
||||
expect := fmt.Sprintf(format, a...)
|
||||
actual := resp.Header.Get(key)
|
||||
if actual != expect {
|
||||
t.Errorf("expected header %s: %s, got: %s", key, expect, actual)
|
||||
return
|
||||
}
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func AssertHTTPResponseContentLength(t *testing.T, resp *http.Response, n int64) (ok bool) {
|
||||
ok = true
|
||||
if resp.ContentLength != n {
|
||||
ok = false
|
||||
t.Errorf("expected header Content-Length: %d, got: %d", n, resp.ContentLength)
|
||||
}
|
||||
if !AssertHTTPResponseBodyLength(t, resp, n) {
|
||||
ok = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func AssertHTTPResponseBodyLength(t *testing.T, resp *http.Response, n int64) (ok bool) {
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if int64(len(b)) != n {
|
||||
ok = false
|
||||
t.Errorf("expected body length: %d, got: %d", n, len(b))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func MustHTTPNewRequest(method, url string, body io.Reader) *http.Request {
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func MustHTTPDo(req *http.Request) *http.Response {
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func MustHTTPDoWithClose(req *http.Request) *http.Response {
|
||||
resp := MustHTTPDo(req)
|
||||
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func AssertSHA256Sum(t *testing.T, sum []byte, r io.Reader) (ok bool) {
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, r); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
computed := h.Sum(nil)
|
||||
ok = bytes.Equal(sum, computed)
|
||||
if !ok {
|
||||
t.Errorf(
|
||||
"expected checksum: %s, got: %s",
|
||||
MustHexEncodeString(sum),
|
||||
MustHexEncodeString(computed),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
160
grab/grabtest/handler.go
Normal file
160
grab/grabtest/handler.go
Normal file
@ -0,0 +1,160 @@
|
||||
package grabtest
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultHandlerContentLength = 1 << 20
|
||||
DefaultHandlerMD5Checksum = "c35cc7d8d91728a0cb052831bc4ef372"
|
||||
DefaultHandlerMD5ChecksumBytes = MustHexDecodeString(DefaultHandlerMD5Checksum)
|
||||
DefaultHandlerSHA256Checksum = "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83"
|
||||
DefaultHandlerSHA256ChecksumBytes = MustHexDecodeString(DefaultHandlerSHA256Checksum)
|
||||
)
|
||||
|
||||
type StatusCodeFunc func(req *http.Request) int
|
||||
|
||||
type handler struct {
|
||||
statusCodeFunc StatusCodeFunc
|
||||
methodWhitelist []string
|
||||
headerBlacklist []string
|
||||
contentLength int
|
||||
acceptRanges bool
|
||||
attachmentFilename string
|
||||
lastModified time.Time
|
||||
ttfb time.Duration
|
||||
rateLimiter *time.Ticker
|
||||
}
|
||||
|
||||
func NewHandler(options ...HandlerOption) (http.Handler, error) {
|
||||
h := &handler{
|
||||
statusCodeFunc: func(req *http.Request) int { return http.StatusOK },
|
||||
methodWhitelist: []string{"GET", "HEAD"},
|
||||
contentLength: DefaultHandlerContentLength,
|
||||
acceptRanges: true,
|
||||
}
|
||||
for _, option := range options {
|
||||
if err := option(h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) {
|
||||
h, err := NewHandler(options...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test server handler: %v", err)
|
||||
return
|
||||
}
|
||||
s := httptest.NewServer(h)
|
||||
defer func() {
|
||||
h.(*handler).close()
|
||||
s.Close()
|
||||
}()
|
||||
f(s.URL)
|
||||
}
|
||||
|
||||
func (h *handler) close() {
|
||||
if h.rateLimiter != nil {
|
||||
h.rateLimiter.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// delay response
|
||||
if h.ttfb > 0 {
|
||||
time.Sleep(h.ttfb)
|
||||
}
|
||||
|
||||
// validate request method
|
||||
allowed := false
|
||||
for _, m := range h.methodWhitelist {
|
||||
if r.Method == m {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
httpError(w, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// set server options
|
||||
if h.acceptRanges {
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
}
|
||||
|
||||
// set attachment filename
|
||||
if h.attachmentFilename != "" {
|
||||
w.Header().Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf("attachment;filename=\"%s\"", h.attachmentFilename),
|
||||
)
|
||||
}
|
||||
|
||||
// set last modified timestamp
|
||||
lastMod := time.Now()
|
||||
if !h.lastModified.IsZero() {
|
||||
lastMod = h.lastModified
|
||||
}
|
||||
w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat))
|
||||
|
||||
// set content-length
|
||||
offset := 0
|
||||
if h.acceptRanges {
|
||||
if reqRange := r.Header.Get("Range"); reqRange != "" {
|
||||
if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil {
|
||||
httpError(w, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if offset >= h.contentLength {
|
||||
httpError(w, http.StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset))
|
||||
|
||||
// apply header blacklist
|
||||
for _, key := range h.headerBlacklist {
|
||||
w.Header().Del(key)
|
||||
}
|
||||
|
||||
// send header and status code
|
||||
w.WriteHeader(h.statusCodeFunc(r))
|
||||
|
||||
// send body
|
||||
if r.Method == "GET" {
|
||||
// use buffered io to reduce overhead on the reader
|
||||
bw := bufio.NewWriterSize(w, 4096)
|
||||
for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ {
|
||||
bw.Write([]byte{byte(i)})
|
||||
if h.rateLimiter != nil {
|
||||
bw.Flush()
|
||||
w.(http.Flusher).Flush() // force the server to send the data to the client
|
||||
select {
|
||||
case <-h.rateLimiter.C:
|
||||
case <-r.Context().Done():
|
||||
}
|
||||
}
|
||||
}
|
||||
if !isRequestClosed(r) {
|
||||
bw.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isRequestClosed returns true if the client request has been canceled.
|
||||
func isRequestClosed(r *http.Request) bool {
|
||||
return r.Context().Err() != nil
|
||||
}
|
||||
|
||||
func httpError(w http.ResponseWriter, code int) {
|
||||
http.Error(w, http.StatusText(code), code)
|
||||
}
|
92
grab/grabtest/handler_option.go
Normal file
92
grab/grabtest/handler_option.go
Normal file
@ -0,0 +1,92 @@
|
||||
package grabtest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HandlerOption func(*handler) error
|
||||
|
||||
func StatusCodeStatic(code int) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
return StatusCode(func(req *http.Request) int {
|
||||
return code
|
||||
})(h)
|
||||
}
|
||||
}
|
||||
|
||||
func StatusCode(f StatusCodeFunc) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
if f == nil {
|
||||
return errors.New("status code function cannot be nil")
|
||||
}
|
||||
h.statusCodeFunc = f
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func MethodWhitelist(methods ...string) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
h.methodWhitelist = methods
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func HeaderBlacklist(headers ...string) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
h.headerBlacklist = headers
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func ContentLength(n int) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
if n < 0 {
|
||||
return errors.New("content length must be zero or greater")
|
||||
}
|
||||
h.contentLength = n
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func AcceptRanges(enabled bool) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
h.acceptRanges = enabled
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func LastModified(t time.Time) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
h.lastModified = t.UTC()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TimeToFirstByte(d time.Duration) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
if d < 1 {
|
||||
return errors.New("time to first byte must be greater than zero")
|
||||
}
|
||||
h.ttfb = d
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimiter(bps int) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
if bps < 1 {
|
||||
return errors.New("bytes per second must be greater than zero")
|
||||
}
|
||||
h.rateLimiter = time.NewTicker(time.Second / time.Duration(bps))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func AttachmentFilename(filename string) HandlerOption {
|
||||
return func(h *handler) error {
|
||||
h.attachmentFilename = filename
|
||||
return nil
|
||||
}
|
||||
}
|
150
grab/grabtest/handler_test.go
Normal file
150
grab/grabtest/handler_test.go
Normal file
@ -0,0 +1,150 @@
|
||||
package grabtest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHandlerDefaults(t *testing.T) {
|
||||
WithTestServer(t, func(url string) {
|
||||
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
|
||||
AssertHTTPResponseStatusCode(t, resp, http.StatusOK)
|
||||
AssertHTTPResponseContentLength(t, resp, 1048576)
|
||||
AssertHTTPResponseHeader(t, resp, "Accept-Ranges", "bytes")
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandlerMethodWhitelist(t *testing.T) {
|
||||
tests := []struct {
|
||||
Whitelist []string
|
||||
Method string
|
||||
ExpectStatusCode int
|
||||
}{
|
||||
{[]string{"GET", "HEAD"}, "GET", http.StatusOK},
|
||||
{[]string{"GET", "HEAD"}, "HEAD", http.StatusOK},
|
||||
{[]string{"GET"}, "HEAD", http.StatusMethodNotAllowed},
|
||||
{[]string{"HEAD"}, "GET", http.StatusMethodNotAllowed},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
WithTestServer(t, func(url string) {
|
||||
resp := MustHTTPDoWithClose(MustHTTPNewRequest(test.Method, url, nil))
|
||||
AssertHTTPResponseStatusCode(t, resp, test.ExpectStatusCode)
|
||||
}, MethodWhitelist(test.Whitelist...))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerHeaderBlacklist(t *testing.T) {
|
||||
contentLength := 4096
|
||||
WithTestServer(t, func(url string) {
|
||||
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
|
||||
defer resp.Body.Close()
|
||||
if resp.ContentLength != -1 {
|
||||
t.Errorf("expected Response.ContentLength: -1, got: %d", resp.ContentLength)
|
||||
}
|
||||
AssertHTTPResponseHeader(t, resp, "Content-Length", "")
|
||||
AssertHTTPResponseBodyLength(t, resp, int64(contentLength))
|
||||
},
|
||||
ContentLength(contentLength),
|
||||
HeaderBlacklist("Content-Length"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHandlerStatusCodeFuncs(t *testing.T) {
|
||||
expect := 418 // I'm a teapot
|
||||
WithTestServer(t, func(url string) {
|
||||
resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil))
|
||||
AssertHTTPResponseStatusCode(t, resp, expect)
|
||||
},
|
||||
StatusCode(func(req *http.Request) int { return expect }),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHandlerContentLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
Method string
|
||||
ContentLength int
|
||||
ExpectHeaderLen int64
|
||||
ExpectBodyLen int
|
||||
}{
|
||||
{"GET", 321, 321, 321},
|
||||
{"HEAD", 321, 321, 0},
|
||||
{"GET", 0, 0, 0},
|
||||
{"HEAD", 0, 0, 0},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
WithTestServer(t, func(url string) {
|
||||
resp := MustHTTPDo(MustHTTPNewRequest(test.Method, url, nil))
|
||||
defer resp.Body.Close()
|
||||
|
||||
AssertHTTPResponseHeader(t, resp, "Content-Length", "%d", test.ExpectHeaderLen)
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(b) != test.ExpectBodyLen {
|
||||
t.Errorf(
|
||||
"expected body length: %v, got: %v, in: %v",
|
||||
test.ExpectBodyLen,
|
||||
len(b),
|
||||
test,
|
||||
)
|
||||
}
|
||||
},
|
||||
ContentLength(test.ContentLength),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerAcceptRanges(t *testing.T) {
|
||||
header := "Accept-Ranges"
|
||||
n := 128
|
||||
t.Run("Enabled", func(t *testing.T) {
|
||||
WithTestServer(t, func(url string) {
|
||||
req := MustHTTPNewRequest("GET", url, nil)
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2))
|
||||
resp := MustHTTPDo(req)
|
||||
AssertHTTPResponseHeader(t, resp, header, "bytes")
|
||||
AssertHTTPResponseContentLength(t, resp, int64(n/2))
|
||||
},
|
||||
ContentLength(n),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
WithTestServer(t, func(url string) {
|
||||
req := MustHTTPNewRequest("GET", url, nil)
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2))
|
||||
resp := MustHTTPDo(req)
|
||||
AssertHTTPResponseHeader(t, resp, header, "")
|
||||
AssertHTTPResponseContentLength(t, resp, int64(n))
|
||||
},
|
||||
AcceptRanges(false),
|
||||
ContentLength(n),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandlerAttachmentFilename(t *testing.T) {
|
||||
filename := "foo.pdf"
|
||||
WithTestServer(t, func(url string) {
|
||||
resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil))
|
||||
AssertHTTPResponseHeader(t, resp, "Content-Disposition", `attachment;filename="%s"`, filename)
|
||||
},
|
||||
AttachmentFilename(filename),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHandlerLastModified(t *testing.T) {
|
||||
WithTestServer(t, func(url string) {
|
||||
resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil))
|
||||
AssertHTTPResponseHeader(t, resp, "Last-Modified", "Thu, 29 Nov 1973 21:33:09 GMT")
|
||||
},
|
||||
LastModified(time.Unix(123456789, 0)),
|
||||
)
|
||||
}
|
16
grab/grabtest/util.go
Normal file
16
grab/grabtest/util.go
Normal file
@ -0,0 +1,16 @@
|
||||
package grabtest
|
||||
|
||||
import "encoding/hex"
|
||||
|
||||
func MustHexDecodeString(s string) (b []byte) {
|
||||
var err error
|
||||
b, err = hex.DecodeString(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func MustHexEncodeString(b []byte) (s string) {
|
||||
return hex.EncodeToString(b)
|
||||
}
|
166
grab/grabui/console_client.go
Normal file
166
grab/grabui/console_client.go
Normal file
@ -0,0 +1,166 @@
|
||||
package grabui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cavaliercoder/grab"
|
||||
)
|
||||
|
||||
type ConsoleClient struct {
|
||||
mu sync.Mutex
|
||||
client *grab.Client
|
||||
succeeded, failed, inProgress int
|
||||
responses []*grab.Response
|
||||
}
|
||||
|
||||
func NewConsoleClient(client *grab.Client) *ConsoleClient {
|
||||
return &ConsoleClient{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConsoleClient) Do(
|
||||
ctx context.Context,
|
||||
workers int,
|
||||
reqs ...*grab.Request,
|
||||
) <-chan *grab.Response {
|
||||
// buffer size prevents slow receivers causing back pressure
|
||||
pump := make(chan *grab.Response, len(reqs))
|
||||
|
||||
go func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.failed = 0
|
||||
c.inProgress = 0
|
||||
c.succeeded = 0
|
||||
c.responses = make([]*grab.Response, 0, len(reqs))
|
||||
if c.client == nil {
|
||||
c.client = grab.DefaultClient
|
||||
}
|
||||
|
||||
fmt.Printf("Downloading %d files...\n", len(reqs))
|
||||
respch := c.client.DoBatch(workers, reqs...)
|
||||
t := time.NewTicker(200 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break Loop
|
||||
|
||||
case resp := <-respch:
|
||||
if resp != nil {
|
||||
// a new response has been received and has started downloading
|
||||
c.responses = append(c.responses, resp)
|
||||
pump <- resp // send to caller
|
||||
} else {
|
||||
// channel is closed - all downloads are complete
|
||||
break Loop
|
||||
}
|
||||
|
||||
case <-t.C:
|
||||
// update UI on clock tick
|
||||
c.refresh()
|
||||
}
|
||||
}
|
||||
|
||||
c.refresh()
|
||||
close(pump)
|
||||
|
||||
fmt.Printf(
|
||||
"Finished %d successful, %d failed, %d incomplete.\n",
|
||||
c.succeeded,
|
||||
c.failed,
|
||||
c.inProgress)
|
||||
}()
|
||||
return pump
|
||||
}
|
||||
|
||||
// refresh prints the progress of all downloads to the terminal
|
||||
func (c *ConsoleClient) refresh() {
|
||||
// clear lines for incomplete downloads
|
||||
if c.inProgress > 0 {
|
||||
fmt.Printf("\033[%dA\033[K", c.inProgress)
|
||||
}
|
||||
|
||||
// print newly completed downloads
|
||||
for i, resp := range c.responses {
|
||||
if resp != nil && resp.IsComplete() {
|
||||
if resp.Err() != nil {
|
||||
c.failed++
|
||||
fmt.Fprintf(os.Stderr, "Error downloading %s: %v\n",
|
||||
resp.Request.URL(),
|
||||
resp.Err())
|
||||
} else {
|
||||
c.succeeded++
|
||||
fmt.Printf("Finished %s %s / %s (%d%%)\n",
|
||||
resp.Filename,
|
||||
byteString(resp.BytesComplete()),
|
||||
byteString(resp.Size()),
|
||||
int(100*resp.Progress()))
|
||||
}
|
||||
c.responses[i] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// print progress for incomplete downloads
|
||||
c.inProgress = 0
|
||||
for _, resp := range c.responses {
|
||||
if resp != nil {
|
||||
fmt.Printf("Downloading %s %s / %s (%d%%) - %s ETA: %s \033[K\n",
|
||||
resp.Filename,
|
||||
byteString(resp.BytesComplete()),
|
||||
byteString(resp.Size()),
|
||||
int(100*resp.Progress()),
|
||||
bpsString(resp.BytesPerSecond()),
|
||||
etaString(resp.ETA()))
|
||||
c.inProgress++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bpsString(n float64) string {
|
||||
if n < 1e3 {
|
||||
return fmt.Sprintf("%.02fBps", n)
|
||||
}
|
||||
if n < 1e6 {
|
||||
return fmt.Sprintf("%.02fKB/s", n/1e3)
|
||||
}
|
||||
if n < 1e9 {
|
||||
return fmt.Sprintf("%.02fMB/s", n/1e6)
|
||||
}
|
||||
return fmt.Sprintf("%.02fGB/s", n/1e9)
|
||||
}
|
||||
|
||||
func byteString(n int64) string {
|
||||
if n < 1<<10 {
|
||||
return fmt.Sprintf("%dB", n)
|
||||
}
|
||||
if n < 1<<20 {
|
||||
return fmt.Sprintf("%dKB", n>>10)
|
||||
}
|
||||
if n < 1<<30 {
|
||||
return fmt.Sprintf("%dMB", n>>20)
|
||||
}
|
||||
if n < 1<<40 {
|
||||
return fmt.Sprintf("%dGB", n>>30)
|
||||
}
|
||||
return fmt.Sprintf("%dTB", n>>40)
|
||||
}
|
||||
|
||||
func etaString(eta time.Time) string {
|
||||
d := eta.Sub(time.Now())
|
||||
if d < time.Second {
|
||||
return "<1s"
|
||||
}
|
||||
// truncate to 1s resolution
|
||||
d /= time.Second
|
||||
d *= time.Second
|
||||
return d.String()
|
||||
}
|
27
grab/grabui/grabui.go
Normal file
27
grab/grabui/grabui.go
Normal file
@ -0,0 +1,27 @@
|
||||
package grabui
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cavaliercoder/grab"
|
||||
)
|
||||
|
||||
func GetBatch(
|
||||
ctx context.Context,
|
||||
workers int,
|
||||
dst string,
|
||||
urlStrs ...string,
|
||||
) (<-chan *grab.Response, error) {
|
||||
reqs := make([]*grab.Request, len(urlStrs))
|
||||
for i := 0; i < len(urlStrs); i++ {
|
||||
req, err := grab.NewRequest(dst, urlStrs[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
reqs[i] = req
|
||||
}
|
||||
|
||||
ui := NewConsoleClient(grab.DefaultClient)
|
||||
return ui.Do(ctx, workers, reqs...), nil
|
||||
}
|
12
grab/rate_limiter.go
Normal file
12
grab/rate_limiter.go
Normal file
@ -0,0 +1,12 @@
|
||||
package grab
|
||||
|
||||
import "context"
|
||||
|
||||
// RateLimiter is an interface that must be satisfied by any third-party rate
|
||||
// limiters that may be used to limit download transfer speeds.
|
||||
//
|
||||
// A recommended token bucket implementation can be found at
|
||||
// https://godoc.org/golang.org/x/time/rate#Limiter.
|
||||
type RateLimiter interface {
|
||||
WaitN(ctx context.Context, n int) (err error)
|
||||
}
|
69
grab/rate_limiter_test.go
Normal file
69
grab/rate_limiter_test.go
Normal file
@ -0,0 +1,69 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cavaliercoder/grab/grabtest"
|
||||
)
|
||||
|
||||
// testRateLimiter is a naive rate limiter that limits throughput to r tokens
|
||||
// per second. The total number of tokens issued is tracked as n.
|
||||
type testRateLimiter struct {
|
||||
r, n int
|
||||
}
|
||||
|
||||
func NewLimiter(r int) RateLimiter {
|
||||
return &testRateLimiter{r: r}
|
||||
}
|
||||
|
||||
func (c *testRateLimiter) WaitN(ctx context.Context, n int) (err error) {
|
||||
c.n += n
|
||||
time.Sleep(
|
||||
time.Duration(1.00 / float64(c.r) * float64(n) * float64(time.Second)))
|
||||
return
|
||||
}
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
// download a 128 byte file, 8 bytes at a time, with a naive 512bps limiter
|
||||
// should take > 250ms
|
||||
filesize := 128
|
||||
filename := ".testRateLimiter"
|
||||
defer os.Remove(filename)
|
||||
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
// limit to 512bps
|
||||
lim := &testRateLimiter{r: 512}
|
||||
req := mustNewRequest(filename, url)
|
||||
|
||||
// ensure multiple trips to the rate limiter by downloading 8 bytes at a time
|
||||
req.BufferSize = 8
|
||||
req.RateLimiter = lim
|
||||
|
||||
resp := mustDo(req)
|
||||
testComplete(t, resp)
|
||||
if lim.n != filesize {
|
||||
t.Errorf("expected %d bytes to pass through limiter, got %d", filesize, lim.n)
|
||||
}
|
||||
if resp.Duration().Seconds() < 0.25 {
|
||||
// BUG: this test can pass if the transfer was slow for unrelated reasons
|
||||
t.Errorf("expected transfer to take >250ms, took %v", resp.Duration())
|
||||
}
|
||||
}, grabtest.ContentLength(filesize))
|
||||
}
|
||||
|
||||
func ExampleRateLimiter() {
|
||||
req, _ := NewRequest("", "http://www.golang-book.com/public/pdf/gobook.pdf")
|
||||
|
||||
// Attach a 1Mbps rate limiter, like the token bucket implementation from
|
||||
// golang.org/x/time/rate.
|
||||
req.RateLimiter = NewLimiter(1048576)
|
||||
|
||||
resp := DefaultClient.Do(req)
|
||||
if err := resp.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
177
grab/request.go
Normal file
177
grab/request.go
Normal file
@ -0,0 +1,177 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hash"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// A Hook is a user provided callback function that can be called by grab at
|
||||
// various stages of a requests lifecycle. If a hook returns an error, the
|
||||
// associated request is canceled and the same error is returned on the Response
|
||||
// object.
|
||||
//
|
||||
// Hook functions are called synchronously and should never block unnecessarily.
|
||||
// Response methods that block until a download is complete, such as
|
||||
// Response.Err, Response.Cancel or Response.Wait will deadlock. To cancel a
|
||||
// download from a callback, simply return a non-nil error.
|
||||
type Hook func(*Response) error
|
||||
|
||||
// A Request represents an HTTP file transfer request to be sent by a Client.
|
||||
type Request struct {
|
||||
// Label is an arbitrary string which may used to label a Request with a
|
||||
// user friendly name.
|
||||
Label string
|
||||
|
||||
// Tag is an arbitrary interface which may be used to relate a Request to
|
||||
// other data.
|
||||
Tag interface{}
|
||||
|
||||
// HTTPRequest specifies the http.Request to be sent to the remote server to
|
||||
// initiate a file transfer. It includes request configuration such as URL,
|
||||
// protocol version, HTTP method, request headers and authentication.
|
||||
HTTPRequest *http.Request
|
||||
|
||||
// Filename specifies the path where the file transfer will be stored in
|
||||
// local storage. If Filename is empty or a directory, the true Filename will
|
||||
// be resolved using Content-Disposition headers or the request URL.
|
||||
//
|
||||
// An empty string means the transfer will be stored in the current working
|
||||
// directory.
|
||||
Filename string
|
||||
|
||||
// SkipExisting specifies that ErrFileExists should be returned if the
|
||||
// destination path already exists. The existing file will not be checked for
|
||||
// completeness.
|
||||
SkipExisting bool
|
||||
|
||||
// NoResume specifies that a partially completed download will be restarted
|
||||
// without attempting to resume any existing file. If the download is already
|
||||
// completed in full, it will not be restarted.
|
||||
NoResume bool
|
||||
|
||||
// NoStore specifies that grab should not write to the local file system.
|
||||
// Instead, the download will be stored in memory and accessible only via
|
||||
// Response.Open or Response.Bytes.
|
||||
NoStore bool
|
||||
|
||||
// NoCreateDirectories specifies that any missing directories in the given
|
||||
// Filename path should not be created automatically, if they do not already
|
||||
// exist.
|
||||
NoCreateDirectories bool
|
||||
|
||||
// IgnoreBadStatusCodes specifies that grab should accept any status code in
|
||||
// the response from the remote server. Otherwise, grab expects the response
|
||||
// status code to be within the 2XX range (after following redirects).
|
||||
IgnoreBadStatusCodes bool
|
||||
|
||||
// IgnoreRemoteTime specifies that grab should not attempt to set the
|
||||
// timestamp of the local file to match the remote file.
|
||||
IgnoreRemoteTime bool
|
||||
|
||||
// Size specifies the expected size of the file transfer if known. If the
|
||||
// server response size does not match, the transfer is cancelled and
|
||||
// ErrBadLength returned.
|
||||
Size int64
|
||||
|
||||
// BufferSize specifies the size in bytes of the buffer that is used for
|
||||
// transferring the requested file. Larger buffers may result in faster
|
||||
// throughput but will use more memory and result in less frequent updates
|
||||
// to the transfer progress statistics. If a RateLimiter is configured,
|
||||
// BufferSize should be much lower than the rate limit. Default: 32KB.
|
||||
BufferSize int
|
||||
|
||||
// RateLimiter allows the transfer rate of a download to be limited. The given
|
||||
// Request.BufferSize determines how frequently the RateLimiter will be
|
||||
// polled.
|
||||
RateLimiter RateLimiter
|
||||
|
||||
// BeforeCopy is a user provided callback that is called immediately before
|
||||
// a request starts downloading. If BeforeCopy returns an error, the request
|
||||
// is cancelled and the same error is returned on the Response object.
|
||||
BeforeCopy Hook
|
||||
|
||||
// AfterCopy is a user provided callback that is called immediately after a
|
||||
// request has finished downloading, before checksum validation and closure.
|
||||
// This hook is only called if the transfer was successful. If AfterCopy
|
||||
// returns an error, the request is canceled and the same error is returned on
|
||||
// the Response object.
|
||||
AfterCopy Hook
|
||||
|
||||
// hash, checksum and deleteOnError - set via SetChecksum.
|
||||
hash hash.Hash
|
||||
checksum []byte
|
||||
deleteOnError bool
|
||||
|
||||
// Context for cancellation and timeout - set via WithContext
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewRequest returns a new file transfer Request suitable for use with
|
||||
// Client.Do.
|
||||
func NewRequest(dst, urlStr string) (*Request, error) {
|
||||
if dst == "" {
|
||||
dst = "."
|
||||
}
|
||||
req, err := http.NewRequest("GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Request{
|
||||
HTTPRequest: req,
|
||||
Filename: dst,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Context returns the request's context. To change the context, use
|
||||
// WithContext.
|
||||
//
|
||||
// The returned context is always non-nil; it defaults to the background
|
||||
// context.
|
||||
//
|
||||
// The context controls cancelation.
|
||||
func (r *Request) Context() context.Context {
|
||||
if r.ctx != nil {
|
||||
return r.ctx
|
||||
}
|
||||
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// WithContext returns a shallow copy of r with its context changed
|
||||
// to ctx. The provided ctx must be non-nil.
|
||||
func (r *Request) WithContext(ctx context.Context) *Request {
|
||||
if ctx == nil {
|
||||
panic("nil context")
|
||||
}
|
||||
r2 := new(Request)
|
||||
*r2 = *r
|
||||
r2.ctx = ctx
|
||||
r2.HTTPRequest = r2.HTTPRequest.WithContext(ctx)
|
||||
return r2
|
||||
}
|
||||
|
||||
// URL returns the URL to be downloaded.
|
||||
func (r *Request) URL() *url.URL {
|
||||
return r.HTTPRequest.URL
|
||||
}
|
||||
|
||||
// SetChecksum sets the desired hashing algorithm and checksum value to validate
|
||||
// a downloaded file. Once the download is complete, the given hashing algorithm
|
||||
// will be used to compute the actual checksum of the downloaded file. If the
|
||||
// checksums do not match, an error will be returned by the associated
|
||||
// Response.Err method.
|
||||
//
|
||||
// If deleteOnError is true, the downloaded file will be deleted automatically
|
||||
// if it fails checksum validation.
|
||||
//
|
||||
// To prevent corruption of the computed checksum, the given hash must not be
|
||||
// used by any other request or goroutines.
|
||||
//
|
||||
// To disable checksum validation, call SetChecksum with a nil hash.
|
||||
func (r *Request) SetChecksum(h hash.Hash, sum []byte, deleteOnError bool) {
|
||||
r.hash = h
|
||||
r.checksum = sum
|
||||
r.deleteOnError = deleteOnError
|
||||
}
|
258
grab/response.go
Normal file
258
grab/response.go
Normal file
@ -0,0 +1,258 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Response represents the response to a completed or in-progress download
|
||||
// request.
|
||||
//
|
||||
// A response may be returned as soon a HTTP response is received from a remote
|
||||
// server, but before the body content has started transferring.
|
||||
//
|
||||
// All Response method calls are thread-safe.
|
||||
type Response struct {
|
||||
// The Request that was submitted to obtain this Response.
|
||||
Request *Request
|
||||
|
||||
// HTTPResponse represents the HTTP response received from an HTTP request.
|
||||
//
|
||||
// The response Body should not be used as it will be consumed and closed by
|
||||
// grab.
|
||||
HTTPResponse *http.Response
|
||||
|
||||
// Filename specifies the path where the file transfer is stored in local
|
||||
// storage.
|
||||
Filename string
|
||||
|
||||
// Size specifies the total expected size of the file transfer.
|
||||
sizeUnsafe int64
|
||||
|
||||
// Start specifies the time at which the file transfer started.
|
||||
Start time.Time
|
||||
|
||||
// End specifies the time at which the file transfer completed.
|
||||
//
|
||||
// This will return zero until the transfer has completed.
|
||||
End time.Time
|
||||
|
||||
// CanResume specifies that the remote server advertised that it can resume
|
||||
// previous downloads, as the 'Accept-Ranges: bytes' header is set.
|
||||
CanResume bool
|
||||
|
||||
// DidResume specifies that the file transfer resumed a previously incomplete
|
||||
// transfer.
|
||||
DidResume bool
|
||||
|
||||
// Done is closed once the transfer is finalized, either successfully or with
|
||||
// errors. Errors are available via Response.Err
|
||||
Done chan struct{}
|
||||
|
||||
// ctx is a Context that controls cancelation of an inprogress transfer
|
||||
ctx context.Context
|
||||
|
||||
// cancel is a cancel func that can be used to cancel the context of this
|
||||
// Response.
|
||||
cancel context.CancelFunc
|
||||
|
||||
// fi is the FileInfo for the destination file if it already existed before
|
||||
// transfer started.
|
||||
fi os.FileInfo
|
||||
|
||||
// optionsKnown indicates that a HEAD request has been completed and the
|
||||
// capabilities of the remote server are known.
|
||||
optionsKnown bool
|
||||
|
||||
// writer is the file handle used to write the downloaded file to local
|
||||
// storage
|
||||
writer io.Writer
|
||||
|
||||
// storeBuffer receives the contents of the transfer if Request.NoStore is
|
||||
// enabled.
|
||||
storeBuffer bytes.Buffer
|
||||
|
||||
// bytesCompleted specifies the number of bytes which were already
|
||||
// transferred before this transfer began.
|
||||
bytesResumed int64
|
||||
|
||||
// transfer is responsible for copying data from the remote server to a local
|
||||
// file, tracking progress and allowing for cancelation.
|
||||
transfer *transfer
|
||||
|
||||
// bufferSize specifies the size in bytes of the transfer buffer.
|
||||
bufferSize int
|
||||
|
||||
// Error contains any error that may have occurred during the file transfer.
|
||||
// This should not be read until IsComplete returns true.
|
||||
err error
|
||||
}
|
||||
|
||||
// IsComplete returns true if the download has completed. If an error occurred
|
||||
// during the download, it can be returned via Err.
|
||||
func (c *Response) IsComplete() bool {
|
||||
select {
|
||||
case <-c.Done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel cancels the file transfer by canceling the underlying Context for
|
||||
// this Response. Cancel blocks until the transfer is closed and returns any
|
||||
// error - typically context.Canceled.
|
||||
func (c *Response) Cancel() error {
|
||||
c.cancel()
|
||||
return c.Err()
|
||||
}
|
||||
|
||||
// Wait blocks until the download is completed.
|
||||
func (c *Response) Wait() {
|
||||
<-c.Done
|
||||
}
|
||||
|
||||
// Err blocks the calling goroutine until the underlying file transfer is
|
||||
// completed and returns any error that may have occurred. If the download is
|
||||
// already completed, Err returns immediately.
|
||||
func (c *Response) Err() error {
|
||||
<-c.Done
|
||||
return c.err
|
||||
}
|
||||
|
||||
// Size returns the size of the file transfer. If the remote server does not
|
||||
// specify the total size and the transfer is incomplete, the return value is
|
||||
// -1.
|
||||
func (c *Response) Size() int64 {
|
||||
return atomic.LoadInt64(&c.sizeUnsafe)
|
||||
}
|
||||
|
||||
// BytesComplete returns the total number of bytes which have been copied to
|
||||
// the destination, including any bytes that were resumed from a previous
|
||||
// download.
|
||||
func (c *Response) BytesComplete() int64 {
|
||||
return c.bytesResumed + c.transfer.N()
|
||||
}
|
||||
|
||||
// BytesPerSecond returns the number of bytes per second transferred using a
|
||||
// simple moving average of the last five seconds. If the download is already
|
||||
// complete, the average bytes/sec for the life of the download is returned.
|
||||
func (c *Response) BytesPerSecond() float64 {
|
||||
if c.IsComplete() {
|
||||
return float64(c.transfer.N()) / c.Duration().Seconds()
|
||||
}
|
||||
return c.transfer.BPS()
|
||||
}
|
||||
|
||||
// Progress returns the ratio of total bytes that have been downloaded. Multiply
|
||||
// the returned value by 100 to return the percentage completed.
|
||||
func (c *Response) Progress() float64 {
|
||||
size := c.Size()
|
||||
if size <= 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(c.BytesComplete()) / float64(size)
|
||||
}
|
||||
|
||||
// Duration returns the duration of a file transfer. If the transfer is in
|
||||
// process, the duration will be between now and the start of the transfer. If
|
||||
// the transfer is complete, the duration will be between the start and end of
|
||||
// the completed transfer process.
|
||||
func (c *Response) Duration() time.Duration {
|
||||
if c.IsComplete() {
|
||||
return c.End.Sub(c.Start)
|
||||
}
|
||||
|
||||
return time.Now().Sub(c.Start)
|
||||
}
|
||||
|
||||
// ETA returns the estimated time at which the the download will complete, given
|
||||
// the current BytesPerSecond. If the transfer has already completed, the actual
|
||||
// end time will be returned.
|
||||
func (c *Response) ETA() time.Time {
|
||||
if c.IsComplete() {
|
||||
return c.End
|
||||
}
|
||||
bt := c.BytesComplete()
|
||||
bps := c.transfer.BPS()
|
||||
if bps == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
secs := float64(c.Size()-bt) / bps
|
||||
return time.Now().Add(time.Duration(secs) * time.Second)
|
||||
}
|
||||
|
||||
// Open blocks the calling goroutine until the underlying file transfer is
|
||||
// completed and then opens the transferred file for reading. If Request.NoStore
|
||||
// was enabled, the reader will read from memory.
|
||||
//
|
||||
// If an error occurred during the transfer, it will be returned.
|
||||
//
|
||||
// It is the callers responsibility to close the returned file handle.
|
||||
func (c *Response) Open() (io.ReadCloser, error) {
|
||||
if err := c.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.openUnsafe()
|
||||
}
|
||||
|
||||
func (c *Response) openUnsafe() (io.ReadCloser, error) {
|
||||
if c.Request.NoStore {
|
||||
return ioutil.NopCloser(bytes.NewReader(c.storeBuffer.Bytes())), nil
|
||||
}
|
||||
return os.Open(c.Filename)
|
||||
}
|
||||
|
||||
// Bytes blocks the calling goroutine until the underlying file transfer is
|
||||
// completed and then reads all bytes from the completed tranafer. If
|
||||
// Request.NoStore was enabled, the bytes will be read from memory.
|
||||
//
|
||||
// If an error occurred during the transfer, it will be returned.
|
||||
func (c *Response) Bytes() ([]byte, error) {
|
||||
if err := c.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.Request.NoStore {
|
||||
return c.storeBuffer.Bytes(), nil
|
||||
}
|
||||
f, err := c.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
return ioutil.ReadAll(f)
|
||||
}
|
||||
|
||||
func (c *Response) requestMethod() string {
|
||||
if c == nil || c.HTTPResponse == nil || c.HTTPResponse.Request == nil {
|
||||
return ""
|
||||
}
|
||||
return c.HTTPResponse.Request.Method
|
||||
}
|
||||
|
||||
func (c *Response) checksumUnsafe() ([]byte, error) {
|
||||
f, err := c.openUnsafe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
t := newTransfer(c.Request.Context(), nil, c.Request.hash, f, nil)
|
||||
if _, err = t.copy(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sum := c.Request.hash.Sum(nil)
|
||||
return sum, nil
|
||||
}
|
||||
|
||||
func (c *Response) closeResponseBody() error {
|
||||
if c.HTTPResponse == nil || c.HTTPResponse.Body == nil {
|
||||
return nil
|
||||
}
|
||||
return c.HTTPResponse.Body.Close()
|
||||
}
|
118
grab/response_test.go
Normal file
118
grab/response_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cavaliercoder/grab/grabtest"
|
||||
)
|
||||
|
||||
// testComplete validates that a completed Response has all the desired fields.
|
||||
func testComplete(t *testing.T, resp *Response) {
|
||||
<-resp.Done
|
||||
if !resp.IsComplete() {
|
||||
t.Errorf("Response.IsComplete returned false")
|
||||
}
|
||||
|
||||
if resp.Start.IsZero() {
|
||||
t.Errorf("Response.Start is zero")
|
||||
}
|
||||
|
||||
if resp.End.IsZero() {
|
||||
t.Error("Response.End is zero")
|
||||
}
|
||||
|
||||
if eta := resp.ETA(); eta != resp.End {
|
||||
t.Errorf("Response.ETA is not equal to Response.End: %v", eta)
|
||||
}
|
||||
|
||||
// the following fields should only be set if no error occurred
|
||||
if resp.Err() == nil {
|
||||
if resp.Filename == "" {
|
||||
t.Errorf("Response.Filename is empty")
|
||||
}
|
||||
|
||||
if resp.Size() == 0 {
|
||||
t.Error("Response.Size is zero")
|
||||
}
|
||||
|
||||
if p := resp.Progress(); p != 1.00 {
|
||||
t.Errorf("Response.Progress returned %v (%v/%v bytes), expected 1", p, resp.BytesComplete(), resp.Size())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponseProgress tests the functions which indicate the progress of an
|
||||
// in-process file transfer.
|
||||
func TestResponseProgress(t *testing.T) {
|
||||
filename := ".testResponseProgress"
|
||||
defer os.Remove(filename)
|
||||
|
||||
sleep := 300 * time.Millisecond
|
||||
size := 1024 * 8 // bytes
|
||||
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
// request a slow transfer
|
||||
req := mustNewRequest(filename, url)
|
||||
resp := DefaultClient.Do(req)
|
||||
|
||||
// make sure transfer has not started
|
||||
if resp.IsComplete() {
|
||||
t.Errorf("Transfer should not have started")
|
||||
}
|
||||
|
||||
if p := resp.Progress(); p != 0 {
|
||||
t.Errorf("Transfer should not have started yet but progress is %v", p)
|
||||
}
|
||||
|
||||
// wait for transfer to complete
|
||||
<-resp.Done
|
||||
|
||||
// make sure transfer is complete
|
||||
if p := resp.Progress(); p != 1 {
|
||||
t.Errorf("Transfer is complete but progress is %v", p)
|
||||
}
|
||||
|
||||
if s := resp.BytesComplete(); s != int64(size) {
|
||||
t.Errorf("Expected to transfer %v bytes, got %v", size, s)
|
||||
}
|
||||
},
|
||||
grabtest.TimeToFirstByte(sleep),
|
||||
grabtest.ContentLength(size),
|
||||
)
|
||||
}
|
||||
|
||||
func TestResponseOpen(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
resp := mustDo(mustNewRequest("", url+"/someFilename"))
|
||||
f, err := resp.Open()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
grabtest.AssertSHA256Sum(t, grabtest.DefaultHandlerSHA256ChecksumBytes, f)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponseBytes(t *testing.T) {
|
||||
grabtest.WithTestServer(t, func(url string) {
|
||||
resp := mustDo(mustNewRequest("", url+"/someFilename"))
|
||||
b, err := resp.Bytes()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
grabtest.AssertSHA256Sum(
|
||||
t,
|
||||
grabtest.DefaultHandlerSHA256ChecksumBytes,
|
||||
bytes.NewReader(b),
|
||||
)
|
||||
})
|
||||
}
|
111
grab/states.wsd
Normal file
111
grab/states.wsd
Normal file
@ -0,0 +1,111 @@
|
||||
@startuml
|
||||
title Grab transfer state
|
||||
|
||||
legend
|
||||
| # | Meaning |
|
||||
| D | Destination path known |
|
||||
| S | File size known |
|
||||
| O | Server options known (Accept-Ranges) |
|
||||
| R | Resume supported (Accept-Ranges) |
|
||||
| Z | Local file empty or missing |
|
||||
| P | Local file partially complete |
|
||||
endlegend
|
||||
|
||||
[*] --> Empty
|
||||
[*] --> D
|
||||
[*] --> S
|
||||
[*] --> DS
|
||||
|
||||
Empty : Filename: ""
|
||||
Empty : Size: 0
|
||||
Empty --> O : HEAD: Method not allowed
|
||||
Empty --> DSO : HEAD: Range not supported
|
||||
Empty --> DSOR : HEAD: Range supported
|
||||
|
||||
DS : Filename: "foo.bar"
|
||||
DS : Size: > 0
|
||||
DS --> DSZ : checkExisting(): File missing
|
||||
DS --> DSP : checkExisting(): File partial
|
||||
DS --> [*] : checkExisting(): File complete
|
||||
DS --> ERROR
|
||||
|
||||
S : Filename: ""
|
||||
S : Size: > 0
|
||||
S --> SO : HEAD: Method not allowed
|
||||
S --> DSO : HEAD: Range not supported
|
||||
S --> DSOR : HEAD: Range supported
|
||||
|
||||
D : Filename: "foo.bar"
|
||||
D : Size: 0
|
||||
D --> DO : HEAD: Method not allowed
|
||||
D --> DSO : HEAD: Range not supported
|
||||
D --> DSOR : HEAD: Range supported
|
||||
|
||||
|
||||
O : Filename: ""
|
||||
O : Size: 0
|
||||
O : CanResume: false
|
||||
O --> DSO : GET 200
|
||||
O --> ERROR
|
||||
|
||||
SO : Filename: ""
|
||||
SO : Size: > 0
|
||||
SO : CanResume: false
|
||||
SO --> DSO : GET: 200
|
||||
SO --> ERROR
|
||||
|
||||
DO : Filename: "foo.bar"
|
||||
DO : Size: 0
|
||||
DO : CanResume: false
|
||||
DO --> DSO : GET 200
|
||||
DO --> ERROR
|
||||
|
||||
DSZ : Filename: "foo.bar"
|
||||
DSZ : Size: > 0
|
||||
DSZ : File: empty
|
||||
DSZ --> DSORZ : HEAD: Range supported
|
||||
DSZ --> DSOZ : HEAD 405 or Range unsupported
|
||||
|
||||
DSP : Filename: "foo.bar"
|
||||
DSP : Size: > 0
|
||||
DSP : File: partial
|
||||
DSP --> DSORP : HEAD: Range supported
|
||||
DSP --> DSOZ : HEAD: 405 or Range unsupported
|
||||
|
||||
DSO : Filename: "foo.bar"
|
||||
DSO : Size: > 0
|
||||
DSO : CanResume: false
|
||||
DSO --> DSOZ : checkExisting(): File partial|missing
|
||||
DSO --> [*] : checkExisting(): File complete
|
||||
|
||||
DSOR : Filename: "foo.bar"
|
||||
DSOR : Size: > 0
|
||||
DSOR : CanResume: true
|
||||
DSOR --> DSORP : CheckLocal: File partial
|
||||
DSOR --> DSORZ : CheckLocal: File missing
|
||||
|
||||
DSORP : Filename: "foo.bar"
|
||||
DSORP : Size: > 0
|
||||
DSORP : CanResume: true
|
||||
DSORP : File: partial
|
||||
DSORP --> Transferring
|
||||
|
||||
DSORZ : Filename: "foo.bar"
|
||||
DSORZ : Size: > 0
|
||||
DSORZ : CanResume: true
|
||||
DSORZ : File: empty
|
||||
DSORZ --> Transferring
|
||||
|
||||
DSOZ : Filename: "foo.bar"
|
||||
DSOZ : Size: > 0
|
||||
DSOZ : CanResume: false
|
||||
DSOZ : File: empty
|
||||
DSOZ --> Transferring
|
||||
|
||||
Transferring --> [*]
|
||||
Transferring --> ERROR
|
||||
|
||||
ERROR : Something went wrong
|
||||
ERROR --> [*]
|
||||
|
||||
@enduml
|
103
grab/transfer.go
Normal file
103
grab/transfer.go
Normal file
@ -0,0 +1,103 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cavaliercoder/grab/bps"
|
||||
)
|
||||
|
||||
type transfer struct {
|
||||
n int64 // must be 64bit aligned on 386
|
||||
ctx context.Context
|
||||
gauge bps.Gauge
|
||||
lim RateLimiter
|
||||
w io.Writer
|
||||
r io.Reader
|
||||
b []byte
|
||||
}
|
||||
|
||||
func newTransfer(ctx context.Context, lim RateLimiter, dst io.Writer, src io.Reader, buf []byte) *transfer {
|
||||
return &transfer{
|
||||
ctx: ctx,
|
||||
gauge: bps.NewSMA(6), // five second moving average sampling every second
|
||||
lim: lim,
|
||||
w: dst,
|
||||
r: src,
|
||||
b: buf,
|
||||
}
|
||||
}
|
||||
|
||||
// copy behaves similarly to io.CopyBuffer except that it checks for cancelation
|
||||
// of the given context.Context, reports progress in a thread-safe manner and
|
||||
// tracks the transfer rate.
|
||||
func (c *transfer) copy() (written int64, err error) {
|
||||
// maintain a bps gauge in another goroutine
|
||||
ctx, cancel := context.WithCancel(c.ctx)
|
||||
defer cancel()
|
||||
go bps.Watch(ctx, c.gauge, c.N, time.Second)
|
||||
|
||||
// start the transfer
|
||||
if c.b == nil {
|
||||
c.b = make([]byte, 32*1024)
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
err = c.ctx.Err()
|
||||
return
|
||||
default:
|
||||
// keep working
|
||||
}
|
||||
nr, er := c.r.Read(c.b)
|
||||
if nr > 0 {
|
||||
nw, ew := c.w.Write(c.b[0:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
atomic.StoreInt64(&c.n, written)
|
||||
}
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
if nr != nw {
|
||||
err = io.ErrShortWrite
|
||||
break
|
||||
}
|
||||
// wait for rate limiter
|
||||
if c.lim != nil {
|
||||
err = c.lim.WaitN(c.ctx, nr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
|
||||
// N returns the number of bytes transferred.
|
||||
func (c *transfer) N() (n int64) {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
n = atomic.LoadInt64(&c.n)
|
||||
return
|
||||
}
|
||||
|
||||
// BPS returns the current bytes per second transfer rate using a simple moving
|
||||
// average.
|
||||
func (c *transfer) BPS() (bps float64) {
|
||||
if c == nil || c.gauge == nil {
|
||||
return 0
|
||||
}
|
||||
return c.gauge.BPS()
|
||||
}
|
75
grab/util.go
Normal file
75
grab/util.go
Normal file
@ -0,0 +1,75 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/lordwelch/pathvalidate"
|
||||
)
|
||||
|
||||
// setLastModified sets the last modified timestamp of a local file according to
|
||||
// the Last-Modified header returned by a remote server.
|
||||
func setLastModified(resp *http.Response, filename string) error {
|
||||
// https://tools.ietf.org/html/rfc7232#section-2.2
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified
|
||||
header := resp.Header.Get("Last-Modified")
|
||||
if header == "" {
|
||||
return nil
|
||||
}
|
||||
lastmod, err := time.Parse(http.TimeFormat, header)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return os.Chtimes(filename, lastmod, lastmod)
|
||||
}
|
||||
|
||||
// mkdirp creates all missing parent directories for the destination file path.
|
||||
func mkdirp(path string) error {
|
||||
dir := filepath.Dir(path)
|
||||
if fi, err := os.Stat(dir); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("error checking destination directory: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0777); err != nil {
|
||||
return fmt.Errorf("error creating destination directory: %v", err)
|
||||
}
|
||||
} else if !fi.IsDir() {
|
||||
panic("grab: developer error: destination path is not directory")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// guessFilename returns a filename for the given http.Response. If none can be
|
||||
// determined ErrNoFilename is returned.
|
||||
//
|
||||
// TODO: NoStore operations should not require a filename
|
||||
func guessFilename(resp *http.Response) (string, error) {
|
||||
var (
|
||||
err error
|
||||
filename = resp.Request.URL.Path
|
||||
)
|
||||
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
|
||||
if _, params, err := mime.ParseMediaType(cd); err == nil {
|
||||
if val, ok := params["filename"]; ok {
|
||||
filename = val
|
||||
} // else filename directive is missing. fallback to URL.Path
|
||||
}
|
||||
}
|
||||
|
||||
// sanitize
|
||||
filename, err = pathvalidate.SanitizeFilename(path.Base(filename), '_')
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %v", ErrNoFilename, err)
|
||||
}
|
||||
|
||||
if filename == "" || filename == "." || filename == "/" {
|
||||
return "", ErrNoFilename
|
||||
}
|
||||
|
||||
return filename, nil
|
||||
}
|
167
grab/util_test.go
Normal file
167
grab/util_test.go
Normal file
@ -0,0 +1,167 @@
|
||||
package grab
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestURLFilenames(t *testing.T) {
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
expect := "filename"
|
||||
testCases := []string{
|
||||
"http://test.com/filename",
|
||||
"http://test.com/path/filename",
|
||||
"http://test.com/deep/path/filename",
|
||||
"http://test.com/filename?with=args",
|
||||
"http://test.com/filename#with-fragment",
|
||||
"http://test.com/filename?with=args&and#with-fragment",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
req, _ := http.NewRequest("GET", tc, nil)
|
||||
resp := &http.Response{
|
||||
Request: req,
|
||||
}
|
||||
actual, err := guessFilename(resp)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if actual != expect {
|
||||
t.Errorf("expected '%v', got '%v'", expect, actual)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
testCases := []string{
|
||||
"http://test.com",
|
||||
"http://test.com/",
|
||||
"http://test.com/filename/",
|
||||
"http://test.com/filename/?with=args",
|
||||
"http://test.com/filename/#with-fragment",
|
||||
"http://test.com/filename\x00",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc, func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", tc, nil)
|
||||
if err != nil {
|
||||
if tc == "http://test.com/filename\x00" {
|
||||
// Since go1.12, urls with invalid control character return an error
|
||||
// See https://github.com/golang/go/commit/829c5df58694b3345cb5ea41206783c8ccf5c3ca
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
resp := &http.Response{
|
||||
Request: req,
|
||||
}
|
||||
|
||||
_, err = guessFilename(resp)
|
||||
if err != ErrNoFilename {
|
||||
t.Errorf("expected '%v', got '%v'", ErrNoFilename, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHeaderFilenames(t *testing.T) {
|
||||
u, _ := url.ParseRequestURI("http://test.com/badfilename")
|
||||
resp := &http.Response{
|
||||
Request: &http.Request{
|
||||
URL: u,
|
||||
},
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
setFilename := func(resp *http.Response, filename string) {
|
||||
resp.Header.Set("Content-Disposition", fmt.Sprintf("attachment;filename=\"%s\"", filename))
|
||||
}
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
expect := "filename"
|
||||
testCases := []string{
|
||||
"filename",
|
||||
"path/filename",
|
||||
"/path/filename",
|
||||
"../../filename",
|
||||
"/path/../../filename",
|
||||
"/../../././///filename",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
setFilename(resp, tc)
|
||||
actual, err := guessFilename(resp)
|
||||
if err != nil {
|
||||
t.Errorf("error (%v): %v", tc, err)
|
||||
}
|
||||
|
||||
if actual != expect {
|
||||
t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Invalid", func(t *testing.T) {
|
||||
testCases := []string{
|
||||
"",
|
||||
"/",
|
||||
".",
|
||||
"/.",
|
||||
"/./",
|
||||
"..",
|
||||
"../",
|
||||
"/../",
|
||||
"/path/",
|
||||
"../path/",
|
||||
"filename\x00",
|
||||
"filename/",
|
||||
"filename//",
|
||||
"filename/..",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
setFilename(resp, tc)
|
||||
if actual, err := guessFilename(resp); err != ErrNoFilename {
|
||||
t.Errorf("expected: %v (%v), got: %v (%v)", ErrNoFilename, tc, err, actual)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHeaderWithMissingDirective(t *testing.T) {
|
||||
u, _ := url.ParseRequestURI("http://test.com/filename")
|
||||
resp := &http.Response{
|
||||
Request: &http.Request{
|
||||
URL: u,
|
||||
},
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
setHeader := func(resp *http.Response, value string) {
|
||||
resp.Header.Set("Content-Disposition", value)
|
||||
}
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
expect := "filename"
|
||||
testCases := []string{
|
||||
"inline",
|
||||
"attachment",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
setHeader(resp, tc)
|
||||
actual, err := guessFilename(resp)
|
||||
if err != nil {
|
||||
t.Errorf("error (%v): %v", tc, err)
|
||||
}
|
||||
|
||||
if actual != expect {
|
||||
t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user