comic-hasher/main.go

540 lines
13 KiB
Go
Raw Normal View History

2024-05-01 18:09:02 -07:00
package main
import (
2024-07-31 11:35:17 -07:00
"bufio"
"context"
2024-05-01 18:09:02 -07:00
"encoding/json"
2024-07-31 11:35:17 -07:00
"flag"
2024-05-01 18:09:02 -07:00
"fmt"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
2024-07-31 11:35:17 -07:00
"io/fs"
2024-05-01 18:09:02 -07:00
"log"
"net/http"
"net/url"
"os"
2024-07-31 11:35:17 -07:00
"os/signal"
"path/filepath"
"runtime"
"runtime/pprof"
2024-05-01 18:09:02 -07:00
"strconv"
"strings"
"time"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/vp8"
_ "golang.org/x/image/vp8l"
_ "golang.org/x/image/webp"
2024-07-31 11:35:17 -07:00
"gitea.narnian.us/lordwelch/goimagehash"
2024-05-01 18:09:02 -07:00
"github.com/google/uuid"
"github.com/zitadel/oidc/pkg/client/rp"
httphelper "github.com/zitadel/oidc/pkg/http"
"github.com/zitadel/oidc/pkg/oidc"
)
const (
h_1 uint64 = 0b11111111 << (8 * iota)
h_2
h_3
h_4
h_5
h_6
h_7
h_8
)
const (
shift_1 = (8 * iota)
shift_2
shift_3
shift_4
shift_5
shift_6
shift_7
shift_8
)
type Cover map[string][]string // IDs is a map of domain to ID eg IDs['comicvine.gamespot.com'] = []string{"1235"}
// type Cover struct {
// AHash uint64
// DHash uint64
// PHash uint64
// IDs map[string][]string // IDs is a map of domain to ID eg IDs['comicvine.gamespot.com'] = []string{"1235"}
// }
type Server struct {
2024-07-31 11:35:17 -07:00
httpServer *http.Server
mux *http.ServeMux
BaseURL *url.URL
token chan<- *oidc.Tokens
ahash [8]map[uint8]string
dhash [8]map[uint8]string
phash [8]map[uint8]string
fAhash map[uint64]string
fDhash map[uint64]string
fPhash map[uint64]string
IDToCover map[string]string // IDToCover is a map of domain:id to an index to covers eg IDToCover['comicvine.gamespot.com:12345'] = 0
covers []Cover
hashingQueue chan im
mappingQueue chan hash
2024-05-01 18:09:02 -07:00
// hashes are a uint64 split into 8 pieces or a unint64 for quick lookup, the value is an index to covers
}
2024-07-31 11:35:17 -07:00
type im struct {
im image.Image
format string
domain string
id string
}
type hash struct {
ahash *goimagehash.ImageHash
dhash *goimagehash.ImageHash
phash *goimagehash.ImageHash
domain string
id string
}
2024-05-01 18:09:02 -07:00
var key = []byte(uuid.New().String())[:16]
func main() {
// mustDropPrivileges()
2024-07-31 11:35:17 -07:00
cover_path := flag.String("cover_path", "", "path to covers to add to hash database")
flag.Parse()
if *cover_path == "" {
log.Fatal("You must supply a path")
}
st, err := os.Stat(*cover_path)
if err != nil {
panic(err)
}
fmt.Println(st)
startServer(*cover_path)
2024-05-01 18:09:02 -07:00
}
func (s *Server) authenticated(w http.ResponseWriter, r *http.Request) (string, bool) {
return strings.TrimSpace("lordwelch"), true
}
func (s *Server) setupOauthHandlers() error {
redirectURI := *s.BaseURL
redirectURI.Path = "/oauth/callback"
successURI := *s.BaseURL
successURI.Path = "/success"
failURI := *s.BaseURL
failURI.RawQuery = url.Values{"auth": []string{"fail"}}.Encode()
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
options := []rp.Option{
rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
}
provider, err := rp.NewRelyingPartyOIDC(os.Getenv("COMICHASHER_PROVIDER_URL"), os.Getenv("COMICHASHER_CLIENT_ID"), os.Getenv("COMICHASHER_CLIENT_SECRET"), redirectURI.String(), strings.Split(os.Getenv("COMICHASHER_SCOPES"), ","), options...)
if err != nil {
return fmt.Errorf("error creating provider: %w", err)
}
// generate some state (representing the state of the user in your application,
// e.g. the page where he was before sending him to login
state := func() string {
return uuid.New().String()
}
// register the AuthURLHandler at your preferred path
// the AuthURLHandler creates the auth request and redirects the user to the auth server
// including state handling with secure cookie and the possibility to use PKCE
s.mux.Handle("/login", rp.AuthURLHandler(state, provider))
// for demonstration purposes the returned userinfo response is written as JSON object onto response
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
s.token <- tokens
w.Header().Add("location", successURI.String())
w.WriteHeader(301)
}
// register the CodeExchangeHandler at the callbackPath
// the CodeExchangeHandler handles the auth response, creates the token request and calls the callback function
// with the returned tokens from the token endpoint
s.mux.Handle(redirectURI.Path, rp.CodeExchangeHandler(marshalUserinfo, provider))
return nil
}
func (s *Server) setupAppHandlers() {
s.mux.HandleFunc("/add_cover", s.add_cover)
s.mux.HandleFunc("/get_cover", s.get_cover)
s.mux.HandleFunc("/match_cover_hash", s.match_cover_hash)
}
func (s *Server) get_cover(w http.ResponseWriter, r *http.Request) {
user, authed := s.authenticated(w, r)
if !authed || user == "" {
http.Error(w, "Invalid Auth", http.StatusForbidden)
return
}
var (
values = r.URL.Query()
domain = strings.TrimSpace(values.Get("domain"))
id = strings.TrimSpace(values.Get("id"))
)
if id == "" {
log.Println("No ID Provided")
http.Error(w, "No ID Provided", http.StatusBadRequest)
return
}
if domain == "" {
log.Println("No domain Provided")
http.Error(w, "No domain Provided", http.StatusBadRequest)
return
}
2024-07-31 11:35:17 -07:00
// if index, ok := s.IDToCover[domain+":"+id]; ok {
// covers, err := json.Marshal(s.covers[index])
// if err == nil {
// w.Header().Add("Content-Type", "application/json")
// w.Write(covers)
// return
// }
// }
fmt.Fprintln(w, "Not implemented")
2024-05-01 18:09:02 -07:00
}
func (s *Server) match_cover_hash(w http.ResponseWriter, r *http.Request) {
user, authed := s.authenticated(w, r)
if !authed || user == "" {
http.Error(w, "Invalid Auth", http.StatusForbidden)
return
}
var (
values = r.URL.Query()
ahashStr = strings.TrimSpace(values.Get("ahash"))
dhashStr = strings.TrimSpace(values.Get("dhash"))
phashStr = strings.TrimSpace(values.Get("phash"))
ahash uint64
dhash uint64
phash uint64
err error
)
if ahash, err = strconv.ParseUint(ahashStr, 16, 64); err != nil && ahashStr != "" {
log.Printf("could not parse ahash: %s", ahashStr)
http.Error(w, "parse fail", http.StatusBadRequest)
return
}
if dhash, err = strconv.ParseUint(dhashStr, 16, 64); err != nil && dhashStr != "" {
log.Printf("could not parse dhash: %s", dhashStr)
http.Error(w, "parse fail", http.StatusBadRequest)
return
}
if phash, err = strconv.ParseUint(phashStr, 16, 64); err != nil && phashStr != "" {
log.Printf("could not parse phash: %s", phashStr)
http.Error(w, "parse fail", http.StatusBadRequest)
return
}
2024-07-31 11:35:17 -07:00
type res struct {
Domain string
Id string
}
2024-05-01 18:09:02 -07:00
if index, ok := s.fAhash[ahash]; ok {
2024-07-31 11:35:17 -07:00
covers, err := json.Marshal(res{
strings.Split(index, ":")[0],
strings.Split(index, ":")[1],
})
2024-05-01 18:09:02 -07:00
if err == nil {
w.Header().Add("Content-Type", "application/json")
w.Write(covers)
2024-07-31 11:35:17 -07:00
w.Write([]byte{'\n'})
2024-05-01 18:09:02 -07:00
return
2024-07-31 11:35:17 -07:00
} else {
fmt.Println(err)
2024-05-01 18:09:02 -07:00
}
}
if index, ok := s.fDhash[dhash]; ok {
2024-07-31 11:35:17 -07:00
covers, err := json.Marshal(res{
strings.Split(index, ":")[0],
strings.Split(index, ":")[1],
})
2024-05-01 18:09:02 -07:00
if err == nil {
w.Header().Add("Content-Type", "application/json")
w.Write(covers)
2024-07-31 11:35:17 -07:00
w.Write([]byte{'\n'})
2024-05-01 18:09:02 -07:00
return
2024-07-31 11:35:17 -07:00
} else {
fmt.Println(err)
2024-05-01 18:09:02 -07:00
}
}
if index, ok := s.fPhash[phash]; ok {
2024-07-31 11:35:17 -07:00
covers, err := json.Marshal(res{
strings.Split(index, ":")[0],
strings.Split(index, ":")[1],
})
2024-05-01 18:09:02 -07:00
if err == nil {
w.Header().Add("Content-Type", "application/json")
w.Write(covers)
2024-07-31 11:35:17 -07:00
w.Write([]byte{'\n'})
2024-05-01 18:09:02 -07:00
return
2024-07-31 11:35:17 -07:00
} else {
fmt.Println(err)
2024-05-01 18:09:02 -07:00
}
}
w.Header().Add("Content-Type", "application/json")
fmt.Fprintln(w, "{\"msg\":\"No hashes found\"}")
}
func (s *Server) add_cover(w http.ResponseWriter, r *http.Request) {
user, authed := s.authenticated(w, r)
if !authed || user == "" {
http.Error(w, "Invalid Auth", http.StatusForbidden)
return
}
var (
values = r.URL.Query()
domain = strings.TrimSpace(values.Get("domain"))
id = strings.TrimSpace(values.Get("id"))
)
if id == "" {
log.Println("No ID Provided")
http.Error(w, "No ID Provided", http.StatusBadRequest)
return
}
if domain == "" {
log.Println("No domain Provided")
http.Error(w, "No domain Provided", http.StatusBadRequest)
return
}
2024-07-31 11:35:17 -07:00
i, format, err := image.Decode(r.Body)
2024-05-01 18:09:02 -07:00
if err != nil {
msg := fmt.Sprintf("Failed to decode Image: %s", err)
log.Println(msg)
http.Error(w, msg, http.StatusBadRequest)
return
}
log.Printf("Decoded %s image from %s", format, user)
2024-07-31 11:35:17 -07:00
s.hashingQueue <- im{i, format, domain, id}
fmt.Fprintln(w, "Success")
}
func (s *Server) mapHashes(id string, ahash, dhash, phash *goimagehash.ImageHash) {
s.fAhash[ahash.GetHash()] = id
s.fDhash[dhash.GetHash()] = id
s.fPhash[phash.GetHash()] = id
for i, partial_hash := range SplitHash(ahash.GetHash()) {
s.ahash[i][partial_hash] = id
}
for i, partial_hash := range SplitHash(dhash.GetHash()) {
s.dhash[i][partial_hash] = id
}
for i, partial_hash := range SplitHash(phash.GetHash()) {
s.phash[i][partial_hash] = id
}
}
func (s *Server) initHashes() {
for i := range s.ahash {
s.ahash[i] = make(map[uint8]string)
}
for i := range s.dhash {
s.dhash[i] = make(map[uint8]string)
}
for i := range s.phash {
s.phash[i] = make(map[uint8]string)
}
s.fAhash = make(map[uint64]string)
s.fDhash = make(map[uint64]string)
s.fPhash = make(map[uint64]string)
s.IDToCover = make(map[string]string)
}
func MemStats() uint64 {
var m runtime.MemStats
runtime.ReadMemStats(&m)
return m.Alloc
}
func (s *Server) mapper() {
var total uint64 = 0
for {
select {
case hash := <-s.mappingQueue:
if total%1000 == 0 {
mem := MemStats()
if mem > 10*1024*1024*1024 {
fmt.Println("Forcing gc", mem, "G")
runtime.GC()
}
}
total++
s.mapHashes(hash.domain+":"+hash.id, hash.ahash, hash.dhash, hash.phash)
}
}
}
func hashImage(i im) hash {
if i.format == "webp" {
i.im = goimagehash.FancyUpscale(i.im.(*image.YCbCr))
}
2024-05-01 18:09:02 -07:00
var (
2024-07-31 11:35:17 -07:00
err error = nil
2024-05-01 18:09:02 -07:00
ahash *goimagehash.ImageHash
dhash *goimagehash.ImageHash
phash *goimagehash.ImageHash
)
2024-07-31 11:35:17 -07:00
ahash, err = goimagehash.AverageHash(i.im)
2024-05-01 18:09:02 -07:00
if err != nil {
msg := fmt.Sprintf("Failed to ahash Image: %s", err)
log.Println(msg)
2024-07-31 11:35:17 -07:00
return hash{}
2024-05-01 18:09:02 -07:00
}
2024-07-31 11:35:17 -07:00
dhash, err = goimagehash.DifferenceHash(i.im)
2024-05-01 18:09:02 -07:00
if err != nil {
msg := fmt.Sprintf("Failed to dhash Image: %s", err)
log.Println(msg)
2024-07-31 11:35:17 -07:00
return hash{}
2024-05-01 18:09:02 -07:00
}
2024-07-31 11:35:17 -07:00
phash, err = goimagehash.PerceptionHash(i.im)
2024-05-01 18:09:02 -07:00
if err != nil {
msg := fmt.Sprintf("Failed to phash Image: %s", err)
log.Println(msg)
2024-07-31 11:35:17 -07:00
return hash{}
2024-05-01 18:09:02 -07:00
}
2024-07-31 11:35:17 -07:00
return hash{
ahash: ahash,
dhash: dhash,
phash: phash,
domain: i.domain,
id: i.id,
2024-05-01 18:09:02 -07:00
}
}
2024-07-31 11:35:17 -07:00
func (s *Server) hasher(workerID int) {
for {
select {
case i := <-s.hashingQueue:
start := time.Now()
hash := hashImage(i)
if hash.domain ==""{
continue
}
s.mappingQueue <- hash
t := time.Now()
elapsed := t.Sub(start)
// fmt.Printf("%#064b\n", ahash.GetHash())
// fmt.Printf("%#064b\n", dhash.GetHash())
// fmt.Printf("%#064b\n", phash.GetHash())
fmt.Printf("Hashing took %v: worker: %v\n", elapsed, workerID)
}
2024-05-01 18:09:02 -07:00
}
}
func SplitHash(hash uint64) [8]uint8 {
return [8]uint8{
uint8((hash & h_8) >> shift_8),
uint8((hash & h_7) >> shift_7),
uint8((hash & h_6) >> shift_6),
uint8((hash & h_5) >> shift_5),
uint8((hash & h_4) >> shift_4),
uint8((hash & h_3) >> shift_3),
uint8((hash & h_2) >> shift_2),
uint8((hash & h_1) >> shift_1),
}
}
// func (s *Server) CoverByID(id string) uint32 {
// v,ok :=s.IDToCover[id]
// return 0
// }
func (s *Server) FindHashes() {
}
2024-07-31 11:35:17 -07:00
func startServer(cover_path string) {
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt)
fmt.Println("Fuck")
2024-05-01 18:09:02 -07:00
mux := http.NewServeMux()
server := Server{
2024-07-31 11:35:17 -07:00
token: make(chan *oidc.Tokens),
hashingQueue: make(chan im, 5000),
mappingQueue: make(chan hash, 5000000),
mux: mux,
2024-05-01 18:09:02 -07:00
httpServer: &http.Server{
Addr: ":8080",
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
},
}
2024-07-31 11:35:17 -07:00
fmt.Println("init hashes")
2024-05-01 18:09:02 -07:00
server.initHashes()
// server.setupOauthHandlers()
2024-07-31 11:35:17 -07:00
fmt.Println("init handlers")
2024-05-01 18:09:02 -07:00
server.setupAppHandlers()
2024-07-31 11:35:17 -07:00
fmt.Println("init hashers")
go server.hasher(1)
go server.hasher(2)
go server.hasher(3)
go server.hasher(4)
go server.hasher(5)
go server.hasher(6)
go server.hasher(7)
go server.hasher(8)
go server.hasher(9)
go server.hasher(10)
fmt.Println("init mapper")
go server.mapper()
fmt.Println("Starting local hashing go routine")
go func() {
fmt.Println("Hashing covers at ", cover_path)
err := filepath.WalkDir(cover_path, func(path string, d fs.DirEntry, err error) error {
select {
case s := <-sig:
server.httpServer.Shutdown(context.TODO())
return fmt.Errorf("Signal: %v", s)
default:
}
if d.IsDir() {
return nil
}
file, err := os.Open(path)
if err != nil {
panic(err)
}
i, format, err := image.Decode(bufio.NewReader(file))
if err != nil {
return nil // skip this image
}
file.Close()
// fmt.Println("Hashing", path)
server.hashingQueue <- im{i, format, "comicvine.gamespot.com", filepath.Base(filepath.Dir(path))}
return nil
})
fmt.Println(err)
}()
fmt.Println("Listening on ", server.httpServer.Addr)
2024-05-01 18:09:02 -07:00
err := server.httpServer.ListenAndServe()
if err != nil {
2024-07-31 11:35:17 -07:00
fmt.Println(err)
}
f, er := os.Create("memprofile")
if er != nil {
fmt.Println("Error in creating file for writing memory profile to: ", er)
return
}
defer f.Close()
runtime.GC()
if e := pprof.WriteHeapProfile(f); e != nil {
fmt.Println("Error in writing memory profile: ", e)
return
2024-05-01 18:09:02 -07:00
}
}