package grab import ( "bytes" "context" "fmt" "io" "net/http" "os" "path/filepath" "sync" "sync/atomic" "time" ) // HTTPClient provides an interface allowing us to perform HTTP requests. type HTTPClient interface { Do(req *http.Request) (*http.Response, error) } // truncater is a private interface allowing different response // Writers to be truncated type truncater interface { Truncate(size int64) error } // A Client is a file download client. // // Clients are safe for concurrent use by multiple goroutines. type Client struct { // HTTPClient specifies the http.Client which will be used for communicating // with the remote server during the file transfer. HTTPClient HTTPClient // UserAgent specifies the User-Agent string which will be set in the // headers of all requests made by this client. // // The user agent string may be overridden in the headers of each request. UserAgent string // BufferSize specifies the size in bytes of the buffer that is used for // transferring all requested files. Larger buffers may result in faster // throughput but will use more memory and result in less frequent updates // to the transfer progress statistics. The BufferSize of each request can // be overridden on each Request object. Default: 32KB. BufferSize int } // NewClient returns a new file download Client, using default configuration. func NewClient() *Client { return &Client{ UserAgent: "grab", HTTPClient: &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, }, }, } } // DefaultClient is the default client and is used by all Get convenience // functions. var DefaultClient = NewClient() // Do sends a file transfer request and returns a file transfer response, // following policy (e.g. redirects, cookies, auth) as configured on the // client's HTTPClient. // // Like http.Get, Do blocks while the transfer is initiated, but returns as soon // as the transfer has started transferring in a background goroutine, or if it // failed early. // // An error is returned via Response.Err if caused by client policy (such as // CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err // will block the caller until the transfer is completed, successfully or // otherwise. func (c *Client) Do(req *Request) *Response { // cancel will be called on all code-paths via closeResponse ctx, cancel := context.WithCancel(req.Context()) req = req.WithContext(ctx) resp := &Response{ Request: req, Start: time.Now(), Done: make(chan struct{}, 0), Filename: req.Filename, ctx: ctx, cancel: cancel, bufferSize: req.BufferSize, } if resp.bufferSize == 0 { // default to Client.BufferSize resp.bufferSize = c.BufferSize } // Run state-machine while caller is blocked to initialize the file transfer. // Must never transition to the copyFile state - this happens next in another // goroutine. c.run(resp, c.statFileInfo) // Run copyFile in a new goroutine. copyFile will no-op if the transfer is // already complete or failed. go c.run(resp, c.copyFile) return resp } // DoChannel executes all requests sent through the given Request channel, one // at a time, until it is closed by another goroutine. The caller is blocked // until the Request channel is closed and all transfers have completed. All // responses are sent through the given Response channel as soon as they are // received from the remote servers and can be used to track the progress of // each download. // // Slow Response receivers will cause a worker to block and therefore delay the // start of the transfer for an already initiated connection - potentially // causing a server timeout. It is the caller's responsibility to ensure a // sufficient buffer size is used for the Response channel to prevent this. // // If an error occurs during any of the file transfers it will be accessible via // the associated Response.Err function. func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) { // TODO: enable cancelling of batch jobs for req := range reqch { resp := c.Do(req) respch <- resp <-resp.Done } } // DoBatch executes all the given requests using the given number of concurrent // workers. Control is passed back to the caller as soon as the workers are // initiated. // // If the requested number of workers is less than one, a worker will be created // for every request. I.e. all requests will be executed concurrently. // // If an error occurs during any of the file transfers it will be accessible via // call to the associated Response.Err. // // The returned Response channel is closed only after all of the given Requests // have completed, successfully or otherwise. func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response { if workers < 1 { workers = len(requests) } reqch := make(chan *Request, len(requests)) respch := make(chan *Response, len(requests)) wg := sync.WaitGroup{} for i := 0; i < workers; i++ { wg.Add(1) go func() { c.DoChannel(reqch, respch) wg.Done() }() } // queue requests go func() { for _, req := range requests { reqch <- req } close(reqch) wg.Wait() close(respch) }() return respch } // An stateFunc is an action that mutates the state of a Response and returns // the next stateFunc to be called. type stateFunc func(*Response) stateFunc // run calls the given stateFunc function and all subsequent returned stateFuncs // until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc // should mutate the state of the given Response until it has completed // downloading or failed. func (c *Client) run(resp *Response, f stateFunc) { for { select { case <-resp.ctx.Done(): if resp.IsComplete() { return } resp.err = resp.ctx.Err() f = c.closeResponse default: // keep working } if f = f(resp); f == nil { return } } } // statFileInfo retrieves FileInfo for any local file matching // Response.Filename. // // If the file does not exist, is a directory, or its name is unknown the next // stateFunc is headRequest. // // If the file exists, Response.fi is set and the next stateFunc is // validateLocal. // // If an error occurs, the next stateFunc is closeResponse. func (c *Client) statFileInfo(resp *Response) stateFunc { if resp.Request.NoStore || resp.Filename == "" { // No filename provided will guess return c.headRequest } fi, err := os.Stat(resp.Filename) if err != nil { if os.IsNotExist(err) { // Filename does not exist, will download return c.headRequest } resp.err = err return c.closeResponse // Other PathError occured } if fi.IsDir() { // resp.Request.Filename is a directory // Will guess filename and append to resp.Request.Filename resp.Filename = "" return c.headRequest } resp.fi = fi return c.validateLocal // Filename exists, validate file } // validateLocal compares a local copy of the downloaded file to the remote // file. // // An error is returned if the local file is larger than the remote file, or // Request.SkipExisting is true. // // If the existing file matches the length of the remote file, the next // stateFunc is checksumFile. // // If the local file is smaller than the remote file and the remote server is // known to support ranged requests, the next stateFunc is getRequest. func (c *Client) validateLocal(resp *Response) stateFunc { if resp.Request.SkipExisting { resp.err = ErrFileExists return c.closeResponse } // determine target file size expectedSize := resp.Request.Size if expectedSize == 0 && resp.HTTPResponse != nil { expectedSize = resp.HTTPResponse.ContentLength } if expectedSize == 0 { // size is either actually 0 or unknown // if unknown, we ask the remote server // if known to be 0, we proceed with a GET return c.headRequest } if expectedSize == resp.fi.Size() { // local file matches remote file size - wrap it up resp.DidResume = true resp.bytesResumed = resp.fi.Size() return c.checksumFile } if resp.Request.NoResume { // local file should be overwritten return c.getRequest } if expectedSize >= 0 && expectedSize < resp.fi.Size() { // remote size is known, is smaller than local size and we want to resume fmt.Fprintln(os.Stderr, "validate\n") resp.err = ErrBadLength return c.closeResponse } if resp.CanResume { // set resume range on GET request resp.Request.HTTPRequest.Header.Set( "Range", fmt.Sprintf("bytes=%d-", resp.fi.Size())) resp.DidResume = true resp.bytesResumed = resp.fi.Size() return c.getRequest } return c.headRequest } func (c *Client) checksumFile(resp *Response) stateFunc { if resp.Request.hash == nil { return c.closeResponse } if resp.Filename == "" { panic("grab: developer error: filename not set") } if resp.Size() < 0 { panic("grab: developer error: size unknown") } req := resp.Request // compute checksum var sum []byte sum, resp.err = resp.checksumUnsafe() if resp.err != nil { return c.closeResponse } // compare checksum if !bytes.Equal(sum, req.checksum) { resp.err = ErrBadChecksum if !resp.Request.NoStore && req.deleteOnError { if err := os.Remove(resp.Filename); err != nil { // err should be os.PathError and include file path resp.err = fmt.Errorf( "cannot remove downloaded file with checksum mismatch: %v", err) } } } return c.closeResponse } // doHTTPRequest sends a HTTP Request and returns the response func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) { if c.UserAgent != "" && req.Header.Get("User-Agent") == "" { req.Header.Set("User-Agent", c.UserAgent) } return c.HTTPClient.Do(req) } func (c *Client) headRequest(resp *Response) stateFunc { if resp.optionsKnown { return c.getRequest } resp.optionsKnown = true if resp.Request.NoResume { return c.getRequest } if resp.Filename != "" && resp.fi == nil { // destination path is already known and does not exist return c.getRequest } hreq := new(http.Request) *hreq = *resp.Request.HTTPRequest hreq.Method = "HEAD" resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq) if resp.err != nil { return c.closeResponse } resp.HTTPResponse.Body.Close() if resp.HTTPResponse.StatusCode != http.StatusOK { return c.getRequest } // In case of redirects during HEAD, record the final URL and use it // instead of the original URL when sending future requests. // This way we avoid sending potentially unsupported requests to // the original URL, e.g. "Range", since it was the final URL // that advertised its support. resp.Request.HTTPRequest.URL = resp.HTTPResponse.Request.URL resp.Request.HTTPRequest.Host = resp.HTTPResponse.Request.Host return c.readResponse } func (c *Client) getRequest(resp *Response) stateFunc { resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest) if resp.err != nil { return c.closeResponse } // TODO: check Content-Range // check status code if !resp.Request.IgnoreBadStatusCodes { if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 { resp.err = StatusCodeError(resp.HTTPResponse.StatusCode) return c.closeResponse } } return c.readResponse } func (c *Client) readResponse(resp *Response) stateFunc { if resp.HTTPResponse == nil { panic("grab: developer error: Response.HTTPResponse is nil") } // check expected size resp.sizeUnsafe = resp.HTTPResponse.ContentLength if resp.sizeUnsafe >= 0 { // remote size is known resp.sizeUnsafe += resp.bytesResumed if resp.Request.Size > 0 && resp.Request.Size != resp.sizeUnsafe { fmt.Fprintln(os.Stderr, "response\n") resp.err = ErrBadLength return c.closeResponse } } // check filename if resp.Filename == "" { filename, err := guessFilename(resp.HTTPResponse) if err != nil { resp.err = err return c.closeResponse } // Request.Filename will be empty or a directory resp.Filename = filepath.Join(resp.Request.Filename, filename) } if !resp.Request.NoStore && resp.requestMethod() == "HEAD" { if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" { resp.CanResume = true } return c.statFileInfo } return c.openWriter } // openWriter opens the destination file for writing and seeks to the location // from whence the file transfer will resume. // // Requires that Response.Filename and resp.DidResume are already be set. func (c *Client) openWriter(resp *Response) stateFunc { if !resp.Request.NoStore && !resp.Request.NoCreateDirectories { resp.err = mkdirp(resp.Filename) if resp.err != nil { return c.closeResponse } } if resp.Request.NoStore { resp.writer = &resp.storeBuffer } else { // compute write flags flag := os.O_CREATE | os.O_WRONLY if resp.fi != nil { if resp.DidResume { flag = os.O_APPEND | os.O_WRONLY } else { // truncate later in copyFile, if not cancelled // by BeforeCopy hook flag = os.O_WRONLY } } // open file f, err := os.OpenFile(resp.Filename, flag, 0666) if err != nil { resp.err = err return c.closeResponse } resp.writer = f // seek to start or end whence := os.SEEK_SET if resp.bytesResumed > 0 { whence = os.SEEK_END } _, resp.err = f.Seek(0, whence) if resp.err != nil { return c.closeResponse } } // init transfer if resp.bufferSize < 1 { resp.bufferSize = 32 * 1024 } b := make([]byte, resp.bufferSize) resp.transfer = newTransfer( resp.Request.Context(), resp.Request.RateLimiter, resp.writer, resp.HTTPResponse.Body, b) // next step is copyFile, but this will be called later in another goroutine return nil } // copy transfers content for a HTTP connection established via Client.do() func (c *Client) copyFile(resp *Response) stateFunc { if resp.IsComplete() { return nil } // run BeforeCopy hook if f := resp.Request.BeforeCopy; f != nil { resp.err = f(resp) if resp.err != nil { return c.closeResponse } } var bytesCopied int64 if resp.transfer == nil { panic("grab: developer error: Response.transfer is nil") } // We waited to truncate the file in openWriter() to make sure // the BeforeCopy didn't cancel the copy. If this was an existing // file that is not going to be resumed, truncate the contents. if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume { t.Truncate(0) } bytesCopied, resp.err = resp.transfer.copy() if resp.err != nil { return c.closeResponse } closeWriter(resp) // set file timestamp if !resp.Request.NoStore && !resp.Request.IgnoreRemoteTime { resp.err = setLastModified(resp.HTTPResponse, resp.Filename) if resp.err != nil { return c.closeResponse } } // update transfer size if previously unknown if resp.Size() < 0 { discoveredSize := resp.bytesResumed + bytesCopied atomic.StoreInt64(&resp.sizeUnsafe, discoveredSize) if resp.Request.Size > 0 && resp.Request.Size != discoveredSize { fmt.Fprintln(os.Stderr, "file\n") resp.err = ErrBadLength return c.closeResponse } } // run AfterCopy hook if f := resp.Request.AfterCopy; f != nil { resp.err = f(resp) if resp.err != nil { return c.closeResponse } } return c.checksumFile } func closeWriter(resp *Response) { if closer, ok := resp.writer.(io.Closer); ok { closer.Close() } resp.writer = nil } // close finalizes the Response func (c *Client) closeResponse(resp *Response) stateFunc { if resp.IsComplete() { panic("grab: developer error: response already closed") } resp.fi = nil closeWriter(resp) resp.closeResponseBody() resp.End = time.Now() close(resp.Done) if resp.cancel != nil { resp.cancel() } return nil }