Add cli flag

This commit is contained in:
Timmy Welch 2024-09-02 15:35:36 -07:00
parent 1955444dcf
commit b1de95021a
2 changed files with 106 additions and 45 deletions

View File

@ -78,9 +78,6 @@ func (f Format) String() string {
return "Unknown" return "Unknown"
} }
type Encoder func(any) ([]byte, error)
type Decoder func([]byte, interface{}) error
func (f *Format) Set(s string) error { func (f *Format) Set(s string) error {
if format, known := formatValues[strings.ToLower(s)]; known { if format, known := formatValues[strings.ToLower(s)]; known {
*f = format *f = format
@ -90,6 +87,45 @@ func (f *Format) Set(s string) error {
return nil return nil
} }
type Storage int
const (
Map = iota + 1
Sqlite
Sqlite3
)
var storageNames = map[Storage]string{
Map: "map",
Sqlite: "sqlite",
Sqlite3: "sqlite3",
}
var storageValues = map[string]Storage{
"map": Map,
"sqlite": Sqlite,
"sqlite3": Sqlite3,
}
func (f Storage) String() string {
if name, known := storageNames[f]; known {
return name
}
return "Unknown"
}
func (f *Storage) Set(s string) error {
if storage, known := storageValues[strings.ToLower(s)]; known {
*f = storage
} else {
return fmt.Errorf("Unknown storage type: %d", f)
}
return nil
}
type Encoder func(any) ([]byte, error)
type Decoder func([]byte, interface{}) error
type Opts struct { type Opts struct {
cpuprofile string cpuprofile string
coverPath string coverPath string
@ -98,10 +134,11 @@ type Opts struct {
saveEmbeddedHashes bool saveEmbeddedHashes bool
format Format format Format
hashesPath string hashesPath string
storageType Storage
} }
func main() { func main() {
opts := Opts{format: Msgpack} // flag is weird opts := Opts{format: Msgpack, storageType: Map} // flag is weird
go func() { go func() {
log.Println(http.ListenAndServe("localhost:6060", nil)) log.Println(http.ListenAndServe("localhost:6060", nil))
}() }()
@ -113,6 +150,7 @@ func main() {
flag.BoolVar(&opts.saveEmbeddedHashes, "save-embedded-hashes", false, "Save hashes even if we loaded the embedded hashes") flag.BoolVar(&opts.saveEmbeddedHashes, "save-embedded-hashes", false, "Save hashes even if we loaded the embedded hashes")
flag.StringVar(&opts.hashesPath, "hashes", "hashes.gz", "Path to optionally gziped hashes in msgpack or json format. You must disable embedded hashes to use this option") flag.StringVar(&opts.hashesPath, "hashes", "hashes.gz", "Path to optionally gziped hashes in msgpack or json format. You must disable embedded hashes to use this option")
flag.Var(&opts.format, "save-format", "Specify the format to export hashes to (json, msgpack)") flag.Var(&opts.format, "save-format", "Specify the format to export hashes to (json, msgpack)")
flag.Var(&opts.storageType, "storage-type", "Specify the storage type used internally to search hashes (sqlite,sqlite3,map)")
flag.Parse() flag.Parse()
if opts.coverPath != "" { if opts.coverPath != "" {
@ -122,7 +160,7 @@ func main() {
} }
} }
opts.sqlitePath, _ = filepath.Abs(opts.sqlitePath) opts.sqlitePath, _ = filepath.Abs(opts.sqlitePath)
pretty.Logln(opts) log.Println(pretty.Formatter(opts))
startServer(opts) startServer(opts)
} }
@ -515,10 +553,10 @@ func (s *Server) HashLocalImages(opts Opts) {
log.Println("Recieved quit") log.Println("Recieved quit")
} }
err := s.httpServer.Shutdown(context.TODO()) err := s.httpServer.Shutdown(context.TODO())
fmt.Println("Err:", err) log.Println("Err:", err)
return return
} }
fmt.Println("Hashing covers at ", opts.coverPath) log.Println("Hashing covers at ", opts.coverPath)
start := time.Now() start := time.Now()
err := filepath.WalkDir(opts.coverPath, func(path string, d fs.DirEntry, err error) error { err := filepath.WalkDir(opts.coverPath, func(path string, d fs.DirEntry, err error) error {
if err != nil { if err != nil {
@ -544,7 +582,7 @@ func (s *Server) HashLocalImages(opts Opts) {
return nil return nil
}) })
elapsed := time.Since(start) elapsed := time.Since(start)
fmt.Println("Err:", err, "local hashing took", elapsed) log.Println("Err:", err, "local hashing took", elapsed)
sig := <-s.signalQueue sig := <-s.signalQueue
if !alreadyQuit { if !alreadyQuit {
@ -555,6 +593,18 @@ func (s *Server) HashLocalImages(opts Opts) {
}() }()
} }
func initializeStorage(opts Opts) (ch.HashStorage, error) {
switch opts.storageType {
case Map:
return ch.NewMapStorage()
case Sqlite:
return ch.NewSqliteStorage("sqlite", opts.sqlitePath)
case Sqlite3:
return ch.NewSqliteStorage("sqlite3", opts.sqlitePath)
}
return nil, errors.New("Unknown storage type provided")
}
func startServer(opts Opts) { func startServer(opts Opts) {
if opts.cpuprofile != "" { if opts.cpuprofile != "" {
f, err := os.Create(opts.cpuprofile) f, err := os.Create(opts.cpuprofile)
@ -584,32 +634,32 @@ func startServer(opts Opts) {
} }
Notify(server.signalQueue) Notify(server.signalQueue)
var err error var err error
fmt.Println("init hashes") log.Println("init hashes")
server.hashes, err = ch.NewMapStorage() server.hashes, err = initializeStorage(opts)
if err != nil { if err != nil {
panic(err) panic(err)
} }
fmt.Println("init handlers") log.Println("init handlers")
server.setupAppHandlers() server.setupAppHandlers()
fmt.Println("init hashers") log.Println("init hashers")
rwg := sync.WaitGroup{} rwg := sync.WaitGroup{}
for i := range 10 { for i := range 10 {
rwg.Add(1) rwg.Add(1)
go server.reader(i, func() { fmt.Println("Reader completed"); rwg.Done() }) go server.reader(i, func() { log.Println("Reader completed"); rwg.Done() })
} }
hwg := sync.WaitGroup{} hwg := sync.WaitGroup{}
for i := range 10 { for i := range 10 {
hwg.Add(1) hwg.Add(1)
go server.hasher(i, func() { fmt.Println("Hasher completed"); hwg.Done() }) go server.hasher(i, func() { log.Println("Hasher completed"); hwg.Done() })
} }
fmt.Println("init mapper") log.Println("init mapper")
mwg := sync.WaitGroup{} mwg := sync.WaitGroup{}
mwg.Add(1) mwg.Add(1)
go server.mapper(func() { fmt.Println("Mapper completed"); mwg.Done() }) go server.mapper(func() { log.Println("Mapper completed"); mwg.Done() })
if opts.loadEmbeddedHashes && len(ch.Hashes) != 0 { if opts.loadEmbeddedHashes && len(ch.Hashes) != 0 {
var err error var err error
@ -658,32 +708,32 @@ func startServer(opts Opts) {
fmt.Printf("Loaded hashes from %q %s\n", opts.hashesPath, format) fmt.Printf("Loaded hashes from %q %s\n", opts.hashesPath, format)
} else { } else {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
fmt.Println("No saved hashes to load") log.Println("No saved hashes to load")
} else { } else {
fmt.Println("Unable to load saved hashes", err) log.Println("Unable to load saved hashes", err)
} }
} }
} }
server.HashLocalImages(opts) server.HashLocalImages(opts)
fmt.Println("Listening on ", server.httpServer.Addr) log.Println("Listening on ", server.httpServer.Addr)
err = server.httpServer.ListenAndServe() err = server.httpServer.ListenAndServe()
if err != nil { if err != nil {
fmt.Println(err) log.Println(err)
} }
close(server.readerQueue) close(server.readerQueue)
fmt.Println("waiting on readers") log.Println("waiting on readers")
rwg.Wait() rwg.Wait()
for range server.readerQueue { for range server.readerQueue {
} }
close(server.hashingQueue) close(server.hashingQueue)
fmt.Println("waiting on hashers") log.Println("waiting on hashers")
hwg.Wait() hwg.Wait()
for range server.hashingQueue { for range server.hashingQueue {
} }
close(server.mappingQueue) close(server.mappingQueue)
fmt.Println("waiting on mapper") log.Println("waiting on mapper")
mwg.Wait() mwg.Wait()
for range server.mappingQueue { for range server.mappingQueue {
} }
@ -698,14 +748,14 @@ func startServer(opts Opts) {
gzw := gzip.NewWriter(f) gzw := gzip.NewWriter(f)
_, err := gzw.Write(encodedHashes) _, err := gzw.Write(encodedHashes)
if err != nil { if err != nil {
fmt.Println("Failed to write hashes", err) log.Println("Failed to write hashes", err)
} else { } else {
fmt.Println("Successfully saved hashes") log.Println("Successfully saved hashes")
} }
gzw.Close() gzw.Close()
f.Close() f.Close()
} else { } else {
fmt.Println("Unabled to save hashes", err) log.Println("Unabled to save hashes", err)
} }
} else { } else {
fmt.Printf("Unable to encode hashes as %v: %v", opts.format, err) fmt.Printf("Unable to encode hashes as %v: %v", opts.format, err)

View File

@ -220,35 +220,46 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re
func (s *sqliteStorage) MapHashes(hash ImageHash) { func (s *sqliteStorage) MapHashes(hash ImageHash) {
insertHashes, err := s.db.Prepare(` insertHashes, err := s.db.Prepare(`
INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO NOTHING; INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO UPDATE SET hash=?1 RETURNING hashid;
`) `)
if err != nil { if err != nil {
panic(err) panic(err)
} }
IDInsertResult, err := s.db.Exec(` rows, err := s.db.Query(`
INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO NOTHING; INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RETURNING idid;
`, hash.ID.Domain, hash.ID.Domain) `, hash.ID.Domain, hash.ID.ID)
if err != nil { if err != nil {
panic(err) panic(err)
} }
id_id, err := IDInsertResult.LastInsertId() if !rows.Next() {
panic("Unable to insert IDs")
}
var id_id int64
err = rows.Scan(&id_id)
if err != nil { if err != nil {
panic(err) panic(err)
} }
rows.Close()
hash_ids := []int64{} hash_ids := []int64{}
for _, hash := range hash.Hashes { for _, hash := range hash.Hashes {
hashInsertResult, err := insertHashes.Exec(int64(hash.Hash), hash.Kind) rows, err := insertHashes.Query(int64(hash.Hash), hash.Kind)
if err != nil { if err != nil {
panic(err) panic(err)
} }
id, err := hashInsertResult.LastInsertId()
if !rows.Next() {
panic("Unable to insert IDs")
}
var id int64
err = rows.Scan(&id)
rows.Close()
if err != nil { if err != nil {
panic(err) panic(err)
} }
hash_ids = append(hash_ids, id) hash_ids = append(hash_ids, id)
} }
for _, hash_id := range hash_ids { for _, hash_id := range hash_ids {
_, err = s.db.Exec(`INSERT INTO id_hash VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id) _, err = s.db.Exec(`INSERT INTO id_hash (hashid,idid) VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id)
if err != nil { if err != nil {
panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err)) panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err))
} }