Optimize memory usage

Add a basic map storage that does manual searches to conserve memory
Change saved hashes format to allow multiple hashes for a given ID
Add a vptree storage

Maps in Go take up a huge amount of space changing IDList to []ID took
  memory from over 1GB down to 200MB (note this was on aarch64 MacOS
  which for some reason uses less memory than aarch64 Linux).
  Exhaustive searches using slices took about 30 ms search now takes
  50-60 ms as it takes longer to iterate a map. Partial hashes will
  speed up searches to 8 ms at the cost of 700MB initial memory usage
  and 400MB idle (though this is on MacOS, which for some reason uses
  less memory that aarch64 Linux so probably more like
  900MB initial -> 600 MB idle on an RPI running Linux)
This commit is contained in:
Timmy Welch 2024-09-07 14:51:18 -07:00
parent b1de95021a
commit 0928ed6ccf
11 changed files with 581 additions and 301 deletions

View File

@ -13,7 +13,7 @@ repos:
- id: go-imports
args: [-w]
- repo: https://github.com/golangci/golangci-lint
rev: v1.59.1
rev: v1.60.3
hooks:
- id: golangci-lint
- repo: https://github.com/asottile/setup-cfg-fmt

151
BasicMap.go Normal file
View File

@ -0,0 +1,151 @@
package ch
import (
"fmt"
"math/bits"
"sync"
"gitea.narnian.us/lordwelch/goimagehash"
)
type basicMapStorage struct {
hashMutex sync.RWMutex
ids map[ID]*[]ID
hashes [3]map[uint64]*[]ID
}
func (b *basicMapStorage) Atleast(hashKind goimagehash.Kind, maxDistance int, searchHash uint64) []Result {
hashType := int(hashKind) - 1
matchingHashes := make([]Result, 0, 100) // hope that we don't need all of them
for storedHash, ids := range b.hashes[hashType] {
distance := bits.OnesCount64(searchHash ^ storedHash)
if distance <= maxDistance {
matchingHashes = append(matchingHashes, Result{ToIDList(*ids), distance, Hash{storedHash, hashKind}})
}
}
return matchingHashes
}
func (b *basicMapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) {
var foundMatches []Result
b.hashMutex.RLock()
defer b.hashMutex.RUnlock()
resetTime()
if exactOnly { // exact matches are also found by partial matches. Don't bother with exact matches so we don't have to de-duplicate
for _, hash := range hashes {
hashType := int(hash.Kind) - 1
ids := b.hashes[hashType][hash.Hash]
if ids != nil && len(*ids) > 0 {
foundMatches = append(foundMatches, Result{
Distance: 0,
Hash: hash,
IDs: ToIDList(*ids),
})
}
}
// If we have exact matches don't bother with other matches
if len(foundMatches) > 0 && exactOnly {
return foundMatches, nil
}
logTime("Search Exact")
}
foundHashes := make(map[uint64]struct{})
totalPartialHashes := 0
for _, hash := range hashes {
for _, match := range b.Atleast(hash.Kind, max, hash.Hash) {
_, alreadyMatched := foundHashes[match.Hash.Hash]
if alreadyMatched {
continue
}
foundHashes[match.Hash.Hash] = struct{}{}
foundMatches = append(foundMatches, match)
}
}
fmt.Println("Total partial hashes tested:", totalPartialHashes, len(foundHashes))
logTime("Search Complete")
go b.printSizes()
return foundMatches, nil
}
func (b *basicMapStorage) MapHashes(hash ImageHash) {
for _, ih := range hash.Hashes {
var (
hashType = int(ih.Kind) - 1
)
*b.hashes[hashType][ih.Hash] = InsertID((*b.hashes[hashType][ih.Hash]), hash.ID)
}
}
func (b *basicMapStorage) DecodeHashes(hashes SavedHashes) error {
for hashType, sourceHashes := range hashes.Hashes {
b.hashes[hashType] = make(map[uint64]*[]ID, len(sourceHashes))
for savedHash, idlistLocation := range sourceHashes {
b.hashes[hashType][savedHash] = &hashes.IDs[idlistLocation]
}
}
b.printSizes()
return nil
}
func (b *basicMapStorage) printSizes() {
// fmt.Println("Size of", "hashes:", size.Of(b.hashes)/1024/1024, "MB")
// fmt.Println("Size of", "ids:", size.Of(b.ids)/1024/1024, "MB")
// fmt.Println("Size of", "basicMapStorage:", size.Of(b)/1024/1024, "MB")
}
func (b *basicMapStorage) EncodeHashes() (SavedHashes, error) {
hashes := SavedHashes{}
idmap := map[*[]ID]int{}
for _, ids := range b.ids {
if _, ok := idmap[ids]; ok {
continue
}
hashes.IDs = append(hashes.IDs, *ids)
idmap[ids] = len(hashes.IDs)
}
for hashType, hashToID := range b.hashes {
for hash, ids := range hashToID {
hashes.Hashes[hashType][hash] = idmap[ids]
}
}
return hashes, nil
}
func (b *basicMapStorage) AssociateIDs(newids []NewIDs) {
for _, newid := range newids {
ids, found := b.ids[newid.OldID]
if !found {
msg := "No IDs belonging to " + newid.OldID.Domain + "exist on this server"
panic(msg)
}
*ids = InsertID(*ids, newid.NewID)
}
}
func (b *basicMapStorage) GetIDs(id ID) IDList {
ids, found := b.ids[id]
if !found {
msg := "No IDs belonging to " + id.Domain + "exist on this server"
panic(msg)
}
return ToIDList(*ids)
}
func NewBasicMapStorage() (HashStorage, error) {
storage := &basicMapStorage{
hashMutex: sync.RWMutex{},
hashes: [3]map[uint64]*[]ID{
make(map[uint64]*[]ID),
make(map[uint64]*[]ID),
make(map[uint64]*[]ID),
},
}
return storage, nil
}

View File

@ -91,20 +91,26 @@ type Storage int
const (
Map = iota + 1
BasicMap
Sqlite
Sqlite3
VPTree
)
var storageNames = map[Storage]string{
Map: "map",
Sqlite: "sqlite",
Sqlite3: "sqlite3",
Map: "map",
BasicMap: "basicmap",
Sqlite: "sqlite",
Sqlite3: "sqlite3",
VPTree: "vptree",
}
var storageValues = map[string]Storage{
"map": Map,
"sqlite": Sqlite,
"sqlite3": Sqlite3,
"map": Map,
"basicmap": BasicMap,
"sqlite": Sqlite,
"sqlite3": Sqlite3,
"vptree": VPTree,
}
func (f Storage) String() string {
@ -138,7 +144,7 @@ type Opts struct {
}
func main() {
opts := Opts{format: Msgpack, storageType: Map} // flag is weird
opts := Opts{format: Msgpack, storageType: BasicMap} // flag is weird
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
@ -150,7 +156,7 @@ func main() {
flag.BoolVar(&opts.saveEmbeddedHashes, "save-embedded-hashes", false, "Save hashes even if we loaded the embedded hashes")
flag.StringVar(&opts.hashesPath, "hashes", "hashes.gz", "Path to optionally gziped hashes in msgpack or json format. You must disable embedded hashes to use this option")
flag.Var(&opts.format, "save-format", "Specify the format to export hashes to (json, msgpack)")
flag.Var(&opts.storageType, "storage-type", "Specify the storage type used internally to search hashes (sqlite,sqlite3,map)")
flag.Var(&opts.storageType, "storage-type", "Specify the storage type used internally to search hashes (sqlite,sqlite3,map,basicmap,vptree)")
flag.Parse()
if opts.coverPath != "" {
@ -350,6 +356,7 @@ func (s *Server) matchCoverHash(w http.ResponseWriter, r *http.Request) {
max int = 8
max_tmp int
err error
hashes []ch.Hash
)
if ahash, err = strconv.ParseUint(ahashStr, 16, 64); err != nil && ahashStr != "" {
@ -357,16 +364,25 @@ func (s *Server) matchCoverHash(w http.ResponseWriter, r *http.Request) {
writeJson(w, http.StatusBadRequest, result{Msg: "hash parse failed"})
return
}
if ahash > 0 {
hashes = append(hashes, ch.Hash{ahash, goimagehash.AHash})
}
if dhash, err = strconv.ParseUint(dhashStr, 16, 64); err != nil && dhashStr != "" {
log.Printf("could not parse dhash: %s", dhashStr)
writeJson(w, http.StatusBadRequest, result{Msg: "hash parse failed"})
return
}
if dhash > 0 {
hashes = append(hashes, ch.Hash{dhash, goimagehash.DHash})
}
if phash, err = strconv.ParseUint(phashStr, 16, 64); err != nil && phashStr != "" {
log.Printf("could not parse phash: %s", phashStr)
writeJson(w, http.StatusBadRequest, result{Msg: "hash parse failed"})
return
}
if phash > 0 {
hashes = append(hashes, ch.Hash{phash, goimagehash.PHash})
}
if max_tmp, err = strconv.Atoi(maxStr); err != nil && maxStr != "" {
log.Printf("Invalid Max: %s", maxStr)
writeJson(w, http.StatusBadRequest, result{Msg: fmt.Sprintf("Invalid Max: %s", maxStr)})
@ -381,7 +397,10 @@ func (s *Server) matchCoverHash(w http.ResponseWriter, r *http.Request) {
writeJson(w, http.StatusBadRequest, result{Msg: fmt.Sprintf("Max must be less than 9: %d", max)})
return
}
matches, err := s.hashes.GetMatches([]ch.Hash{{ahash, goimagehash.AHash}, {dhash, goimagehash.DHash}, {phash, goimagehash.PHash}}, max, exactOnly)
matches, err := s.hashes.GetMatches(hashes, max, exactOnly)
slices.SortFunc(matches, func(a ch.Result, b ch.Result) int {
return cmp.Compare(a.Distance, b.Distance)
})
log.Println(err)
if len(matches) > 0 {
var msg string = ""
@ -532,10 +551,15 @@ func (s *Server) DecodeHashes(format Format, hashes []byte) error {
default:
return fmt.Errorf("Unknown format: %v", format)
}
loadedHashes := make(ch.SavedHashes)
loadedHashes := ch.SavedHashes{}
err := decoder(hashes, &loadedHashes)
if err != nil {
return err
if err != nil || len(loadedHashes.IDs) == 0 {
fmt.Println("Failed to load hashes, checking if they are old hashes", err)
oldHashes := make(ch.OldSavedHashes)
if err = decoder(hashes, &oldHashes); err != nil {
return err
}
loadedHashes = ch.ConvertSavedHashes(oldHashes)
}
return s.hashes.DecodeHashes(loadedHashes)
@ -597,10 +621,14 @@ func initializeStorage(opts Opts) (ch.HashStorage, error) {
switch opts.storageType {
case Map:
return ch.NewMapStorage()
case BasicMap:
return ch.NewBasicMapStorage()
case Sqlite:
return ch.NewSqliteStorage("sqlite", opts.sqlitePath)
case Sqlite3:
return ch.NewSqliteStorage("sqlite3", opts.sqlitePath)
case VPTree:
return ch.NewVPStorage()
}
return nil, errors.New("Unknown storage type provided")
}

17
cmd/comic-hasher/tmp.go Normal file
View File

@ -0,0 +1,17 @@
//go:build main
package main
import (
"fmt"
"time"
)
func main() {
tmp := make([]string, 0, 932456)
for range 932460 {
tmp = append(tmp, "comicvine.gamespot.com:123456")
}
fmt.Println(len(tmp))
time.Sleep(time.Minute)
}

10
go.mod
View File

@ -1,8 +1,6 @@
module gitea.narnian.us/lordwelch/comic-hasher
go 1.22.1
toolchain go1.22.2
go 1.23.0
require (
gitea.narnian.us/lordwelch/goimagehash v0.0.0-20240812025715-33ff96e45f00
@ -10,8 +8,10 @@ require (
github.com/kr/pretty v0.1.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/mholt/archiver/v4 v4.0.0-alpha.8
github.com/ncruces/go-sqlite3 v0.18.1
golang.org/x/image v0.19.0
golang.org/x/text v0.17.0
gonum.org/v1/gonum v0.15.1
modernc.org/sqlite v1.32.0
)
@ -40,14 +40,16 @@ require (
github.com/kr/text v0.1.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
github.com/nwaples/rardecode/v2 v2.0.0-beta.2 // indirect
github.com/pierrec/lz4/v4 v4.1.15 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/tetratelabs/wazero v1.8.0 // indirect
github.com/therootcompany/xz v1.0.1 // indirect
github.com/ulikunitz/xz v0.5.10 // indirect
go4.org v0.0.0-20200411211856-f5505b9728dd // indirect
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
golang.org/x/sys v0.22.0 // indirect
golang.org/x/sys v0.24.0 // indirect
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
modernc.org/libc v1.55.3 // indirect
modernc.org/mathutil v1.6.0 // indirect

12
go.sum
View File

@ -115,8 +115,12 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mholt/archiver/v4 v4.0.0-alpha.8 h1:tRGQuDVPh66WCOelqe6LIGh0gwmfwxUrSSDunscGsRM=
github.com/mholt/archiver/v4 v4.0.0-alpha.8/go.mod h1:5f7FUYGXdJWUjESffJaYR4R60VhnHxb2X3T1teMyv5A=
github.com/ncruces/go-sqlite3 v0.18.1 h1:iN8IMZV5EMxpH88NUac9vId23eTKNFUhP7jgY0EBbNc=
github.com/ncruces/go-sqlite3 v0.18.1/go.mod h1:eEOyZnW1dGTJ+zDpMuzfYamEUBtdFz5zeYhqLBtHxvM=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/nwaples/rardecode/v2 v2.0.0-beta.2 h1:e3mzJFJs4k83GXBEiTaQ5HgSc/kOK8q0rDaRO0MPaOk=
github.com/nwaples/rardecode/v2 v2.0.0-beta.2/go.mod h1:yntwv/HfMc/Hbvtq9I19D1n58te3h6KsqCf3GxyfBGY=
github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0=
@ -133,6 +137,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tetratelabs/wazero v1.8.0 h1:iEKu0d4c2Pd+QSRieYbnQC9yiFlMS9D+Jr0LsRmcF4g=
github.com/tetratelabs/wazero v1.8.0/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
github.com/therootcompany/xz v1.0.1 h1:CmOtsn1CbtmyYiusbfmhmkpAAETj0wBIH6kCYaX+xzw=
github.com/therootcompany/xz v1.0.1/go.mod h1:3K3UH1yCKgBneZYhuQUvJ9HPD19UEXEI0BWbMn8qNMY=
github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8=
@ -232,8 +238,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -273,6 +279,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0=
gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=

View File

@ -77,7 +77,82 @@ type Hash struct {
Kind goimagehash.Kind
}
type SavedHashes map[Source]map[string][3]uint64
// IDList is a map of domain to ID eg IDs["comicvine.gamespot.com"] = []string{"1235"}
// Maps are extremely expensive in go for small maps this should only be used to return info to a user no internal code should use this
type IDList map[Source][]string
type OldSavedHashes map[Source]map[string][3]uint64
type SavedHashes struct {
IDs [][]ID
Hashes [3]map[uint64]int
}
func ToIDList(ids []ID) IDList {
idlist := IDList{}
for _, id := range ids {
idlist[id.Domain] = Insert(idlist[id.Domain], id.ID)
}
return idlist
}
func InsertID(ids []ID, id ID) []ID {
index, itemFound := slices.BinarySearchFunc(ids, id, func(e ID, t ID) int {
return cmp.Or(
cmp.Compare(e.Domain, t.Domain),
cmp.Compare(e.ID, t.ID),
)
})
if itemFound {
return ids
}
return slices.Insert(ids, index, id)
}
func (s *SavedHashes) InsertHash(hash Hash, id ID) {
for i, h := range s.Hashes {
if h == nil {
s.Hashes[i] = make(map[uint64]int)
}
}
hashType := int(hash.Kind) - 1
idx, hashFound := s.Hashes[hashType][hash.Hash]
if !hashFound {
idx = len(s.IDs)
s.IDs = append(s.IDs, make([]ID, 0, 3))
}
s.IDs[idx] = InsertID(s.IDs[idx], id)
s.Hashes[hashType][hash.Hash] = idx
}
func ConvertSavedHashes(oldHashes OldSavedHashes) SavedHashes {
t := SavedHashes{}
idcount := 0
for _, ids := range oldHashes {
idcount += len(ids)
}
t.IDs = make([][]ID, 0, idcount)
t.Hashes[0] = make(map[uint64]int, idcount)
t.Hashes[1] = make(map[uint64]int, idcount)
t.Hashes[2] = make(map[uint64]int, idcount)
for domain, sourceHashes := range oldHashes {
for id, hashes := range sourceHashes {
idx := len(t.IDs)
t.IDs = append(t.IDs, []ID{{domain, id}})
for hashType, hash := range hashes {
t.Hashes[hashType][hash] = idx
}
}
}
fmt.Println("Expected number of IDs", idcount)
idcount = 0
for _, ids := range t.IDs {
idcount += len(ids)
}
fmt.Println("length of hashes", len(t.Hashes[0])+len(t.Hashes[1])+len(t.Hashes[2]))
fmt.Println("Length of ID lists", len(t.IDs))
fmt.Println("Total number of IDs", idcount)
return t
}
type NewIDs struct {
OldID ID
@ -171,5 +246,3 @@ func SplitHash(hash uint64) [8]uint8 {
uint8((hash & H0) >> Shift0),
}
}
type IDList map[Source][]string // IDs is a map of domain to ID eg IDs['comicvine.gamespot.com'] = []string{"1235"}

315
map.go
View File

@ -1,100 +1,32 @@
package ch
import (
"cmp"
"math/bits"
"fmt"
"slices"
"sync"
"gitea.narnian.us/lordwelch/goimagehash"
)
type mapStorage struct {
hashMutex sync.RWMutex
partialHash [3][8]map[uint8][]int
// partialAhash [8]map[uint8][]int
// partialDhash [8]map[uint8][]int
// partialPhash [8]map[uint8][]int
ids []ID
idToHash map[int][3][]int
hashes [3][]uint64
// ahashes []uint64
// dhashes []uint64
// phashes []uint64
hashToID [3]map[int][]int
// ahashToID map[int][]int
// dhashToID map[int][]int
// phashToID map[int][]int
type MapStorage struct {
basicMapStorage
partialHash [3][8]map[uint8][]uint64
}
func (m *mapStorage) addID(id ID) int {
index, itemFound := slices.BinarySearchFunc(m.ids, id, func(existing, new ID) int {
return cmp.Or(
cmp.Compare(existing.Domain, new.Domain),
cmp.Compare(existing.ID, new.ID),
)
})
if itemFound {
return index
}
m.ids = slices.Insert(m.ids, index, id)
return index
}
func (m *mapStorage) getID(id ID) (int, bool) {
return slices.BinarySearchFunc(m.ids, id, func(existing, new ID) int {
return cmp.Or(
cmp.Compare(existing.Domain, new.Domain),
cmp.Compare(existing.ID, new.ID),
)
})
}
func (m *mapStorage) Atleast(hashKind goimagehash.Kind, maxDistance int, searchHash uint64, hashes []int) []Result {
hashType := int(hashKind) - 1
matchingHashes := make([]Result, 0, len(hashes)/2) // hope that we don't need all of them
for _, idx := range hashes {
storedHash := m.hashes[hashType][idx]
distance := bits.OnesCount64(searchHash ^ storedHash)
if distance <= maxDistance {
ids := make(IDList)
for _, idLocation := range m.hashToID[hashType][idx] {
ids[m.ids[idLocation].Domain] = Insert(ids[m.ids[idLocation].Domain], m.ids[idLocation].ID)
}
matchingHashes = append(matchingHashes, Result{ids, distance, Hash{storedHash, hashKind}})
}
}
return matchingHashes
}
func (m *mapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) {
func (m *MapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) {
var foundMatches []Result
m.hashMutex.RLock()
defer m.hashMutex.RUnlock()
resetTime()
if exactOnly { // exact matches are also found by partial matches. Don't bother with exact matches so we don't have to de-duplicate
for _, hash := range hashes {
hashType := int(hash.Kind) - 1
if hashLocation, found := slices.BinarySearch(m.hashes[hashType], hash.Hash); found {
idlist := make(IDList)
for _, idLocation := range m.hashToID[hashType][hashLocation] {
for _, hashLocation := range m.idToHash[idLocation][0] {
for _, foundIDLocation := range m.hashToID[hashType][hashLocation] {
foundID := m.ids[foundIDLocation]
idlist[foundID.Domain] = Insert(idlist[foundID.Domain], foundID.ID)
}
}
}
if len(idlist) > 0 {
foundMatches = append(foundMatches, Result{
Distance: 0,
Hash: hash,
})
}
idlist := m.hashes[hashType][hash.Hash]
if idlist != nil && len(*idlist) > 0 {
foundMatches = append(foundMatches, Result{
Distance: 0,
Hash: hash,
IDs: ToIDList(*idlist),
})
}
}
@ -102,173 +34,114 @@ func (m *mapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Resul
if len(foundMatches) > 0 && exactOnly {
return foundMatches, nil
}
logTime("Search Exact")
}
foundHashes := make(map[uint64]struct{})
for _, hash := range hashes {
if hash.Hash == 0 {
continue
}
hashType := int(hash.Kind) - 1
for i, partialHash := range SplitHash(hash.Hash) {
for _, match := range m.Atleast(hash.Kind, max, hash.Hash, m.partialHash[hashType][i][partialHash]) {
_, alreadyMatched := foundHashes[match.Hash.Hash]
if alreadyMatched {
continue
totalPartialHashes := 0
for _, searchHash := range hashes {
foundHashes := make(map[uint64]struct{})
hashType := int(searchHash.Kind) - 1
for i, partialHash := range SplitHash(searchHash.Hash) {
partialHashes := m.partialHash[hashType][i][partialHash]
totalPartialHashes += len(partialHashes)
for _, match := range Atleast(max, searchHash.Hash, partialHashes) {
_, alreadyMatched := foundHashes[match.Hash]
if matchedResults, ok := m.hashes[hashType][match.Hash]; ok && !alreadyMatched {
foundHashes[match.Hash] = struct{}{}
foundMatches = append(foundMatches, Result{IDs: ToIDList(*matchedResults), Distance: match.Distance, Hash: Hash{Hash: match.Hash, Kind: searchHash.Kind}})
}
foundMatches = append(foundMatches, match)
}
}
}
fmt.Println("Total partial hashes tested:", totalPartialHashes)
logTime("Search Complete")
go m.printSizes()
return foundMatches, nil
}
func (m *mapStorage) MapHashes(hash ImageHash) {
idIndex := m.addID(hash.ID)
idHashes := m.idToHash[idIndex]
func (m *MapStorage) MapHashes(hash ImageHash) {
m.basicMapStorage.MapHashes(hash)
for _, hash := range hash.Hashes {
var (
hashIndex int
hashType = int(hash.Kind) - 1
)
m.hashes[hashType], hashIndex = InsertIdx(m.hashes[hashType], hash.Hash)
hashType := int(hash.Kind) - 1
for i, partialHash := range SplitHash(hash.Hash) {
m.partialHash[hashType][i][partialHash] = append(m.partialHash[hashType][i][partialHash], hashIndex)
m.partialHash[hashType][i][partialHash] = Insert(m.partialHash[hashType][i][partialHash], hash.Hash)
}
idHashes[hashType] = Insert(idHashes[hashType], hashIndex)
m.hashToID[hashType][hashIndex] = Insert(m.hashToID[hashType][hashIndex], idIndex)
}
m.idToHash[idIndex] = idHashes
}
func (m *mapStorage) DecodeHashes(hashes SavedHashes) error {
for _, sourceHashes := range hashes {
m.hashes[0] = make([]uint64, 0, len(sourceHashes))
m.hashes[1] = make([]uint64, 0, len(sourceHashes))
m.hashes[2] = make([]uint64, 0, len(sourceHashes))
break
func (m *MapStorage) DecodeHashes(hashes SavedHashes) error {
for hashType, sourceHashes := range hashes.Hashes {
m.hashes[hashType] = make(map[uint64]*[]ID, len(sourceHashes))
for savedHash, idlistLocation := range sourceHashes {
for i, partialHash := range SplitHash(savedHash) {
m.partialHash[hashType][i][partialHash] = append(m.partialHash[hashType][i][partialHash], savedHash)
}
m.hashes[hashType][savedHash] = &hashes.IDs[idlistLocation]
}
}
for domain, sourceHashes := range hashes {
for id, h := range sourceHashes {
m.ids = append(m.ids, ID{Domain: Source(domain), ID: id})
for _, hash := range []Hash{Hash{h[0], goimagehash.AHash}, Hash{h[1], goimagehash.DHash}, Hash{h[2], goimagehash.PHash}} {
var (
hashType = int(hash.Kind) - 1
)
m.hashes[hashType] = append(m.hashes[hashType], hash.Hash)
m.printSizes()
for _, partialHashes := range m.partialHash {
for _, partMap := range partialHashes {
for part, hashes := range partMap {
slices.Sort(hashes)
partMap[part] = slices.Compact(hashes)
}
}
}
slices.SortFunc(m.ids, func(existing, new ID) int {
return cmp.Or(
cmp.Compare(existing.Domain, new.Domain),
cmp.Compare(existing.ID, new.ID),
)
})
slices.Sort(m.hashes[0])
slices.Sort(m.hashes[1])
slices.Sort(m.hashes[2])
for domain, sourceHashes := range hashes {
for id, h := range sourceHashes {
m.MapHashes(ImageHash{
Hashes: []Hash{{h[0], goimagehash.AHash}, {h[1], goimagehash.DHash}, {h[2], goimagehash.PHash}},
ID: ID{Domain: Source(domain), ID: id},
})
}
}
m.printSizes()
return nil
}
func (m *mapStorage) EncodeHashes() (SavedHashes, error) {
hashes := make(SavedHashes)
for idLocation, hashLocation := range m.idToHash {
id := m.ids[idLocation]
_, ok := hashes[id.Domain]
if !ok {
hashes[id.Domain] = make(map[string][3]uint64)
}
// TODO: Add all hashes. Currently saved hashes does not allow multiple IDs for a single hash
hashes[id.Domain][id.ID] = [3]uint64{
m.hashes[0][hashLocation[0][0]],
m.hashes[1][hashLocation[1][0]],
m.hashes[2][hashLocation[2][0]],
}
}
return hashes, nil
}
func (m *MapStorage) printSizes() {
fmt.Println("Length of hashes:", len(m.hashes[0])+len(m.hashes[1])+len(m.hashes[2]))
// fmt.Println("Size of", "hashes:", size.Of(m.hashes)/1024/1024, "MB")
// fmt.Println("Size of", "ids:", size.Of(m.ids)/1024/1024, "MB")
// fmt.Println("Size of", "MapStorage:", size.Of(m)/1024/1024, "MB")
func (m *mapStorage) AssociateIDs(newids []NewIDs) {
for _, ids := range newids {
oldIDLocation, found := m.getID(ids.OldID)
if !found {
msg := "No IDs belonging to " + ids.OldID.Domain + "exist on this server"
panic(msg)
}
newIDLocation := m.addID(ids.NewID)
for _, hashType := range []int{int(goimagehash.AHash), int(goimagehash.DHash), int(goimagehash.PHash)} {
for _, hashLocation := range m.idToHash[oldIDLocation][hashType] {
m.hashToID[hashType][hashLocation] = Insert(m.hashToID[hashType][hashLocation], newIDLocation)
idHashes := m.idToHash[newIDLocation]
idHashes[hashType] = Insert(idHashes[hashType], hashLocation)
m.idToHash[newIDLocation] = idHashes
}
}
}
}
func (m *mapStorage) GetIDs(id ID) IDList {
idIndex, found := m.getID(id)
if !found {
msg := "No IDs belonging to " + id.Domain + "exist on this server"
panic(msg)
}
ids := make(IDList)
for _, hashLocation := range m.idToHash[idIndex][0] {
for _, foundIDLocation := range m.hashToID[0][hashLocation] {
foundID := m.ids[foundIDLocation]
ids[foundID.Domain] = Insert(ids[foundID.Domain], foundID.ID)
}
}
for _, hashLocation := range m.idToHash[idIndex][1] {
for _, foundIDLocation := range m.hashToID[1][hashLocation] {
foundID := m.ids[foundIDLocation]
ids[foundID.Domain] = Insert(ids[foundID.Domain], foundID.ID)
}
}
for _, hashLocation := range m.idToHash[idIndex][2] {
for _, foundIDLocation := range m.hashToID[2][hashLocation] {
foundID := m.ids[foundIDLocation]
ids[foundID.Domain] = Insert(ids[foundID.Domain], foundID.ID)
}
}
return ids
}
func NewMapStorage() (HashStorage, error) {
storage := &mapStorage{
hashMutex: sync.RWMutex{},
idToHash: make(map[int][3][]int),
hashToID: [3]map[int][]int{
make(map[int][]int),
make(map[int][]int),
make(map[int][]int),
storage := &MapStorage{
basicMapStorage: basicMapStorage{
hashMutex: sync.RWMutex{},
hashes: [3]map[uint64]*[]ID{
make(map[uint64]*[]ID),
make(map[uint64]*[]ID),
make(map[uint64]*[]ID),
},
},
partialHash: [3][8]map[uint8][]uint64{
{
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
},
{
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
},
{
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
make(map[uint8][]uint64),
},
},
}
for i := range storage.partialHash[0] {
storage.partialHash[0][i] = make(map[uint8][]int)
}
for i := range storage.partialHash[1] {
storage.partialHash[1][i] = make(map[uint8][]int)
}
for i := range storage.partialHash[2] {
storage.partialHash[2][i] = make(map[uint8][]int)
}
return storage, nil
}

131
sqlite.go
View File

@ -8,6 +8,7 @@ import (
"log"
"math/bits"
"strings"
"time"
"gitea.narnian.us/lordwelch/goimagehash"
_ "modernc.org/sqlite"
@ -67,11 +68,11 @@ func (s *sqliteStorage) findExactHashes(statement *sql.Stmt, items ...interface{
func (s *sqliteStorage) findPartialHashes(max int, search_hash int64, kind goimagehash.Kind) ([]sqliteHash, error) { // exact matches are also found by partial matches. Don't bother with exact matches so we don't have to de-duplicate
hashes := []sqliteHash{}
statement, err := s.db.PrepareContext(context.Background(), `SELECT rowid,hash,kind FROM Hashes WHERE (kind=?) AND (((hash >> (0 * 8) & 0xFF)=(? >> (0 * 8) & 0xFF)) OR ((hash >> (1 * 8) & 0xFF)=(? >> (1 * 8) & 0xFF)) OR ((hash >> (2 * 8) & 0xFF)=(? >> (2 * 8) & 0xFF)) OR ((hash >> (3 * 8) & 0xFF)=(? >> (3 * 8) & 0xFF)) OR ((hash >> (4 * 8) & 0xFF)=(? >> (4 * 8) & 0xFF)) OR ((hash >> (5 * 8) & 0xFF)=(? >> (5 * 8) & 0xFF)) OR ((hash >> (6 * 8) & 0xFF)=(? >> (6 * 8) & 0xFF)) OR ((hash >> (7 * 8) & 0xFF)=(? >> (7 * 8) & 0xFF))) ORDER BY kind,hash;`)
statement, err := s.db.PrepareContext(context.Background(), `SELECT rowid,hash,kind FROM Hashes WHERE (kind=?) AND (((hash >> (0 * 8) & 0xFF)=(?2 >> (0 * 8) & 0xFF)) OR ((hash >> (1 * 8) & 0xFF)=(?2 >> (1 * 8) & 0xFF)) OR ((hash >> (2 * 8) & 0xFF)=(?2 >> (2 * 8) & 0xFF)) OR ((hash >> (3 * 8) & 0xFF)=(?2 >> (3 * 8) & 0xFF)) OR ((hash >> (4 * 8) & 0xFF)=(?2 >> (4 * 8) & 0xFF)) OR ((hash >> (5 * 8) & 0xFF)=(?2 >> (5 * 8) & 0xFF)) OR ((hash >> (6 * 8) & 0xFF)=(?2 >> (6 * 8) & 0xFF)) OR ((hash >> (7 * 8) & 0xFF)=(?2 >> (7 * 8) & 0xFF)));`)
if err != nil {
return hashes, err
}
rows, err := statement.Query(kind, int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash), int64(search_hash))
rows, err := statement.Query(kind, int64(search_hash))
if err != nil {
return hashes, err
}
@ -93,6 +94,7 @@ func (s *sqliteStorage) findPartialHashes(max int, search_hash int64, kind goima
}
}
rows.Close()
logTime("Filter partial " + kind.String())
statement, err = s.db.PrepareContext(context.Background(), `SELECT DISTINCT IDS.domain, IDs.id, id_hash.hashid FROM IDs JOIN id_hash ON IDs.rowid = id_hash.idid WHERE (id_hash.hashid in (`+strings.TrimRight(strings.Repeat("?,", len(hashes)), ",")+`)) ORDER BY IDs.domain, IDs.ID;`)
if err != nil {
@ -161,6 +163,7 @@ CREATE INDEX IF NOT EXISTS hash_8_index ON Hashes ((hash >> (7 * 8) & 0xFF));
CREATE INDEX IF NOT EXISTS id_domain ON IDs (domain, id);
PRAGMA shrink_memory;
ANALYZE;
`)
if err != nil {
return err
@ -168,15 +171,38 @@ PRAGMA shrink_memory;
return nil
}
var (
total time.Duration
t = time.Now()
)
func resetTime() {
total = 0
t = time.Now()
}
func logTime(log string) {
n := time.Now()
s := n.Sub(t)
t = n
total += s
fmt.Printf("total: %v, %s: %v\n", total, log, s)
}
func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) {
var foundMatches []Result
var (
foundMatches []Result
)
resetTime()
if exactOnly { // exact matches are also found by partial matches. Don't bother with exact matches so we don't have to de-duplicate
statement, err := s.db.Prepare(`SELECT rowid,hash,kind FROM Hashes WHERE ` + strings.TrimSuffix(strings.Repeat("(hash=? AND kind=?) OR", len(hashes)), "OR") + `ORDER BY kind,hash;`)
if err != nil {
logTime("Fail exact")
return foundMatches, err
}
args := make([]interface{}, 0, len(hashes)*2)
for _, hash := range hashes {
if hash.Hash != 0 {
@ -195,6 +221,7 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re
if len(foundMatches) > 0 && exactOnly {
return foundMatches, nil
}
logTime("Search Exact")
}
foundHashes := make(map[uint64]struct{})
@ -204,6 +231,7 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re
if err != nil {
return foundMatches, err
}
logTime("Search partial " + hash.Kind.String())
for _, hash := range hashes {
if _, alreadyMatched := foundHashes[hash.Hash.Hash]; !alreadyMatched {
@ -219,14 +247,18 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re
}
func (s *sqliteStorage) MapHashes(hash ImageHash) {
insertHashes, err := s.db.Prepare(`
INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO UPDATE SET hash=?1 RETURNING hashid;
tx, err := s.db.BeginTx(context.Background(), nil)
if err != nil {
panic(err)
}
insertHashes, err := tx.Prepare(`
INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO UPDATE SET hash=?1 RETURNING hashid
`)
if err != nil {
panic(err)
}
rows, err := s.db.Query(`
INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RETURNING idid;
rows, err := tx.Query(`
INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RETURNING idid
`, hash.ID.Domain, hash.ID.ID)
if err != nil {
panic(err)
@ -258,12 +290,19 @@ INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RET
}
hash_ids = append(hash_ids, id)
}
var ids []any
for _, hash_id := range hash_ids {
_, err = s.db.Exec(`INSERT INTO id_hash (hashid,idid) VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id)
if err != nil {
panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err))
}
ids = append(ids, hash_id, id_id)
}
_, err = tx.Exec(`INSERT INTO id_hash (hashid,idid) VALUES `+strings.TrimSuffix(strings.Repeat("(?, ?),", len(hash_ids)), ",")+` ON CONFLICT DO NOTHING;`, ids...)
if err != nil {
panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err))
}
err = tx.Commit()
if err != nil {
panic(err)
}
insertHashes.Close()
}
func (s *sqliteStorage) DecodeHashes(hashes SavedHashes) error {
@ -272,9 +311,15 @@ func (s *sqliteStorage) DecodeHashes(hashes SavedHashes) error {
return err
}
for domain, sourceHashes := range hashes {
for id, h := range sourceHashes {
s.MapHashes(ImageHash{[]Hash{{h[0], goimagehash.AHash}, {h[1], goimagehash.DHash}, {h[2], goimagehash.PHash}}, ID{domain, id}})
for hashType, sourceHashes := range hashes.Hashes {
hashKind := goimagehash.Kind(hashType + 1)
for hash, idsLocations := range sourceHashes {
for _, id := range hashes.IDs[idsLocations] {
s.MapHashes(ImageHash{
Hashes: []Hash{{hash, hashKind}},
ID: id,
})
}
}
}
err = s.createIndexes()
@ -285,48 +330,27 @@ func (s *sqliteStorage) DecodeHashes(hashes SavedHashes) error {
}
func (s *sqliteStorage) EncodeHashes() (SavedHashes, error) {
hashes := make(SavedHashes)
hashes := SavedHashes{}
conn, err := s.db.Conn(context.Background())
if err != nil {
return hashes, err
}
defer conn.Close()
rows, err := conn.QueryContext(context.Background(), "SELECT DISTINCT (domain) FROM IDs ORDER BY domain;")
rows, err := conn.QueryContext(context.Background(), "SELECT IDs.domain,IDs.id,Hashes.hash,Hashes.kind FROM Hashes JOIN id_hash ON id_hash.hashid = hashes.rowid JOIN IDs ON IDs.rowid = id_hash.idid ORDER BY IDs.ID,Hashes.kind,Hashes.hash;")
if err != nil {
rows.Close()
return hashes, err
}
var (
id ID
hash Hash
)
err = rows.Scan(&id.Domain, &id.ID, &hash.Hash, &hash.Kind)
if err != nil {
return hashes, err
}
sources := make([]string, 0, 10)
for rows.Next() {
var source string
if err = rows.Scan(&source); err != nil {
rows.Close()
return hashes, err
}
sources = append(sources, source)
}
for _, source := range sources {
rows, err = conn.QueryContext(context.Background(), "SELECT IDs.id,Hashes.hash,Hashes.kind FROM Hashes JOIN id_hash ON id_hash.hashid = hashes.rowid JOIN IDs ON IDs.rowid = id_hash.idid WHERE IDs.domain = ? ORDER BY IDs.ID,Hashes.kind,Hashes.hash;", source)
if err != nil {
rows.Close()
return hashes, err
}
var (
id string
hash int64
typ goimagehash.Kind
)
err = rows.Scan(&id, &hash, &typ)
if err != nil {
return hashes, err
}
_, ok := hashes[Source(source)]
if !ok {
hashes[Source(source)] = make(map[string][3]uint64)
}
h := hashes[Source(source)][id]
h[typ-1] = uint64(hash)
hashes[Source(source)][id] = h
}
hashes.InsertHash(hash, id)
return hashes, nil
}
@ -415,16 +439,6 @@ CREATE TABLE IF NOT EXISTS Hashes(
UNIQUE(kind, hash)
);
CREATE INDEX IF NOT EXISTS hash_index ON Hashes (kind, hash);
CREATE INDEX IF NOT EXISTS hash_1_index ON Hashes ((hash >> (0 * 8) & 0xFF));
CREATE INDEX IF NOT EXISTS hash_2_index ON Hashes ((hash >> (1 * 8) & 0xFF));
CREATE INDEX IF NOT EXISTS hash_3_index ON Hashes ((hash >> (2 * 8) & 0xFF));
CREATE INDEX IF NOT EXISTS hash_4_index ON Hashes ((hash >> (3 * 8) & 0xFF));
CREATE INDEX IF NOT EXISTS hash_5_index ON Hashes ((hash >> (4 * 8) & 0xFF));
CREATE INDEX IF NOT EXISTS hash_6_index ON Hashes ((hash >> (5 * 8) & 0xFF));
CREATE INDEX IF NOT EXISTS hash_7_index ON Hashes ((hash >> (6 * 8) & 0xFF));
CREATE INDEX IF NOT EXISTS hash_8_index ON Hashes ((hash >> (7 * 8) & 0xFF));
CREATE TABLE IF NOT EXISTS IDs(
id TEXT NOT NULL,
domain TEXT NOT NULL,
@ -445,6 +459,7 @@ CREATE TABLE IF NOT EXISTS id_hash(
if err != nil {
panic(err)
}
sqlite.createIndexes()
sqlite.db.SetMaxOpenConns(1)
return sqlite, nil
}

8
sqlite_no_cgo.go Normal file
View File

@ -0,0 +1,8 @@
//go:build !cgo
package ch
import (
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)

105
vp-tree.go Normal file
View File

@ -0,0 +1,105 @@
package ch
import (
"errors"
"fmt"
"math/bits"
"gitea.narnian.us/lordwelch/goimagehash"
"gonum.org/v1/gonum/spatial/vptree"
)
type VPTree struct {
trees [3]*vptree.Tree
hashes [3][]vptree.Comparable
}
type VPHash struct {
Hash Hash
IDs []ID
}
func (h *VPHash) Distance(c vptree.Comparable) float64 {
h2, ok := c.(*VPHash)
if !ok {
return -99
}
return float64(bits.OnesCount64(h.Hash.Hash ^ h2.Hash.Hash))
}
func (v *VPTree) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) {
var matches []Result
var exactMatches []Result
fmt.Println(hashes)
for _, hash := range hashes {
results := vptree.NewDistKeeper(float64(max))
hashType := int(hash.Kind) - 1
v.trees[hashType].NearestSet(results, &VPHash{Hash: hash})
for _, result := range results.Heap {
vphash := result.Comparable.(*VPHash)
if result.Dist == 0 {
exactMatches = append(exactMatches, Result{
IDs: ToIDList(vphash.IDs),
Distance: int(result.Dist),
Hash: vphash.Hash,
})
} else {
matches = append(matches, Result{
IDs: ToIDList(vphash.IDs),
Distance: int(result.Dist),
Hash: vphash.Hash,
})
}
}
}
if len(exactMatches) > 0 && exactOnly {
return exactMatches, nil
}
matches = append(exactMatches[:len(exactMatches):len(exactMatches)], matches...)
return matches, nil
}
func (v *VPTree) MapHashes(ImageHash) {
panic("Not Implemented")
}
func (v *VPTree) DecodeHashes(hashes SavedHashes) error {
var err error
for hashType, sourceHashes := range hashes.Hashes {
for hash, idsLocation := range sourceHashes {
var (
hashKind = goimagehash.Kind(hashType + 1)
)
hash := &VPHash{Hash{hash, hashKind}, hashes.IDs[idsLocation]}
v.hashes[hashType] = append(v.hashes[hashType], hash)
}
}
for hashType := range 3 {
v.trees[hashType], err = vptree.New(v.hashes[hashType], 3, nil)
if err != nil {
return err
}
}
return nil
}
func (v *VPTree) EncodeHashes() (SavedHashes, error) {
return SavedHashes{}, errors.New("Not Implemented")
}
func (v *VPTree) AssociateIDs(newIDs []NewIDs) {
panic("Not Implemented")
}
func (v *VPTree) GetIDs(id ID) IDList {
return nil
}
func NewVPStorage() (HashStorage, error) {
return &VPTree{
hashes: [3][]vptree.Comparable{
make([]vptree.Comparable, 0, 1_000_000),
make([]vptree.Comparable, 0, 1_000_000),
make([]vptree.Comparable, 0, 1_000_000),
},
}, nil
}