From 4922ceb678225bd26ba02cf2c9f9834039e893d8 Mon Sep 17 00:00:00 2001 From: Timmy Welch Date: Wed, 16 Oct 2024 17:56:19 -0700 Subject: [PATCH] Fix locking for map storage --- BasicMap.go | 42 +++++++++++++++++++++++++++--------------- map.go | 2 +- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/BasicMap.go b/BasicMap.go index b2e8b9d..04981f3 100644 --- a/BasicMap.go +++ b/BasicMap.go @@ -12,7 +12,7 @@ import ( ) type basicMapStorage struct { - hashMutex sync.RWMutex + hashMutex *sync.RWMutex ids map[ID]*[]ID hashes [3][]structHash @@ -26,6 +26,8 @@ type structHash struct { func (b *basicMapStorage) Atleast(hashKind goimagehash.Kind, maxDistance int, searchHash uint64) []Result { hashType := int(hashKind) - 1 matchingHashes := make([]Result, 0, 100) // hope that we don't need all of them + b.hashMutex.RLock() + defer b.hashMutex.RUnlock() for _, storedHash := range b.hashes[hashType] { distance := bits.OnesCount64(searchHash ^ storedHash.hash) if distance <= maxDistance { @@ -36,14 +38,13 @@ func (b *basicMapStorage) Atleast(hashKind goimagehash.Kind, maxDistance int, se } func (b *basicMapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Result, error) { var foundMatches []Result - b.hashMutex.RLock() - defer b.hashMutex.RUnlock() resetTime() defer logTime(fmt.Sprintf("Search Complete: max: %v ExactOnly: %v", max, exactOnly)) if exactOnly { // exact matches are also found by partial matches. Don't bother with exact matches so we don't have to de-duplicate for _, hash := range hashes { hashType := int(hash.Kind) - 1 + b.hashMutex.RLock() index, hashFound := b.findHash(hashType, hash.Hash) if hashFound { foundMatches = append(foundMatches, Result{ @@ -52,6 +53,7 @@ func (b *basicMapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([] IDs: ToIDList(*b.hashes[hashType][index].ids), }) } + b.hashMutex.RUnlock() } logTime("Search Exact") @@ -75,21 +77,27 @@ func (b *basicMapStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([] } fmt.Println("Total partial hashes tested:", totalPartialHashes, len(foundHashes)) - go b.printSizes() return foundMatches, nil } +// findHash must have a read lock before using func (b *basicMapStorage) findHash(hashType int, hash uint64) (int, bool) { return slices.BinarySearchFunc(b.hashes[hashType], hash, func(e structHash, t uint64) int { return cmp.Compare(e.hash, t) }) } -func (b *basicMapStorage) InsertHash(hashType int, hash uint64, ids *[]ID) { + +// insertHash will take a write lock if the hash is not found +func (b *basicMapStorage) insertHash(hashType int, hash uint64, ids *[]ID) { + b.hashMutex.RLock() index, hashFound := b.findHash(hashType, hash) + b.hashMutex.RUnlock() if hashFound { return } + b.hashMutex.Lock() b.hashes[hashType] = slices.Insert(b.hashes[hashType], index, structHash{hash, ids}) + b.hashMutex.Unlock() } func (b *basicMapStorage) MapHashes(hash ImageHash) { @@ -97,16 +105,21 @@ func (b *basicMapStorage) MapHashes(hash ImageHash) { var ( hashType = int(ih.Kind) - 1 ) + b.hashMutex.RLock() ids, ok := b.ids[hash.ID] + b.hashMutex.RUnlock() if !ok { + b.hashMutex.Lock() ids = &[]ID{hash.ID} b.ids[hash.ID] = ids + b.hashMutex.Unlock() } - b.InsertHash(hashType, ih.Hash, ids) + b.insertHash(hashType, ih.Hash, ids) } } +// DecodeHashes should already have a lock func (b *basicMapStorage) DecodeHashes(hashes SavedHashes) error { for hashType, sourceHashes := range hashes.Hashes { b.hashes[hashType] = make([]structHash, len(sourceHashes)) @@ -122,17 +135,10 @@ func (b *basicMapStorage) DecodeHashes(hashes SavedHashes) error { return cmp.Compare(a.hash, b.hash) }) } - b.printSizes() return nil } -func (b *basicMapStorage) printSizes() { - // fmt.Println("Size of", "hashes:", size.Of(b.hashes)/1024/1024, "MB") - // fmt.Println("Size of", "ids:", size.Of(b.ids)/1024/1024, "MB") - // fmt.Println("Size of", "basicMapStorage:", size.Of(b)/1024/1024, "MB") - -} - +// EncodeHashes should already have a lock func (b *basicMapStorage) EncodeHashes() (SavedHashes, error) { hashes := SavedHashes{ Hashes: [3]map[uint64]int{ @@ -161,17 +167,23 @@ func (b *basicMapStorage) EncodeHashes() (SavedHashes, error) { func (b *basicMapStorage) AssociateIDs(newids []NewIDs) error { for _, newid := range newids { + b.hashMutex.RLock() ids, found := b.ids[newid.OldID] + b.hashMutex.RUnlock() if !found { msg := "No IDs belonging to " + string(newid.OldID.Domain) + " exist on this server" return errors.New(msg) } + b.hashMutex.Lock() *ids = InsertID(*ids, newid.NewID) + b.hashMutex.Unlock() } return nil } func (b *basicMapStorage) GetIDs(id ID) IDList { + b.hashMutex.RLock() + defer b.hashMutex.RUnlock() ids, found := b.ids[id] if !found { return nil @@ -181,7 +193,7 @@ func (b *basicMapStorage) GetIDs(id ID) IDList { func NewBasicMapStorage() (HashStorage, error) { storage := &basicMapStorage{ - hashMutex: sync.RWMutex{}, + hashMutex: &sync.RWMutex{}, ids: make(map[ID]*[]ID), hashes: [3][]structHash{}, } diff --git a/map.go b/map.go index 09720ff..031273d 100644 --- a/map.go +++ b/map.go @@ -106,7 +106,7 @@ func (m *MapStorage) printSizes() { func NewMapStorage() (HashStorage, error) { storage := &MapStorage{ basicMapStorage: basicMapStorage{ - hashMutex: sync.RWMutex{}, + hashMutex: &sync.RWMutex{}, hashes: [3][]structHash{ []structHash{}, []structHash{},