diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a913888..c956fc8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: go-imports args: [-w] - repo: https://github.com/golangci/golangci-lint - rev: v1.59.1 + rev: v1.60.3 hooks: - id: golangci-lint - repo: https://github.com/asottile/setup-cfg-fmt diff --git a/BasicMap.go b/BasicMap.go new file mode 100644 index 0000000..53ffed6 --- /dev/null +++ b/BasicMap.go @@ -0,0 +1,151 @@ +package ch + +import ( + "fmt" + "math/bits" + "sync" + + "gitea.narnian.us/lordwelch/goimagehash" +) + +type basicMapStorage struct { + hashMutex sync.RWMutex + + ids map[ID]*[]ID + hashes [3]map[uint64]*[]ID +} + +func (b *basicMapStorage) Atleast(hashKind goimagehash.Kind, maxDistance int, searchHash uint64) []Result { + hashType := int(hashKind) - 1 + matchingHashes := make([]Result, 0, 100) // hope that we don't need all of them + for storedHash, ids := range b.hashes[hashType] { + distance := bits.OnesCount64(searchHash ^ storedHash) + if distance <= maxDistance { + matchingHashes = append(matchingHashes, Result{ToIDList(*ids), distance, Hash{storedHash, hashKind}}) + } + } + return matchingHashes +} +func (b *basicMapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) { + var foundMatches []Result + b.hashMutex.RLock() + defer b.hashMutex.RUnlock() + resetTime() + + if exactOnly { // exact matches are also found by partial matches. Don't bother with exact matches so we don't have to de-duplicate + for _, hash := range hashes { + hashType := int(hash.Kind) - 1 + ids := b.hashes[hashType][hash.Hash] + if ids != nil && len(*ids) > 0 { + foundMatches = append(foundMatches, Result{ + Distance: 0, + Hash: hash, + IDs: ToIDList(*ids), + }) + } + } + + // If we have exact matches don't bother with other matches + if len(foundMatches) > 0 && exactOnly { + return foundMatches, nil + } + logTime("Search Exact") + } + + foundHashes := make(map[uint64]struct{}) + totalPartialHashes := 0 + for _, hash := range hashes { + for _, match := range b.Atleast(hash.Kind, max, hash.Hash) { + _, alreadyMatched := foundHashes[match.Hash.Hash] + if alreadyMatched { + continue + } + foundHashes[match.Hash.Hash] = struct{}{} + foundMatches = append(foundMatches, match) + } + + } + fmt.Println("Total partial hashes tested:", totalPartialHashes, len(foundHashes)) + logTime("Search Complete") + go b.printSizes() + return foundMatches, nil +} + +func (b *basicMapStorage) MapHashes(hash ImageHash) { + for _, ih := range hash.Hashes { + var ( + hashType = int(ih.Kind) - 1 + ) + + *b.hashes[hashType][ih.Hash] = InsertID((*b.hashes[hashType][ih.Hash]), hash.ID) + } +} + +func (b *basicMapStorage) DecodeHashes(hashes SavedHashes) error { + for hashType, sourceHashes := range hashes.Hashes { + b.hashes[hashType] = make(map[uint64]*[]ID, len(sourceHashes)) + for savedHash, idlistLocation := range sourceHashes { + b.hashes[hashType][savedHash] = &hashes.IDs[idlistLocation] + } + } + b.printSizes() + return nil +} + +func (b *basicMapStorage) printSizes() { + // fmt.Println("Size of", "hashes:", size.Of(b.hashes)/1024/1024, "MB") + // fmt.Println("Size of", "ids:", size.Of(b.ids)/1024/1024, "MB") + // fmt.Println("Size of", "basicMapStorage:", size.Of(b)/1024/1024, "MB") + +} + +func (b *basicMapStorage) EncodeHashes() (SavedHashes, error) { + hashes := SavedHashes{} + idmap := map[*[]ID]int{} + for _, ids := range b.ids { + if _, ok := idmap[ids]; ok { + continue + } + hashes.IDs = append(hashes.IDs, *ids) + idmap[ids] = len(hashes.IDs) + } + for hashType, hashToID := range b.hashes { + for hash, ids := range hashToID { + hashes.Hashes[hashType][hash] = idmap[ids] + } + } + return hashes, nil +} + +func (b *basicMapStorage) AssociateIDs(newids []NewIDs) { + for _, newid := range newids { + ids, found := b.ids[newid.OldID] + if !found { + msg := "No IDs belonging to " + newid.OldID.Domain + "exist on this server" + panic(msg) + } + *ids = InsertID(*ids, newid.NewID) + } +} + +func (b *basicMapStorage) GetIDs(id ID) IDList { + ids, found := b.ids[id] + if !found { + msg := "No IDs belonging to " + id.Domain + "exist on this server" + panic(msg) + } + return ToIDList(*ids) +} + +func NewBasicMapStorage() (HashStorage, error) { + storage := &basicMapStorage{ + hashMutex: sync.RWMutex{}, + + hashes: [3]map[uint64]*[]ID{ + make(map[uint64]*[]ID), + make(map[uint64]*[]ID), + make(map[uint64]*[]ID), + }, + } + return storage, nil +} diff --git a/cmd/comic-hasher/main.go b/cmd/comic-hasher/main.go index 7263d84..4daeea2 100644 --- a/cmd/comic-hasher/main.go +++ b/cmd/comic-hasher/main.go @@ -91,20 +91,26 @@ type Storage int const ( Map = iota + 1 + BasicMap Sqlite Sqlite3 + VPTree ) var storageNames = map[Storage]string{ - Map: "map", - Sqlite: "sqlite", - Sqlite3: "sqlite3", + Map: "map", + BasicMap: "basicmap", + Sqlite: "sqlite", + Sqlite3: "sqlite3", + VPTree: "vptree", } var storageValues = map[string]Storage{ - "map": Map, - "sqlite": Sqlite, - "sqlite3": Sqlite3, + "map": Map, + "basicmap": BasicMap, + "sqlite": Sqlite, + "sqlite3": Sqlite3, + "vptree": VPTree, } func (f Storage) String() string { @@ -138,7 +144,7 @@ type Opts struct { } func main() { - opts := Opts{format: Msgpack, storageType: Map} // flag is weird + opts := Opts{format: Msgpack, storageType: BasicMap} // flag is weird go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) }() @@ -150,7 +156,7 @@ func main() { flag.BoolVar(&opts.saveEmbeddedHashes, "save-embedded-hashes", false, "Save hashes even if we loaded the embedded hashes") flag.StringVar(&opts.hashesPath, "hashes", "hashes.gz", "Path to optionally gziped hashes in msgpack or json format. You must disable embedded hashes to use this option") flag.Var(&opts.format, "save-format", "Specify the format to export hashes to (json, msgpack)") - flag.Var(&opts.storageType, "storage-type", "Specify the storage type used internally to search hashes (sqlite,sqlite3,map)") + flag.Var(&opts.storageType, "storage-type", "Specify the storage type used internally to search hashes (sqlite,sqlite3,map,basicmap,vptree)") flag.Parse() if opts.coverPath != "" { @@ -350,6 +356,7 @@ func (s *Server) matchCoverHash(w http.ResponseWriter, r *http.Request) { max int = 8 max_tmp int err error + hashes []ch.Hash ) if ahash, err = strconv.ParseUint(ahashStr, 16, 64); err != nil && ahashStr != "" { @@ -357,16 +364,25 @@ func (s *Server) matchCoverHash(w http.ResponseWriter, r *http.Request) { writeJson(w, http.StatusBadRequest, result{Msg: "hash parse failed"}) return } + if ahash > 0 { + hashes = append(hashes, ch.Hash{ahash, goimagehash.AHash}) + } if dhash, err = strconv.ParseUint(dhashStr, 16, 64); err != nil && dhashStr != "" { log.Printf("could not parse dhash: %s", dhashStr) writeJson(w, http.StatusBadRequest, result{Msg: "hash parse failed"}) return } + if dhash > 0 { + hashes = append(hashes, ch.Hash{dhash, goimagehash.DHash}) + } if phash, err = strconv.ParseUint(phashStr, 16, 64); err != nil && phashStr != "" { log.Printf("could not parse phash: %s", phashStr) writeJson(w, http.StatusBadRequest, result{Msg: "hash parse failed"}) return } + if phash > 0 { + hashes = append(hashes, ch.Hash{phash, goimagehash.PHash}) + } if max_tmp, err = strconv.Atoi(maxStr); err != nil && maxStr != "" { log.Printf("Invalid Max: %s", maxStr) writeJson(w, http.StatusBadRequest, result{Msg: fmt.Sprintf("Invalid Max: %s", maxStr)}) @@ -381,7 +397,10 @@ func (s *Server) matchCoverHash(w http.ResponseWriter, r *http.Request) { writeJson(w, http.StatusBadRequest, result{Msg: fmt.Sprintf("Max must be less than 9: %d", max)}) return } - matches, err := s.hashes.GetMatches([]ch.Hash{{ahash, goimagehash.AHash}, {dhash, goimagehash.DHash}, {phash, goimagehash.PHash}}, max, exactOnly) + matches, err := s.hashes.GetMatches(hashes, max, exactOnly) + slices.SortFunc(matches, func(a ch.Result, b ch.Result) int { + return cmp.Compare(a.Distance, b.Distance) + }) log.Println(err) if len(matches) > 0 { var msg string = "" @@ -532,10 +551,15 @@ func (s *Server) DecodeHashes(format Format, hashes []byte) error { default: return fmt.Errorf("Unknown format: %v", format) } - loadedHashes := make(ch.SavedHashes) + loadedHashes := ch.SavedHashes{} err := decoder(hashes, &loadedHashes) - if err != nil { - return err + if err != nil || len(loadedHashes.IDs) == 0 { + fmt.Println("Failed to load hashes, checking if they are old hashes", err) + oldHashes := make(ch.OldSavedHashes) + if err = decoder(hashes, &oldHashes); err != nil { + return err + } + loadedHashes = ch.ConvertSavedHashes(oldHashes) } return s.hashes.DecodeHashes(loadedHashes) @@ -597,10 +621,14 @@ func initializeStorage(opts Opts) (ch.HashStorage, error) { switch opts.storageType { case Map: return ch.NewMapStorage() + case BasicMap: + return ch.NewBasicMapStorage() case Sqlite: return ch.NewSqliteStorage("sqlite", opts.sqlitePath) case Sqlite3: return ch.NewSqliteStorage("sqlite3", opts.sqlitePath) + case VPTree: + return ch.NewVPStorage() } return nil, errors.New("Unknown storage type provided") } diff --git a/cmd/comic-hasher/tmp.go b/cmd/comic-hasher/tmp.go new file mode 100644 index 0000000..cb5d1a2 --- /dev/null +++ b/cmd/comic-hasher/tmp.go @@ -0,0 +1,17 @@ +//go:build main + +package main + +import ( + "fmt" + "time" +) + +func main() { + tmp := make([]string, 0, 932456) + for range 932460 { + tmp = append(tmp, "comicvine.gamespot.com:123456") + } + fmt.Println(len(tmp)) + time.Sleep(time.Minute) +} diff --git a/go.mod b/go.mod index 3f7a260..a37a440 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module gitea.narnian.us/lordwelch/comic-hasher -go 1.22.1 - -toolchain go1.22.2 +go 1.23.0 require ( gitea.narnian.us/lordwelch/goimagehash v0.0.0-20240812025715-33ff96e45f00 @@ -10,8 +8,10 @@ require ( github.com/kr/pretty v0.1.0 github.com/mattn/go-sqlite3 v1.14.22 github.com/mholt/archiver/v4 v4.0.0-alpha.8 + github.com/ncruces/go-sqlite3 v0.18.1 golang.org/x/image v0.19.0 golang.org/x/text v0.17.0 + gonum.org/v1/gonum v0.15.1 modernc.org/sqlite v1.32.0 ) @@ -40,14 +40,16 @@ require ( github.com/kr/text v0.1.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/ncruces/julianday v1.0.0 // indirect github.com/nwaples/rardecode/v2 v2.0.0-beta.2 // indirect github.com/pierrec/lz4/v4 v4.1.15 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/tetratelabs/wazero v1.8.0 // indirect github.com/therootcompany/xz v1.0.1 // indirect github.com/ulikunitz/xz v0.5.10 // indirect go4.org v0.0.0-20200411211856-f5505b9728dd // indirect golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect - golang.org/x/sys v0.22.0 // indirect + golang.org/x/sys v0.24.0 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.55.3 // indirect modernc.org/mathutil v1.6.0 // indirect diff --git a/go.sum b/go.sum index ae9d7be..acecc77 100644 --- a/go.sum +++ b/go.sum @@ -115,8 +115,12 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mholt/archiver/v4 v4.0.0-alpha.8 h1:tRGQuDVPh66WCOelqe6LIGh0gwmfwxUrSSDunscGsRM= github.com/mholt/archiver/v4 v4.0.0-alpha.8/go.mod h1:5f7FUYGXdJWUjESffJaYR4R60VhnHxb2X3T1teMyv5A= +github.com/ncruces/go-sqlite3 v0.18.1 h1:iN8IMZV5EMxpH88NUac9vId23eTKNFUhP7jgY0EBbNc= +github.com/ncruces/go-sqlite3 v0.18.1/go.mod h1:eEOyZnW1dGTJ+zDpMuzfYamEUBtdFz5zeYhqLBtHxvM= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= +github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= github.com/nwaples/rardecode/v2 v2.0.0-beta.2 h1:e3mzJFJs4k83GXBEiTaQ5HgSc/kOK8q0rDaRO0MPaOk= github.com/nwaples/rardecode/v2 v2.0.0-beta.2/go.mod h1:yntwv/HfMc/Hbvtq9I19D1n58te3h6KsqCf3GxyfBGY= github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= @@ -133,6 +137,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tetratelabs/wazero v1.8.0 h1:iEKu0d4c2Pd+QSRieYbnQC9yiFlMS9D+Jr0LsRmcF4g= +github.com/tetratelabs/wazero v1.8.0/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= github.com/therootcompany/xz v1.0.1 h1:CmOtsn1CbtmyYiusbfmhmkpAAETj0wBIH6kCYaX+xzw= github.com/therootcompany/xz v1.0.1/go.mod h1:3K3UH1yCKgBneZYhuQUvJ9HPD19UEXEI0BWbMn8qNMY= github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= @@ -232,8 +238,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -273,6 +279,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0= +gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= diff --git a/hashing.go b/hashing.go index f8551c5..2b96f05 100644 --- a/hashing.go +++ b/hashing.go @@ -77,7 +77,82 @@ type Hash struct { Kind goimagehash.Kind } -type SavedHashes map[Source]map[string][3]uint64 +// IDList is a map of domain to ID eg IDs["comicvine.gamespot.com"] = []string{"1235"} +// Maps are extremely expensive in go for small maps this should only be used to return info to a user no internal code should use this +type IDList map[Source][]string + +type OldSavedHashes map[Source]map[string][3]uint64 + +type SavedHashes struct { + IDs [][]ID + Hashes [3]map[uint64]int +} + +func ToIDList(ids []ID) IDList { + idlist := IDList{} + for _, id := range ids { + idlist[id.Domain] = Insert(idlist[id.Domain], id.ID) + } + return idlist +} +func InsertID(ids []ID, id ID) []ID { + index, itemFound := slices.BinarySearchFunc(ids, id, func(e ID, t ID) int { + return cmp.Or( + cmp.Compare(e.Domain, t.Domain), + cmp.Compare(e.ID, t.ID), + ) + }) + if itemFound { + return ids + } + return slices.Insert(ids, index, id) +} +func (s *SavedHashes) InsertHash(hash Hash, id ID) { + for i, h := range s.Hashes { + if h == nil { + s.Hashes[i] = make(map[uint64]int) + } + } + + hashType := int(hash.Kind) - 1 + idx, hashFound := s.Hashes[hashType][hash.Hash] + if !hashFound { + idx = len(s.IDs) + s.IDs = append(s.IDs, make([]ID, 0, 3)) + } + s.IDs[idx] = InsertID(s.IDs[idx], id) + s.Hashes[hashType][hash.Hash] = idx +} + +func ConvertSavedHashes(oldHashes OldSavedHashes) SavedHashes { + t := SavedHashes{} + idcount := 0 + for _, ids := range oldHashes { + idcount += len(ids) + } + t.IDs = make([][]ID, 0, idcount) + t.Hashes[0] = make(map[uint64]int, idcount) + t.Hashes[1] = make(map[uint64]int, idcount) + t.Hashes[2] = make(map[uint64]int, idcount) + for domain, sourceHashes := range oldHashes { + for id, hashes := range sourceHashes { + idx := len(t.IDs) + t.IDs = append(t.IDs, []ID{{domain, id}}) + for hashType, hash := range hashes { + t.Hashes[hashType][hash] = idx + } + } + } + fmt.Println("Expected number of IDs", idcount) + idcount = 0 + for _, ids := range t.IDs { + idcount += len(ids) + } + fmt.Println("length of hashes", len(t.Hashes[0])+len(t.Hashes[1])+len(t.Hashes[2])) + fmt.Println("Length of ID lists", len(t.IDs)) + fmt.Println("Total number of IDs", idcount) + return t +} type NewIDs struct { OldID ID @@ -171,5 +246,3 @@ func SplitHash(hash uint64) [8]uint8 { uint8((hash & H0) >> Shift0), } } - -type IDList map[Source][]string // IDs is a map of domain to ID eg IDs['comicvine.gamespot.com'] = []string{"1235"} diff --git a/map.go b/map.go index 7d4443a..09615e2 100644 --- a/map.go +++ b/map.go @@ -1,100 +1,32 @@ package ch import ( - "cmp" - "math/bits" + "fmt" "slices" "sync" - - "gitea.narnian.us/lordwelch/goimagehash" ) -type mapStorage struct { - hashMutex sync.RWMutex - partialHash [3][8]map[uint8][]int - // partialAhash [8]map[uint8][]int - // partialDhash [8]map[uint8][]int - // partialPhash [8]map[uint8][]int - - ids []ID - - idToHash map[int][3][]int - - hashes [3][]uint64 - // ahashes []uint64 - // dhashes []uint64 - // phashes []uint64 - - hashToID [3]map[int][]int - // ahashToID map[int][]int - // dhashToID map[int][]int - // phashToID map[int][]int +type MapStorage struct { + basicMapStorage + partialHash [3][8]map[uint8][]uint64 } -func (m *mapStorage) addID(id ID) int { - index, itemFound := slices.BinarySearchFunc(m.ids, id, func(existing, new ID) int { - return cmp.Or( - cmp.Compare(existing.Domain, new.Domain), - cmp.Compare(existing.ID, new.ID), - ) - }) - if itemFound { - return index - } - m.ids = slices.Insert(m.ids, index, id) - return index -} - -func (m *mapStorage) getID(id ID) (int, bool) { - return slices.BinarySearchFunc(m.ids, id, func(existing, new ID) int { - return cmp.Or( - cmp.Compare(existing.Domain, new.Domain), - cmp.Compare(existing.ID, new.ID), - ) - }) -} - -func (m *mapStorage) Atleast(hashKind goimagehash.Kind, maxDistance int, searchHash uint64, hashes []int) []Result { - hashType := int(hashKind) - 1 - matchingHashes := make([]Result, 0, len(hashes)/2) // hope that we don't need all of them - for _, idx := range hashes { - storedHash := m.hashes[hashType][idx] - distance := bits.OnesCount64(searchHash ^ storedHash) - if distance <= maxDistance { - ids := make(IDList) - for _, idLocation := range m.hashToID[hashType][idx] { - ids[m.ids[idLocation].Domain] = Insert(ids[m.ids[idLocation].Domain], m.ids[idLocation].ID) - } - matchingHashes = append(matchingHashes, Result{ids, distance, Hash{storedHash, hashKind}}) - } - } - return matchingHashes -} -func (m *mapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) { +func (m *MapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) { var foundMatches []Result m.hashMutex.RLock() defer m.hashMutex.RUnlock() + resetTime() if exactOnly { // exact matches are also found by partial matches. Don't bother with exact matches so we don't have to de-duplicate for _, hash := range hashes { hashType := int(hash.Kind) - 1 - if hashLocation, found := slices.BinarySearch(m.hashes[hashType], hash.Hash); found { - idlist := make(IDList) - for _, idLocation := range m.hashToID[hashType][hashLocation] { - - for _, hashLocation := range m.idToHash[idLocation][0] { - for _, foundIDLocation := range m.hashToID[hashType][hashLocation] { - foundID := m.ids[foundIDLocation] - idlist[foundID.Domain] = Insert(idlist[foundID.Domain], foundID.ID) - } - } - } - if len(idlist) > 0 { - foundMatches = append(foundMatches, Result{ - Distance: 0, - Hash: hash, - }) - } + idlist := m.hashes[hashType][hash.Hash] + if idlist != nil && len(*idlist) > 0 { + foundMatches = append(foundMatches, Result{ + Distance: 0, + Hash: hash, + IDs: ToIDList(*idlist), + }) } } @@ -102,173 +34,114 @@ func (m *mapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Resul if len(foundMatches) > 0 && exactOnly { return foundMatches, nil } + logTime("Search Exact") } - foundHashes := make(map[uint64]struct{}) - for _, hash := range hashes { - if hash.Hash == 0 { - continue - } - hashType := int(hash.Kind) - 1 - for i, partialHash := range SplitHash(hash.Hash) { - for _, match := range m.Atleast(hash.Kind, max, hash.Hash, m.partialHash[hashType][i][partialHash]) { - _, alreadyMatched := foundHashes[match.Hash.Hash] - if alreadyMatched { - continue + totalPartialHashes := 0 + for _, searchHash := range hashes { + foundHashes := make(map[uint64]struct{}) + hashType := int(searchHash.Kind) - 1 + for i, partialHash := range SplitHash(searchHash.Hash) { + partialHashes := m.partialHash[hashType][i][partialHash] + totalPartialHashes += len(partialHashes) + for _, match := range Atleast(max, searchHash.Hash, partialHashes) { + _, alreadyMatched := foundHashes[match.Hash] + if matchedResults, ok := m.hashes[hashType][match.Hash]; ok && !alreadyMatched { + foundHashes[match.Hash] = struct{}{} + foundMatches = append(foundMatches, Result{IDs: ToIDList(*matchedResults), Distance: match.Distance, Hash: Hash{Hash: match.Hash, Kind: searchHash.Kind}}) } - foundMatches = append(foundMatches, match) } } } - + fmt.Println("Total partial hashes tested:", totalPartialHashes) + logTime("Search Complete") + go m.printSizes() return foundMatches, nil } -func (m *mapStorage) MapHashes(hash ImageHash) { - - idIndex := m.addID(hash.ID) - idHashes := m.idToHash[idIndex] +func (m *MapStorage) MapHashes(hash ImageHash) { + m.basicMapStorage.MapHashes(hash) for _, hash := range hash.Hashes { - var ( - hashIndex int - hashType = int(hash.Kind) - 1 - ) - m.hashes[hashType], hashIndex = InsertIdx(m.hashes[hashType], hash.Hash) + hashType := int(hash.Kind) - 1 for i, partialHash := range SplitHash(hash.Hash) { - m.partialHash[hashType][i][partialHash] = append(m.partialHash[hashType][i][partialHash], hashIndex) + m.partialHash[hashType][i][partialHash] = Insert(m.partialHash[hashType][i][partialHash], hash.Hash) } - idHashes[hashType] = Insert(idHashes[hashType], hashIndex) - m.hashToID[hashType][hashIndex] = Insert(m.hashToID[hashType][hashIndex], idIndex) } - m.idToHash[idIndex] = idHashes } -func (m *mapStorage) DecodeHashes(hashes SavedHashes) error { - - for _, sourceHashes := range hashes { - m.hashes[0] = make([]uint64, 0, len(sourceHashes)) - m.hashes[1] = make([]uint64, 0, len(sourceHashes)) - m.hashes[2] = make([]uint64, 0, len(sourceHashes)) - break +func (m *MapStorage) DecodeHashes(hashes SavedHashes) error { + for hashType, sourceHashes := range hashes.Hashes { + m.hashes[hashType] = make(map[uint64]*[]ID, len(sourceHashes)) + for savedHash, idlistLocation := range sourceHashes { + for i, partialHash := range SplitHash(savedHash) { + m.partialHash[hashType][i][partialHash] = append(m.partialHash[hashType][i][partialHash], savedHash) + } + m.hashes[hashType][savedHash] = &hashes.IDs[idlistLocation] + } } - for domain, sourceHashes := range hashes { - for id, h := range sourceHashes { - m.ids = append(m.ids, ID{Domain: Source(domain), ID: id}) - - for _, hash := range []Hash{Hash{h[0], goimagehash.AHash}, Hash{h[1], goimagehash.DHash}, Hash{h[2], goimagehash.PHash}} { - var ( - hashType = int(hash.Kind) - 1 - ) - m.hashes[hashType] = append(m.hashes[hashType], hash.Hash) + m.printSizes() + for _, partialHashes := range m.partialHash { + for _, partMap := range partialHashes { + for part, hashes := range partMap { + slices.Sort(hashes) + partMap[part] = slices.Compact(hashes) } } } - slices.SortFunc(m.ids, func(existing, new ID) int { - return cmp.Or( - cmp.Compare(existing.Domain, new.Domain), - cmp.Compare(existing.ID, new.ID), - ) - }) - slices.Sort(m.hashes[0]) - slices.Sort(m.hashes[1]) - slices.Sort(m.hashes[2]) - for domain, sourceHashes := range hashes { - for id, h := range sourceHashes { - m.MapHashes(ImageHash{ - Hashes: []Hash{{h[0], goimagehash.AHash}, {h[1], goimagehash.DHash}, {h[2], goimagehash.PHash}}, - ID: ID{Domain: Source(domain), ID: id}, - }) - } - } + m.printSizes() return nil } -func (m *mapStorage) EncodeHashes() (SavedHashes, error) { - hashes := make(SavedHashes) - for idLocation, hashLocation := range m.idToHash { - id := m.ids[idLocation] - _, ok := hashes[id.Domain] - if !ok { - hashes[id.Domain] = make(map[string][3]uint64) - } - // TODO: Add all hashes. Currently saved hashes does not allow multiple IDs for a single hash - hashes[id.Domain][id.ID] = [3]uint64{ - m.hashes[0][hashLocation[0][0]], - m.hashes[1][hashLocation[1][0]], - m.hashes[2][hashLocation[2][0]], - } - } - return hashes, nil -} +func (m *MapStorage) printSizes() { + fmt.Println("Length of hashes:", len(m.hashes[0])+len(m.hashes[1])+len(m.hashes[2])) + // fmt.Println("Size of", "hashes:", size.Of(m.hashes)/1024/1024, "MB") + // fmt.Println("Size of", "ids:", size.Of(m.ids)/1024/1024, "MB") + // fmt.Println("Size of", "MapStorage:", size.Of(m)/1024/1024, "MB") -func (m *mapStorage) AssociateIDs(newids []NewIDs) { - for _, ids := range newids { - oldIDLocation, found := m.getID(ids.OldID) - if !found { - msg := "No IDs belonging to " + ids.OldID.Domain + "exist on this server" - panic(msg) - } - - newIDLocation := m.addID(ids.NewID) - - for _, hashType := range []int{int(goimagehash.AHash), int(goimagehash.DHash), int(goimagehash.PHash)} { - for _, hashLocation := range m.idToHash[oldIDLocation][hashType] { - m.hashToID[hashType][hashLocation] = Insert(m.hashToID[hashType][hashLocation], newIDLocation) - idHashes := m.idToHash[newIDLocation] - idHashes[hashType] = Insert(idHashes[hashType], hashLocation) - m.idToHash[newIDLocation] = idHashes - } - } - } -} - -func (m *mapStorage) GetIDs(id ID) IDList { - idIndex, found := m.getID(id) - if !found { - msg := "No IDs belonging to " + id.Domain + "exist on this server" - panic(msg) - } - ids := make(IDList) - - for _, hashLocation := range m.idToHash[idIndex][0] { - for _, foundIDLocation := range m.hashToID[0][hashLocation] { - foundID := m.ids[foundIDLocation] - ids[foundID.Domain] = Insert(ids[foundID.Domain], foundID.ID) - } - } - for _, hashLocation := range m.idToHash[idIndex][1] { - for _, foundIDLocation := range m.hashToID[1][hashLocation] { - foundID := m.ids[foundIDLocation] - ids[foundID.Domain] = Insert(ids[foundID.Domain], foundID.ID) - } - } - for _, hashLocation := range m.idToHash[idIndex][2] { - for _, foundIDLocation := range m.hashToID[2][hashLocation] { - foundID := m.ids[foundIDLocation] - ids[foundID.Domain] = Insert(ids[foundID.Domain], foundID.ID) - } - } - return ids } func NewMapStorage() (HashStorage, error) { - storage := &mapStorage{ - hashMutex: sync.RWMutex{}, - idToHash: make(map[int][3][]int), - hashToID: [3]map[int][]int{ - make(map[int][]int), - make(map[int][]int), - make(map[int][]int), + storage := &MapStorage{ + basicMapStorage: basicMapStorage{ + hashMutex: sync.RWMutex{}, + hashes: [3]map[uint64]*[]ID{ + make(map[uint64]*[]ID), + make(map[uint64]*[]ID), + make(map[uint64]*[]ID), + }, + }, + partialHash: [3][8]map[uint8][]uint64{ + { + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + }, + { + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + }, + { + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + make(map[uint8][]uint64), + }, }, - } - for i := range storage.partialHash[0] { - storage.partialHash[0][i] = make(map[uint8][]int) - } - for i := range storage.partialHash[1] { - storage.partialHash[1][i] = make(map[uint8][]int) - } - for i := range storage.partialHash[2] { - storage.partialHash[2][i] = make(map[uint8][]int) } return storage, nil } diff --git a/sqlite.go b/sqlite.go index d2a2be5..d937390 100644 --- a/sqlite.go +++ b/sqlite.go @@ -8,6 +8,7 @@ import ( "log" "math/bits" "strings" + "time" "gitea.narnian.us/lordwelch/goimagehash" _ "modernc.org/sqlite" @@ -67,11 +68,11 @@ func (s *sqliteStorage) findExactHashes(statement *sql.Stmt, items ...interface{ func (s *sqliteStorage) findPartialHashes(max int, search_hash int64, kind goimagehash.Kind) ([]sqliteHash, error) { // exact matches are also found by partial matches. Don't bother with exact matches so we don't have to de-duplicate hashes := []sqliteHash{} - statement, err := s.db.PrepareContext(context.Background(), `SELECT rowid,hash,kind FROM Hashes WHERE (kind=?) AND (((hash >> (0 * 8) & 0xFF)=(? >> (0 * 8) & 0xFF)) OR ((hash >> (1 * 8) & 0xFF)=(? >> (1 * 8) & 0xFF)) OR ((hash >> (2 * 8) & 0xFF)=(? >> (2 * 8) & 0xFF)) OR ((hash >> (3 * 8) & 0xFF)=(? >> (3 * 8) & 0xFF)) OR ((hash >> (4 * 8) & 0xFF)=(? >> (4 * 8) & 0xFF)) OR ((hash >> (5 * 8) & 0xFF)=(? >> (5 * 8) & 0xFF)) OR ((hash >> (6 * 8) & 0xFF)=(? >> (6 * 8) & 0xFF)) OR ((hash >> (7 * 8) & 0xFF)=(? >> (7 * 8) & 0xFF))) ORDER BY kind,hash;`) + statement, err := s.db.PrepareContext(context.Background(), `SELECT rowid,hash,kind FROM Hashes WHERE (kind=?) AND (((hash >> (0 * 8) & 0xFF)=(?2 >> (0 * 8) & 0xFF)) OR ((hash >> (1 * 8) & 0xFF)=(?2 >> (1 * 8) & 0xFF)) OR ((hash >> (2 * 8) & 0xFF)=(?2 >> (2 * 8) & 0xFF)) OR ((hash >> (3 * 8) & 0xFF)=(?2 >> (3 * 8) & 0xFF)) OR ((hash >> (4 * 8) & 0xFF)=(?2 >> (4 * 8) & 0xFF)) OR ((hash >> (5 * 8) & 0xFF)=(?2 >> (5 * 8) & 0xFF)) OR ((hash >> (6 * 8) & 0xFF)=(?2 >> (6 * 8) & 0xFF)) OR ((hash >> (7 * 8) & 0xFF)=(?2 >> (7 * 8) & 0xFF)));`) if err != nil { return hashes, err } - rows, err := statement.Query(kind, int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash)) + rows, err := statement.Query(kind, int64(search_hash)) if err != nil { return hashes, err } @@ -93,6 +94,7 @@ func (s *sqliteStorage) findPartialHashes(max int, search_hash int64, kind goima } } rows.Close() + logTime("Filter partial " + kind.String()) statement, err = s.db.PrepareContext(context.Background(), `SELECT DISTINCT IDS.domain, IDs.id, id_hash.hashid FROM IDs JOIN id_hash ON IDs.rowid = id_hash.idid WHERE (id_hash.hashid in (`+strings.TrimRight(strings.Repeat("?,", len(hashes)), ",")+`)) ORDER BY IDs.domain, IDs.ID;`) if err != nil { @@ -161,6 +163,7 @@ CREATE INDEX IF NOT EXISTS hash_8_index ON Hashes ((hash >> (7 * 8) & 0xFF)); CREATE INDEX IF NOT EXISTS id_domain ON IDs (domain, id); PRAGMA shrink_memory; +ANALYZE; `) if err != nil { return err @@ -168,15 +171,38 @@ PRAGMA shrink_memory; return nil } +var ( + total time.Duration + t = time.Now() +) + +func resetTime() { + total = 0 + t = time.Now() +} + +func logTime(log string) { + n := time.Now() + s := n.Sub(t) + t = n + total += s + fmt.Printf("total: %v, %s: %v\n", total, log, s) +} + func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) { - var foundMatches []Result + var ( + foundMatches []Result + ) + resetTime() if exactOnly { // exact matches are also found by partial matches. Don't bother with exact matches so we don't have to de-duplicate statement, err := s.db.Prepare(`SELECT rowid,hash,kind FROM Hashes WHERE ` + strings.TrimSuffix(strings.Repeat("(hash=? AND kind=?) OR", len(hashes)), "OR") + `ORDER BY kind,hash;`) if err != nil { + logTime("Fail exact") return foundMatches, err } + args := make([]interface{}, 0, len(hashes)*2) for _, hash := range hashes { if hash.Hash != 0 { @@ -195,6 +221,7 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re if len(foundMatches) > 0 && exactOnly { return foundMatches, nil } + logTime("Search Exact") } foundHashes := make(map[uint64]struct{}) @@ -204,6 +231,7 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re if err != nil { return foundMatches, err } + logTime("Search partial " + hash.Kind.String()) for _, hash := range hashes { if _, alreadyMatched := foundHashes[hash.Hash.Hash]; !alreadyMatched { @@ -219,14 +247,18 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re } func (s *sqliteStorage) MapHashes(hash ImageHash) { - insertHashes, err := s.db.Prepare(` -INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO UPDATE SET hash=?1 RETURNING hashid; + tx, err := s.db.BeginTx(context.Background(), nil) + if err != nil { + panic(err) + } + insertHashes, err := tx.Prepare(` +INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO UPDATE SET hash=?1 RETURNING hashid `) if err != nil { panic(err) } - rows, err := s.db.Query(` -INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RETURNING idid; + rows, err := tx.Query(` +INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RETURNING idid `, hash.ID.Domain, hash.ID.ID) if err != nil { panic(err) @@ -258,12 +290,19 @@ INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RET } hash_ids = append(hash_ids, id) } + var ids []any for _, hash_id := range hash_ids { - _, err = s.db.Exec(`INSERT INTO id_hash (hashid,idid) VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id) - if err != nil { - panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err)) - } + ids = append(ids, hash_id, id_id) } + _, err = tx.Exec(`INSERT INTO id_hash (hashid,idid) VALUES `+strings.TrimSuffix(strings.Repeat("(?, ?),", len(hash_ids)), ",")+` ON CONFLICT DO NOTHING;`, ids...) + if err != nil { + panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err)) + } + err = tx.Commit() + if err != nil { + panic(err) + } + insertHashes.Close() } func (s *sqliteStorage) DecodeHashes(hashes SavedHashes) error { @@ -272,9 +311,15 @@ func (s *sqliteStorage) DecodeHashes(hashes SavedHashes) error { return err } - for domain, sourceHashes := range hashes { - for id, h := range sourceHashes { - s.MapHashes(ImageHash{[]Hash{{h[0], goimagehash.AHash}, {h[1], goimagehash.DHash}, {h[2], goimagehash.PHash}}, ID{domain, id}}) + for hashType, sourceHashes := range hashes.Hashes { + hashKind := goimagehash.Kind(hashType + 1) + for hash, idsLocations := range sourceHashes { + for _, id := range hashes.IDs[idsLocations] { + s.MapHashes(ImageHash{ + Hashes: []Hash{{hash, hashKind}}, + ID: id, + }) + } } } err = s.createIndexes() @@ -285,48 +330,27 @@ func (s *sqliteStorage) DecodeHashes(hashes SavedHashes) error { } func (s *sqliteStorage) EncodeHashes() (SavedHashes, error) { - hashes := make(SavedHashes) + hashes := SavedHashes{} conn, err := s.db.Conn(context.Background()) if err != nil { return hashes, err } defer conn.Close() - rows, err := conn.QueryContext(context.Background(), "SELECT DISTINCT (domain) FROM IDs ORDER BY domain;") + rows, err := conn.QueryContext(context.Background(), "SELECT IDs.domain,IDs.id,Hashes.hash,Hashes.kind FROM Hashes JOIN id_hash ON id_hash.hashid = hashes.rowid JOIN IDs ON IDs.rowid = id_hash.idid ORDER BY IDs.ID,Hashes.kind,Hashes.hash;") + if err != nil { + rows.Close() + return hashes, err + } + var ( + id ID + hash Hash + ) + err = rows.Scan(&id.Domain, &id.ID, &hash.Hash, &hash.Kind) if err != nil { return hashes, err } - sources := make([]string, 0, 10) - for rows.Next() { - var source string - if err = rows.Scan(&source); err != nil { - rows.Close() - return hashes, err - } - sources = append(sources, source) - } - for _, source := range sources { - rows, err = conn.QueryContext(context.Background(), "SELECT IDs.id,Hashes.hash,Hashes.kind FROM Hashes JOIN id_hash ON id_hash.hashid = hashes.rowid JOIN IDs ON IDs.rowid = id_hash.idid WHERE IDs.domain = ? ORDER BY IDs.ID,Hashes.kind,Hashes.hash;", source) - if err != nil { - rows.Close() - return hashes, err - } - var ( - id string - hash int64 - typ goimagehash.Kind - ) - err = rows.Scan(&id, &hash, &typ) - if err != nil { - return hashes, err - } - _, ok := hashes[Source(source)] - if !ok { - hashes[Source(source)] = make(map[string][3]uint64) - } - h := hashes[Source(source)][id] - h[typ-1] = uint64(hash) - hashes[Source(source)][id] = h - } + hashes.InsertHash(hash, id) + return hashes, nil } @@ -415,16 +439,6 @@ CREATE TABLE IF NOT EXISTS Hashes( UNIQUE(kind, hash) ); -CREATE INDEX IF NOT EXISTS hash_index ON Hashes (kind, hash); -CREATE INDEX IF NOT EXISTS hash_1_index ON Hashes ((hash >> (0 * 8) & 0xFF)); -CREATE INDEX IF NOT EXISTS hash_2_index ON Hashes ((hash >> (1 * 8) & 0xFF)); -CREATE INDEX IF NOT EXISTS hash_3_index ON Hashes ((hash >> (2 * 8) & 0xFF)); -CREATE INDEX IF NOT EXISTS hash_4_index ON Hashes ((hash >> (3 * 8) & 0xFF)); -CREATE INDEX IF NOT EXISTS hash_5_index ON Hashes ((hash >> (4 * 8) & 0xFF)); -CREATE INDEX IF NOT EXISTS hash_6_index ON Hashes ((hash >> (5 * 8) & 0xFF)); -CREATE INDEX IF NOT EXISTS hash_7_index ON Hashes ((hash >> (6 * 8) & 0xFF)); -CREATE INDEX IF NOT EXISTS hash_8_index ON Hashes ((hash >> (7 * 8) & 0xFF)); - CREATE TABLE IF NOT EXISTS IDs( id TEXT NOT NULL, domain TEXT NOT NULL, @@ -445,6 +459,7 @@ CREATE TABLE IF NOT EXISTS id_hash( if err != nil { panic(err) } + sqlite.createIndexes() sqlite.db.SetMaxOpenConns(1) return sqlite, nil } diff --git a/sqlite_no_cgo.go b/sqlite_no_cgo.go new file mode 100644 index 0000000..7ac3194 --- /dev/null +++ b/sqlite_no_cgo.go @@ -0,0 +1,8 @@ +//go:build !cgo + +package ch + +import ( + _ "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" +) diff --git a/vp-tree.go b/vp-tree.go new file mode 100644 index 0000000..f65937f --- /dev/null +++ b/vp-tree.go @@ -0,0 +1,105 @@ +package ch + +import ( + "errors" + "fmt" + "math/bits" + + "gitea.narnian.us/lordwelch/goimagehash" + "gonum.org/v1/gonum/spatial/vptree" +) + +type VPTree struct { + trees [3]*vptree.Tree + hashes [3][]vptree.Comparable +} +type VPHash struct { + Hash Hash + IDs []ID +} + +func (h *VPHash) Distance(c vptree.Comparable) float64 { + h2, ok := c.(*VPHash) + if !ok { + return -99 + } + return float64(bits.OnesCount64(h.Hash.Hash ^ h2.Hash.Hash)) +} + +func (v *VPTree) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) { + var matches []Result + var exactMatches []Result + fmt.Println(hashes) + for _, hash := range hashes { + results := vptree.NewDistKeeper(float64(max)) + hashType := int(hash.Kind) - 1 + v.trees[hashType].NearestSet(results, &VPHash{Hash: hash}) + for _, result := range results.Heap { + vphash := result.Comparable.(*VPHash) + if result.Dist == 0 { + exactMatches = append(exactMatches, Result{ + IDs: ToIDList(vphash.IDs), + Distance: int(result.Dist), + Hash: vphash.Hash, + }) + } else { + matches = append(matches, Result{ + IDs: ToIDList(vphash.IDs), + Distance: int(result.Dist), + Hash: vphash.Hash, + }) + } + } + } + if len(exactMatches) > 0 && exactOnly { + return exactMatches, nil + } + matches = append(exactMatches[:len(exactMatches):len(exactMatches)], matches...) + return matches, nil +} + +func (v *VPTree) MapHashes(ImageHash) { + panic("Not Implemented") +} + +func (v *VPTree) DecodeHashes(hashes SavedHashes) error { + var err error + for hashType, sourceHashes := range hashes.Hashes { + for hash, idsLocation := range sourceHashes { + var ( + hashKind = goimagehash.Kind(hashType + 1) + ) + hash := &VPHash{Hash{hash, hashKind}, hashes.IDs[idsLocation]} + v.hashes[hashType] = append(v.hashes[hashType], hash) + } + } + for hashType := range 3 { + v.trees[hashType], err = vptree.New(v.hashes[hashType], 3, nil) + if err != nil { + return err + } + } + return nil +} +func (v *VPTree) EncodeHashes() (SavedHashes, error) { + return SavedHashes{}, errors.New("Not Implemented") +} + +func (v *VPTree) AssociateIDs(newIDs []NewIDs) { + panic("Not Implemented") +} + +func (v *VPTree) GetIDs(id ID) IDList { + return nil +} + +func NewVPStorage() (HashStorage, error) { + + return &VPTree{ + hashes: [3][]vptree.Comparable{ + make([]vptree.Comparable, 0, 1_000_000), + make([]vptree.Comparable, 0, 1_000_000), + make([]vptree.Comparable, 0, 1_000_000), + }, + }, nil +}