Add cli flag

This commit is contained in:
Timmy Welch 2024-09-02 15:35:36 -07:00
parent 1955444dcf
commit b1de95021a
2 changed files with 106 additions and 45 deletions

View File

@ -78,9 +78,6 @@ func (f Format) String() string {
return "Unknown"
}
type Encoder func(any) ([]byte, error)
type Decoder func([]byte, interface{}) error
func (f *Format) Set(s string) error {
if format, known := formatValues[strings.ToLower(s)]; known {
*f = format
@ -90,6 +87,45 @@ func (f *Format) Set(s string) error {
return nil
}
type Storage int
const (
Map = iota + 1
Sqlite
Sqlite3
)
var storageNames = map[Storage]string{
Map: "map",
Sqlite: "sqlite",
Sqlite3: "sqlite3",
}
var storageValues = map[string]Storage{
"map": Map,
"sqlite": Sqlite,
"sqlite3": Sqlite3,
}
func (f Storage) String() string {
if name, known := storageNames[f]; known {
return name
}
return "Unknown"
}
func (f *Storage) Set(s string) error {
if storage, known := storageValues[strings.ToLower(s)]; known {
*f = storage
} else {
return fmt.Errorf("Unknown storage type: %d", f)
}
return nil
}
type Encoder func(any) ([]byte, error)
type Decoder func([]byte, interface{}) error
type Opts struct {
cpuprofile string
coverPath string
@ -98,10 +134,11 @@ type Opts struct {
saveEmbeddedHashes bool
format Format
hashesPath string
storageType Storage
}
func main() {
opts := Opts{format: Msgpack} // flag is weird
opts := Opts{format: Msgpack, storageType: Map} // flag is weird
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
@ -113,6 +150,7 @@ func main() {
flag.BoolVar(&opts.saveEmbeddedHashes, "save-embedded-hashes", false, "Save hashes even if we loaded the embedded hashes")
flag.StringVar(&opts.hashesPath, "hashes", "hashes.gz", "Path to optionally gziped hashes in msgpack or json format. You must disable embedded hashes to use this option")
flag.Var(&opts.format, "save-format", "Specify the format to export hashes to (json, msgpack)")
flag.Var(&opts.storageType, "storage-type", "Specify the storage type used internally to search hashes (sqlite,sqlite3,map)")
flag.Parse()
if opts.coverPath != "" {
@ -122,7 +160,7 @@ func main() {
}
}
opts.sqlitePath, _ = filepath.Abs(opts.sqlitePath)
pretty.Logln(opts)
log.Println(pretty.Formatter(opts))
startServer(opts)
}
@ -515,10 +553,10 @@ func (s *Server) HashLocalImages(opts Opts) {
log.Println("Recieved quit")
}
err := s.httpServer.Shutdown(context.TODO())
fmt.Println("Err:", err)
log.Println("Err:", err)
return
}
fmt.Println("Hashing covers at ", opts.coverPath)
log.Println("Hashing covers at ", opts.coverPath)
start := time.Now()
err := filepath.WalkDir(opts.coverPath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
@ -544,7 +582,7 @@ func (s *Server) HashLocalImages(opts Opts) {
return nil
})
elapsed := time.Since(start)
fmt.Println("Err:", err, "local hashing took", elapsed)
log.Println("Err:", err, "local hashing took", elapsed)
sig := <-s.signalQueue
if !alreadyQuit {
@ -555,6 +593,18 @@ func (s *Server) HashLocalImages(opts Opts) {
}()
}
func initializeStorage(opts Opts) (ch.HashStorage, error) {
switch opts.storageType {
case Map:
return ch.NewMapStorage()
case Sqlite:
return ch.NewSqliteStorage("sqlite", opts.sqlitePath)
case Sqlite3:
return ch.NewSqliteStorage("sqlite3", opts.sqlitePath)
}
return nil, errors.New("Unknown storage type provided")
}
func startServer(opts Opts) {
if opts.cpuprofile != "" {
f, err := os.Create(opts.cpuprofile)
@ -584,32 +634,32 @@ func startServer(opts Opts) {
}
Notify(server.signalQueue)
var err error
fmt.Println("init hashes")
server.hashes, err = ch.NewMapStorage()
log.Println("init hashes")
server.hashes, err = initializeStorage(opts)
if err != nil {
panic(err)
}
fmt.Println("init handlers")
log.Println("init handlers")
server.setupAppHandlers()
fmt.Println("init hashers")
log.Println("init hashers")
rwg := sync.WaitGroup{}
for i := range 10 {
rwg.Add(1)
go server.reader(i, func() { fmt.Println("Reader completed"); rwg.Done() })
go server.reader(i, func() { log.Println("Reader completed"); rwg.Done() })
}
hwg := sync.WaitGroup{}
for i := range 10 {
hwg.Add(1)
go server.hasher(i, func() { fmt.Println("Hasher completed"); hwg.Done() })
go server.hasher(i, func() { log.Println("Hasher completed"); hwg.Done() })
}
fmt.Println("init mapper")
log.Println("init mapper")
mwg := sync.WaitGroup{}
mwg.Add(1)
go server.mapper(func() { fmt.Println("Mapper completed"); mwg.Done() })
go server.mapper(func() { log.Println("Mapper completed"); mwg.Done() })
if opts.loadEmbeddedHashes && len(ch.Hashes) != 0 {
var err error
@ -658,32 +708,32 @@ func startServer(opts Opts) {
fmt.Printf("Loaded hashes from %q %s\n", opts.hashesPath, format)
} else {
if errors.Is(err, os.ErrNotExist) {
fmt.Println("No saved hashes to load")
log.Println("No saved hashes to load")
} else {
fmt.Println("Unable to load saved hashes", err)
log.Println("Unable to load saved hashes", err)
}
}
}
server.HashLocalImages(opts)
fmt.Println("Listening on ", server.httpServer.Addr)
log.Println("Listening on ", server.httpServer.Addr)
err = server.httpServer.ListenAndServe()
if err != nil {
fmt.Println(err)
log.Println(err)
}
close(server.readerQueue)
fmt.Println("waiting on readers")
log.Println("waiting on readers")
rwg.Wait()
for range server.readerQueue {
}
close(server.hashingQueue)
fmt.Println("waiting on hashers")
log.Println("waiting on hashers")
hwg.Wait()
for range server.hashingQueue {
}
close(server.mappingQueue)
fmt.Println("waiting on mapper")
log.Println("waiting on mapper")
mwg.Wait()
for range server.mappingQueue {
}
@ -698,14 +748,14 @@ func startServer(opts Opts) {
gzw := gzip.NewWriter(f)
_, err := gzw.Write(encodedHashes)
if err != nil {
fmt.Println("Failed to write hashes", err)
log.Println("Failed to write hashes", err)
} else {
fmt.Println("Successfully saved hashes")
log.Println("Successfully saved hashes")
}
gzw.Close()
f.Close()
} else {
fmt.Println("Unabled to save hashes", err)
log.Println("Unabled to save hashes", err)
}
} else {
fmt.Printf("Unable to encode hashes as %v: %v", opts.format, err)

View File

@ -128,18 +128,18 @@ func (s *sqliteStorage) findPartialHashes(max int, search_hash int64, kind goima
func (s *sqliteStorage) dropIndexes() error {
_, err := s.db.Exec(`
DROP INDEX IF EXISTS hash_index;
DROP INDEX IF EXISTS hash_1_index;
DROP INDEX IF EXISTS hash_2_index;
DROP INDEX IF EXISTS hash_3_index;
DROP INDEX IF EXISTS hash_4_index;
DROP INDEX IF EXISTS hash_5_index;
DROP INDEX IF EXISTS hash_6_index;
DROP INDEX IF EXISTS hash_7_index;
DROP INDEX IF EXISTS hash_8_index;
DROP INDEX IF EXISTS hash_index;
DROP INDEX IF EXISTS hash_1_index;
DROP INDEX IF EXISTS hash_2_index;
DROP INDEX IF EXISTS hash_3_index;
DROP INDEX IF EXISTS hash_4_index;
DROP INDEX IF EXISTS hash_5_index;
DROP INDEX IF EXISTS hash_6_index;
DROP INDEX IF EXISTS hash_7_index;
DROP INDEX IF EXISTS hash_8_index;
DROP INDEX IF EXISTS id_domain;
`)
DROP INDEX IF EXISTS id_domain;
`)
if err != nil {
return err
}
@ -220,35 +220,46 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re
func (s *sqliteStorage) MapHashes(hash ImageHash) {
insertHashes, err := s.db.Prepare(`
INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO NOTHING;
INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO UPDATE SET hash=?1 RETURNING hashid;
`)
if err != nil {
panic(err)
}
IDInsertResult, err := s.db.Exec(`
INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO NOTHING;
`, hash.ID.Domain, hash.ID.Domain)
rows, err := s.db.Query(`
INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RETURNING idid;
`, hash.ID.Domain, hash.ID.ID)
if err != nil {
panic(err)
}
id_id, err := IDInsertResult.LastInsertId()
if !rows.Next() {
panic("Unable to insert IDs")
}
var id_id int64
err = rows.Scan(&id_id)
if err != nil {
panic(err)
}
rows.Close()
hash_ids := []int64{}
for _, hash := range hash.Hashes {
hashInsertResult, err := insertHashes.Exec(int64(hash.Hash), hash.Kind)
rows, err := insertHashes.Query(int64(hash.Hash), hash.Kind)
if err != nil {
panic(err)
}
id, err := hashInsertResult.LastInsertId()
if !rows.Next() {
panic("Unable to insert IDs")
}
var id int64
err = rows.Scan(&id)
rows.Close()
if err != nil {
panic(err)
}
hash_ids = append(hash_ids, id)
}
for _, hash_id := range hash_ids {
_, err = s.db.Exec(`INSERT INTO id_hash VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id)
_, err = s.db.Exec(`INSERT INTO id_hash (hashid,idid) VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id)
if err != nil {
panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err))
}