From b1de95021a81eb4101d8a674946e081920794bbb Mon Sep 17 00:00:00 2001 From: Timmy Welch Date: Mon, 2 Sep 2024 15:35:36 -0700 Subject: [PATCH] Add cli flag --- cmd/comic-hasher/main.go | 102 +++++++++++++++++++++++++++++---------- sqlite.go | 49 +++++++++++-------- 2 files changed, 106 insertions(+), 45 deletions(-) diff --git a/cmd/comic-hasher/main.go b/cmd/comic-hasher/main.go index a5a4c2c..7263d84 100644 --- a/cmd/comic-hasher/main.go +++ b/cmd/comic-hasher/main.go @@ -78,9 +78,6 @@ func (f Format) String() string { return "Unknown" } -type Encoder func(any) ([]byte, error) -type Decoder func([]byte, interface{}) error - func (f *Format) Set(s string) error { if format, known := formatValues[strings.ToLower(s)]; known { *f = format @@ -90,6 +87,45 @@ func (f *Format) Set(s string) error { return nil } +type Storage int + +const ( + Map = iota + 1 + Sqlite + Sqlite3 +) + +var storageNames = map[Storage]string{ + Map: "map", + Sqlite: "sqlite", + Sqlite3: "sqlite3", +} + +var storageValues = map[string]Storage{ + "map": Map, + "sqlite": Sqlite, + "sqlite3": Sqlite3, +} + +func (f Storage) String() string { + if name, known := storageNames[f]; known { + return name + } + return "Unknown" +} + +func (f *Storage) Set(s string) error { + if storage, known := storageValues[strings.ToLower(s)]; known { + *f = storage + } else { + return fmt.Errorf("Unknown storage type: %d", f) + } + return nil +} + +type Encoder func(any) ([]byte, error) +type Decoder func([]byte, interface{}) error + type Opts struct { cpuprofile string coverPath string @@ -98,10 +134,11 @@ type Opts struct { saveEmbeddedHashes bool format Format hashesPath string + storageType Storage } func main() { - opts := Opts{format: Msgpack} // flag is weird + opts := Opts{format: Msgpack, storageType: Map} // flag is weird go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) }() @@ -113,6 +150,7 @@ func main() { flag.BoolVar(&opts.saveEmbeddedHashes, "save-embedded-hashes", false, "Save hashes even if we loaded the embedded hashes") flag.StringVar(&opts.hashesPath, "hashes", "hashes.gz", "Path to optionally gziped hashes in msgpack or json format. You must disable embedded hashes to use this option") flag.Var(&opts.format, "save-format", "Specify the format to export hashes to (json, msgpack)") + flag.Var(&opts.storageType, "storage-type", "Specify the storage type used internally to search hashes (sqlite,sqlite3,map)") flag.Parse() if opts.coverPath != "" { @@ -122,7 +160,7 @@ func main() { } } opts.sqlitePath, _ = filepath.Abs(opts.sqlitePath) - pretty.Logln(opts) + log.Println(pretty.Formatter(opts)) startServer(opts) } @@ -515,10 +553,10 @@ func (s *Server) HashLocalImages(opts Opts) { log.Println("Recieved quit") } err := s.httpServer.Shutdown(context.TODO()) - fmt.Println("Err:", err) + log.Println("Err:", err) return } - fmt.Println("Hashing covers at ", opts.coverPath) + log.Println("Hashing covers at ", opts.coverPath) start := time.Now() err := filepath.WalkDir(opts.coverPath, func(path string, d fs.DirEntry, err error) error { if err != nil { @@ -544,7 +582,7 @@ func (s *Server) HashLocalImages(opts Opts) { return nil }) elapsed := time.Since(start) - fmt.Println("Err:", err, "local hashing took", elapsed) + log.Println("Err:", err, "local hashing took", elapsed) sig := <-s.signalQueue if !alreadyQuit { @@ -555,6 +593,18 @@ func (s *Server) HashLocalImages(opts Opts) { }() } +func initializeStorage(opts Opts) (ch.HashStorage, error) { + switch opts.storageType { + case Map: + return ch.NewMapStorage() + case Sqlite: + return ch.NewSqliteStorage("sqlite", opts.sqlitePath) + case Sqlite3: + return ch.NewSqliteStorage("sqlite3", opts.sqlitePath) + } + return nil, errors.New("Unknown storage type provided") +} + func startServer(opts Opts) { if opts.cpuprofile != "" { f, err := os.Create(opts.cpuprofile) @@ -584,32 +634,32 @@ func startServer(opts Opts) { } Notify(server.signalQueue) var err error - fmt.Println("init hashes") - server.hashes, err = ch.NewMapStorage() + log.Println("init hashes") + server.hashes, err = initializeStorage(opts) if err != nil { panic(err) } - fmt.Println("init handlers") + log.Println("init handlers") server.setupAppHandlers() - fmt.Println("init hashers") + log.Println("init hashers") rwg := sync.WaitGroup{} for i := range 10 { rwg.Add(1) - go server.reader(i, func() { fmt.Println("Reader completed"); rwg.Done() }) + go server.reader(i, func() { log.Println("Reader completed"); rwg.Done() }) } hwg := sync.WaitGroup{} for i := range 10 { hwg.Add(1) - go server.hasher(i, func() { fmt.Println("Hasher completed"); hwg.Done() }) + go server.hasher(i, func() { log.Println("Hasher completed"); hwg.Done() }) } - fmt.Println("init mapper") + log.Println("init mapper") mwg := sync.WaitGroup{} mwg.Add(1) - go server.mapper(func() { fmt.Println("Mapper completed"); mwg.Done() }) + go server.mapper(func() { log.Println("Mapper completed"); mwg.Done() }) if opts.loadEmbeddedHashes && len(ch.Hashes) != 0 { var err error @@ -658,32 +708,32 @@ func startServer(opts Opts) { fmt.Printf("Loaded hashes from %q %s\n", opts.hashesPath, format) } else { if errors.Is(err, os.ErrNotExist) { - fmt.Println("No saved hashes to load") + log.Println("No saved hashes to load") } else { - fmt.Println("Unable to load saved hashes", err) + log.Println("Unable to load saved hashes", err) } } } server.HashLocalImages(opts) - fmt.Println("Listening on ", server.httpServer.Addr) + log.Println("Listening on ", server.httpServer.Addr) err = server.httpServer.ListenAndServe() if err != nil { - fmt.Println(err) + log.Println(err) } close(server.readerQueue) - fmt.Println("waiting on readers") + log.Println("waiting on readers") rwg.Wait() for range server.readerQueue { } close(server.hashingQueue) - fmt.Println("waiting on hashers") + log.Println("waiting on hashers") hwg.Wait() for range server.hashingQueue { } close(server.mappingQueue) - fmt.Println("waiting on mapper") + log.Println("waiting on mapper") mwg.Wait() for range server.mappingQueue { } @@ -698,14 +748,14 @@ func startServer(opts Opts) { gzw := gzip.NewWriter(f) _, err := gzw.Write(encodedHashes) if err != nil { - fmt.Println("Failed to write hashes", err) + log.Println("Failed to write hashes", err) } else { - fmt.Println("Successfully saved hashes") + log.Println("Successfully saved hashes") } gzw.Close() f.Close() } else { - fmt.Println("Unabled to save hashes", err) + log.Println("Unabled to save hashes", err) } } else { fmt.Printf("Unable to encode hashes as %v: %v", opts.format, err) diff --git a/sqlite.go b/sqlite.go index 44c7cb0..d2a2be5 100644 --- a/sqlite.go +++ b/sqlite.go @@ -128,18 +128,18 @@ func (s *sqliteStorage) findPartialHashes(max int, search_hash int64, kind goima func (s *sqliteStorage) dropIndexes() error { _, err := s.db.Exec(` -DROP INDEX IF EXISTS hash_index; -DROP INDEX IF EXISTS hash_1_index; -DROP INDEX IF EXISTS hash_2_index; -DROP INDEX IF EXISTS hash_3_index; -DROP INDEX IF EXISTS hash_4_index; -DROP INDEX IF EXISTS hash_5_index; -DROP INDEX IF EXISTS hash_6_index; -DROP INDEX IF EXISTS hash_7_index; -DROP INDEX IF EXISTS hash_8_index; + DROP INDEX IF EXISTS hash_index; + DROP INDEX IF EXISTS hash_1_index; + DROP INDEX IF EXISTS hash_2_index; + DROP INDEX IF EXISTS hash_3_index; + DROP INDEX IF EXISTS hash_4_index; + DROP INDEX IF EXISTS hash_5_index; + DROP INDEX IF EXISTS hash_6_index; + DROP INDEX IF EXISTS hash_7_index; + DROP INDEX IF EXISTS hash_8_index; -DROP INDEX IF EXISTS id_domain; -`) + DROP INDEX IF EXISTS id_domain; + `) if err != nil { return err } @@ -220,35 +220,46 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re func (s *sqliteStorage) MapHashes(hash ImageHash) { insertHashes, err := s.db.Prepare(` -INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO NOTHING; +INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO UPDATE SET hash=?1 RETURNING hashid; `) if err != nil { panic(err) } - IDInsertResult, err := s.db.Exec(` -INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO NOTHING; -`, hash.ID.Domain, hash.ID.Domain) + rows, err := s.db.Query(` +INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RETURNING idid; +`, hash.ID.Domain, hash.ID.ID) if err != nil { panic(err) } - id_id, err := IDInsertResult.LastInsertId() + if !rows.Next() { + panic("Unable to insert IDs") + } + var id_id int64 + err = rows.Scan(&id_id) if err != nil { panic(err) } + rows.Close() hash_ids := []int64{} for _, hash := range hash.Hashes { - hashInsertResult, err := insertHashes.Exec(int64(hash.Hash), hash.Kind) + rows, err := insertHashes.Query(int64(hash.Hash), hash.Kind) if err != nil { panic(err) } - id, err := hashInsertResult.LastInsertId() + + if !rows.Next() { + panic("Unable to insert IDs") + } + var id int64 + err = rows.Scan(&id) + rows.Close() if err != nil { panic(err) } hash_ids = append(hash_ids, id) } for _, hash_id := range hash_ids { - _, err = s.db.Exec(`INSERT INTO id_hash VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id) + _, err = s.db.Exec(`INSERT INTO id_hash (hashid,idid) VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id) if err != nil { panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err)) }