gloader/downloader.go
2020-10-25 12:01:27 -07:00

486 lines
11 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"mime"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
"os"
"path"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"github.com/cavaliercoder/grab"
"github.com/lordwelch/pathvalidate"
"golang.org/x/net/publicsuffix"
)
var (
DefaultCookieJar = newCookieJar()
DefaultGrabClient = grab.NewClient()
DefaultMaxActiveDownloads = 4
ErrUnsupportedScheme = errors.New("unsupported scheme")
)
type Priority uint8
type Status uint8
const (
Highest Priority = iota
High
Medium
Low
)
const (
Queued Status = iota
Complete
Stopped
Paused
Downloading
Error
Canceled
)
type Downloader struct {
DataDir string
DownloadDir string
CompleteDir string
InfoDir string
Grab *grab.Client
Jar http.CookieJar
MaxActiveDownloads int
Server *http.Server
downloads RequestQueue
history RequestQueue
NewRequest chan Request
requestDone chan *Request
}
type Request struct {
URL url.URL `json:"url"`
Cookies []http.Cookie `json:"cookies"`
ForceDownload bool `json:"forceDownload"`
Status Status `json:"-"`
Priority Priority `json:"priority"`
Filepath string `json:"filepath"`
TempPath string `json:"tempPath"`
Response *grab.Response `json:"-"`
Error error `json:"-"`
CompletedDate time.Time
}
type RequestQueue struct {
queue []*Request
URLSort bool
DateSort bool
}
func (rq RequestQueue) Less(i, j int) bool {
ii := 0
jj := 0
if rq.queue[i].ForceDownload {
ii = 1
}
if rq.queue[j].ForceDownload {
jj = 1
}
if ii < jj {
return true
}
if rq.queue[i].Priority < rq.queue[j].Priority {
return true
}
if rq.DateSort && rq.queue[i].CompletedDate.Before(rq.queue[j].CompletedDate) {
return true
}
if rq.URLSort && rq.queue[i].URL.String() < rq.queue[j].URL.String() {
return true
}
return false
}
func (rq RequestQueue) Len() int {
return len(rq.queue)
}
func (rq RequestQueue) Swap(i, j int) {
rq.queue[i], rq.queue[j] = rq.queue[j], rq.queue[i]
}
func (rq *RequestQueue) Pop(i int) *Request {
r := rq.queue[i]
copy(rq.queue[i:], rq.queue[i+1:])
rq.queue[len(rq.queue)-1] = nil
rq.queue = rq.queue[:len(rq.queue)-1]
return r
}
func (rq *RequestQueue) remove(r *Request) {
for i, req := range rq.queue {
if req == r {
copy(rq.queue[i:], rq.queue[i+1:])
rq.queue[len(rq.queue)-1] = nil
rq.queue = rq.queue[:len(rq.queue)-1]
break
}
}
}
func newCookieJar() http.CookieJar {
c, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
return c
}
func newDownloader() *Downloader {
return &Downloader{
Jar: DefaultCookieJar,
Grab: DefaultGrabClient,
}
}
func (d *Downloader) Start(network, address string) {
var (
listener net.Listener
mux = http.NewServeMux()
err error
)
if d.NewRequest == nil {
d.NewRequest = make(chan Request, 64)
}
if d.requestDone == nil {
d.requestDone = make(chan *Request, 64)
}
if d.MaxActiveDownloads < 1 {
d.MaxActiveDownloads = DefaultMaxActiveDownloads
}
if d.Server == nil {
d.Server = &http.Server{
Addr: address,
Handler: mux,
ReadTimeout: 2 * time.Minute,
WriteTimeout: 2 * time.Minute,
}
}
if d.DataDir == "" {
d.DataDir = "/perm/downloader"
}
if d.DownloadDir == "" {
d.DownloadDir = path.Join(d.DataDir, "Download")
}
if d.CompleteDir == "" {
d.CompleteDir = path.Join(d.DataDir, "Complete")
}
fmt.Println(d.DataDir)
fmt.Println(d.DownloadDir)
fmt.Println(d.CompleteDir)
os.MkdirAll(d.DataDir, 0777)
os.MkdirAll(d.DownloadDir, 0777)
os.MkdirAll(d.CompleteDir, 0777)
listener, err = net.Listen(network, address)
if err != nil {
panic(err)
}
fmt.Println("adding /add handler")
// mux.HandleFunc("/", d.UI)
mux.HandleFunc("/add", d.restAddDownload)
fmt.Println("starting main go routine")
go d.download()
fmt.Println("serving http server")
d.Server.Serve(listener)
}
func (d *Downloader) restAddDownload(w http.ResponseWriter, r *http.Request) {
var (
requests []Request
err error
)
if r.Method != http.MethodPost {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Add("Allow", http.MethodPost)
w.WriteHeader(http.StatusMethodNotAllowed)
fmt.Fprintln(w, "HTTP Error 405 Method Not Allowed\nOnly POST method is allowed")
fmt.Println("HTTP Error 405 Method Not Allowed\nOnly POST method is allowed")
return
}
// TODO fail only on individual requests
err = json.NewDecoder(r.Body).Decode(&requests)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
for _, req := range requests {
req.TempPath = ""
fmt.Println("adding request", req.URL.String())
d.NewRequest <- req
}
w.WriteHeader(http.StatusOK)
}
func (d Downloader) getContentDispsition(r Request) string {
var (
err error
re *http.Response
p map[string]string
)
ht := &http.Client{
Jar: d.Jar,
Timeout: 30 * time.Second,
}
re, err = ht.Head(r.URL.String())
if err != nil {
return ""
}
re.Body.Close()
_, p, err = mime.ParseMediaType(re.Header.Get("Content-Disposition"))
if err != nil {
return ""
}
if f, ok := p["filename"]; ok {
return f
}
return ""
}
// getFilename checks the provided filepath
// if not set uses the content-disposition from a head request
// if not set uses the basename of the url
// and sanitizes the filename using github.com/lordwelch/pathvalidate
func (d *Downloader) getFilename(r *Request) {
fmt.Println("Determining filename")
r.Filepath = filepath.Clean(r.Filepath)
if r.Filepath == "." {
fmt.Println("filename is empty, testing head request")
r.Filepath = d.getContentDispsition(*r)
fmt.Println("path from head request:", r.Filepath)
if r.Filepath == "" {
r.Filepath, _ = url.PathUnescape(filepath.Base(r.URL.Path))
}
}
r.Filepath, _ = pathvalidate.SanitizeFilename(r.Filepath, '_')
r.Filepath = filepath.Join(d.DownloadDir, r.Filepath)
// if filepath.IsAbs(r.Filepath) { // should already exist
// dir, file := filepath.Split(r.Filepath)
// // someone is trying to be sneaky (or someone changed the CompleteDir), change path to the correct dir
// if dir != filepath.Clean(d.CompleteDir) {
// r.Filepath = filepath.Join(d.CompleteDir, file)
// }
// return
// }
fmt.Println("result path:", r.Filepath)
}
func getNewFilename(dir, name string) string {
var (
err error
index = 1
)
fmt.Println("getfilename", dir, name)
ext := filepath.Ext(name)
base := strings.TrimSuffix(name, ext)
fmt.Println("stat", filepath.Join(dir, name))
_, err = os.Stat(filepath.Join(dir, name))
for err == nil {
name = strings.TrimRight(base+"."+strconv.Itoa(index)+ext, ".")
fmt.Println("stat", filepath.Join(dir, name))
_, err = os.Stat(filepath.Join(dir, name))
}
if os.IsNotExist(err) {
return filepath.Join(dir, name)
}
panic(err) // other path error
}
func (d Downloader) getDownloadFilename(r *Request) {
if r.TempPath == "" {
f, err := ioutil.TempFile(d.DownloadDir, filepath.Base(r.Filepath))
if err != nil {
fmt.Printf("request for %v failed: %v", r.URL.String(), err)
}
f.Close()
r.TempPath = filepath.Join(d.DownloadDir, f.Name())
}
f, err := os.OpenFile(r.Filepath, os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
return
}
f.Close()
}
func (d Downloader) SearchDownloads(u url.URL) int {
for i, req := range d.downloads.queue {
if req.URL.String() == u.String() {
return i
}
}
return -1
}
func (d Downloader) SearchHistory(u url.URL) int {
for i, req := range d.history.queue {
if req.URL.String() == u.String() {
return i
}
}
return -1
}
func (d Downloader) FindRequest(u url.URL) *Request {
if i := d.SearchDownloads(u); i >= 0 {
return d.downloads.queue[i]
}
if i := d.SearchHistory(u); i >= 0 {
return d.history.queue[i]
}
return nil
}
func (d *Downloader) addRequest(r *Request) {
fmt.Println("adding download for", r.URL.String())
req := d.FindRequest(r.URL)
d.getFilename(r)
if req != nil { // url alread added
fmt.Println("URL is already added", r.URL.String())
if fi, err := os.Stat(r.Filepath); filepath.Base(req.Filepath) == filepath.Base(r.Filepath) || (err == nil && fi.Name() == filepath.Base(r.Filepath) && fi.Size() != 0) { // filepath has been found, should this check for multiple downloads of the same url or let the download name increment automatically
fmt.Println("file already exists", r.Filepath)
//getNewFilename(d.CompleteDir, filepath.Base(r.Filepath))
d.validate(*r) // TODO, should also check to see if it seems like it is similar, (check first k to see if it is the same file?? leave option to user)
return
}
} else { // new request, download link
r.Filepath = getNewFilename(d.CompleteDir, filepath.Base(r.Filepath))
d.downloads.queue = append(d.downloads.queue, r)
}
if len(d.getRunningDownloads()) < d.MaxActiveDownloads {
d.startDownload(r)
}
}
func (d *Downloader) validate(r Request) {
//TODO
}
func (d *Downloader) startDownload(r *Request) {
fmt.Println("starting download for", r.URL.String())
d.getDownloadFilename(r)
req, err := grab.NewRequest(r.TempPath, r.URL.String())
if err != nil {
r.Status = Error
r.Error = err
return
}
r.Status = Downloading
r.Response = d.Grab.Do(req)
go func(r *Request) {
fmt.Println("wait for download")
fmt.Println(r.Response.IsComplete())
r.Response.Wait()
fmt.Println("download completed for", r.URL)
d.requestDone <- r
}(r)
}
func (d Downloader) getRunningDownloads() []*Request {
var (
running = make([]*Request, 0, d.MaxActiveDownloads)
)
for _, req := range d.downloads.queue {
if req.Status == Downloading && req.Response != nil {
running = append(running, req)
}
}
return running
}
func (d *Downloader) syncDownloads() {
if len(d.getRunningDownloads()) >= d.MaxActiveDownloads {
return
}
sort.Stable(d.downloads)
// Start new downloads
for _, req := range d.downloads.queue {
if d.MaxActiveDownloads >= len(d.getRunningDownloads()) {
if req.Status == Queued {
d.startDownload(req)
}
}
}
// Clean completed/canceled downloads
for i := 0; i < d.downloads.Len(); i++ {
if d.downloads.queue[i].Status == Complete || d.downloads.queue[i].Status == Canceled {
d.history.queue = append(d.history.queue, d.downloads.Pop(i))
i--
}
}
}
func (d *Downloader) requestCompleted(r *Request) {
if r.Response.Err() == nil {
fmt.Println("removing from downloads")
d.downloads.remove(r)
r.Status = Complete
fmt.Println(r.TempPath, "!=", r.Filepath)
if r.TempPath != r.Filepath {
fmt.Println("renaming download to the completed dir")
os.Rename(r.TempPath, r.Filepath)
}
d.history.queue = append(d.history.queue, r)
} else {
r.Status = Error
r.Error = r.Response.Err()
fmt.Println("fucking error:", r.Error)
}
}
func (d *Downloader) download() {
for {
select {
case TIME := <-time.After(10 * time.Second):
fmt.Println(TIME)
for _, req := range d.downloads.queue {
fmt.Println(req.URL)
fmt.Println(req.Status)
fmt.Println(req.Response.ETA())
}
d.syncDownloads()
case r := <-d.NewRequest:
d.addRequest(&r)
case r := <-d.requestDone:
fmt.Println("finishing request for", r.URL)
d.requestCompleted(r)
}
}
}