Add cli flag
This commit is contained in:
parent
1955444dcf
commit
b1de95021a
@ -78,9 +78,6 @@ func (f Format) String() string {
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
type Encoder func(any) ([]byte, error)
|
||||
type Decoder func([]byte, interface{}) error
|
||||
|
||||
func (f *Format) Set(s string) error {
|
||||
if format, known := formatValues[strings.ToLower(s)]; known {
|
||||
*f = format
|
||||
@ -90,6 +87,45 @@ func (f *Format) Set(s string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type Storage int
|
||||
|
||||
const (
|
||||
Map = iota + 1
|
||||
Sqlite
|
||||
Sqlite3
|
||||
)
|
||||
|
||||
var storageNames = map[Storage]string{
|
||||
Map: "map",
|
||||
Sqlite: "sqlite",
|
||||
Sqlite3: "sqlite3",
|
||||
}
|
||||
|
||||
var storageValues = map[string]Storage{
|
||||
"map": Map,
|
||||
"sqlite": Sqlite,
|
||||
"sqlite3": Sqlite3,
|
||||
}
|
||||
|
||||
func (f Storage) String() string {
|
||||
if name, known := storageNames[f]; known {
|
||||
return name
|
||||
}
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
func (f *Storage) Set(s string) error {
|
||||
if storage, known := storageValues[strings.ToLower(s)]; known {
|
||||
*f = storage
|
||||
} else {
|
||||
return fmt.Errorf("Unknown storage type: %d", f)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Encoder func(any) ([]byte, error)
|
||||
type Decoder func([]byte, interface{}) error
|
||||
|
||||
type Opts struct {
|
||||
cpuprofile string
|
||||
coverPath string
|
||||
@ -98,10 +134,11 @@ type Opts struct {
|
||||
saveEmbeddedHashes bool
|
||||
format Format
|
||||
hashesPath string
|
||||
storageType Storage
|
||||
}
|
||||
|
||||
func main() {
|
||||
opts := Opts{format: Msgpack} // flag is weird
|
||||
opts := Opts{format: Msgpack, storageType: Map} // flag is weird
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
@ -113,6 +150,7 @@ func main() {
|
||||
flag.BoolVar(&opts.saveEmbeddedHashes, "save-embedded-hashes", false, "Save hashes even if we loaded the embedded hashes")
|
||||
flag.StringVar(&opts.hashesPath, "hashes", "hashes.gz", "Path to optionally gziped hashes in msgpack or json format. You must disable embedded hashes to use this option")
|
||||
flag.Var(&opts.format, "save-format", "Specify the format to export hashes to (json, msgpack)")
|
||||
flag.Var(&opts.storageType, "storage-type", "Specify the storage type used internally to search hashes (sqlite,sqlite3,map)")
|
||||
flag.Parse()
|
||||
|
||||
if opts.coverPath != "" {
|
||||
@ -122,7 +160,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
opts.sqlitePath, _ = filepath.Abs(opts.sqlitePath)
|
||||
pretty.Logln(opts)
|
||||
log.Println(pretty.Formatter(opts))
|
||||
startServer(opts)
|
||||
}
|
||||
|
||||
@ -515,10 +553,10 @@ func (s *Server) HashLocalImages(opts Opts) {
|
||||
log.Println("Recieved quit")
|
||||
}
|
||||
err := s.httpServer.Shutdown(context.TODO())
|
||||
fmt.Println("Err:", err)
|
||||
log.Println("Err:", err)
|
||||
return
|
||||
}
|
||||
fmt.Println("Hashing covers at ", opts.coverPath)
|
||||
log.Println("Hashing covers at ", opts.coverPath)
|
||||
start := time.Now()
|
||||
err := filepath.WalkDir(opts.coverPath, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
@ -544,7 +582,7 @@ func (s *Server) HashLocalImages(opts Opts) {
|
||||
return nil
|
||||
})
|
||||
elapsed := time.Since(start)
|
||||
fmt.Println("Err:", err, "local hashing took", elapsed)
|
||||
log.Println("Err:", err, "local hashing took", elapsed)
|
||||
|
||||
sig := <-s.signalQueue
|
||||
if !alreadyQuit {
|
||||
@ -555,6 +593,18 @@ func (s *Server) HashLocalImages(opts Opts) {
|
||||
}()
|
||||
}
|
||||
|
||||
func initializeStorage(opts Opts) (ch.HashStorage, error) {
|
||||
switch opts.storageType {
|
||||
case Map:
|
||||
return ch.NewMapStorage()
|
||||
case Sqlite:
|
||||
return ch.NewSqliteStorage("sqlite", opts.sqlitePath)
|
||||
case Sqlite3:
|
||||
return ch.NewSqliteStorage("sqlite3", opts.sqlitePath)
|
||||
}
|
||||
return nil, errors.New("Unknown storage type provided")
|
||||
}
|
||||
|
||||
func startServer(opts Opts) {
|
||||
if opts.cpuprofile != "" {
|
||||
f, err := os.Create(opts.cpuprofile)
|
||||
@ -584,32 +634,32 @@ func startServer(opts Opts) {
|
||||
}
|
||||
Notify(server.signalQueue)
|
||||
var err error
|
||||
fmt.Println("init hashes")
|
||||
server.hashes, err = ch.NewMapStorage()
|
||||
log.Println("init hashes")
|
||||
server.hashes, err = initializeStorage(opts)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println("init handlers")
|
||||
log.Println("init handlers")
|
||||
server.setupAppHandlers()
|
||||
|
||||
fmt.Println("init hashers")
|
||||
log.Println("init hashers")
|
||||
rwg := sync.WaitGroup{}
|
||||
for i := range 10 {
|
||||
rwg.Add(1)
|
||||
go server.reader(i, func() { fmt.Println("Reader completed"); rwg.Done() })
|
||||
go server.reader(i, func() { log.Println("Reader completed"); rwg.Done() })
|
||||
}
|
||||
|
||||
hwg := sync.WaitGroup{}
|
||||
for i := range 10 {
|
||||
hwg.Add(1)
|
||||
go server.hasher(i, func() { fmt.Println("Hasher completed"); hwg.Done() })
|
||||
go server.hasher(i, func() { log.Println("Hasher completed"); hwg.Done() })
|
||||
}
|
||||
|
||||
fmt.Println("init mapper")
|
||||
log.Println("init mapper")
|
||||
mwg := sync.WaitGroup{}
|
||||
mwg.Add(1)
|
||||
go server.mapper(func() { fmt.Println("Mapper completed"); mwg.Done() })
|
||||
go server.mapper(func() { log.Println("Mapper completed"); mwg.Done() })
|
||||
|
||||
if opts.loadEmbeddedHashes && len(ch.Hashes) != 0 {
|
||||
var err error
|
||||
@ -658,32 +708,32 @@ func startServer(opts Opts) {
|
||||
fmt.Printf("Loaded hashes from %q %s\n", opts.hashesPath, format)
|
||||
} else {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
fmt.Println("No saved hashes to load")
|
||||
log.Println("No saved hashes to load")
|
||||
} else {
|
||||
fmt.Println("Unable to load saved hashes", err)
|
||||
log.Println("Unable to load saved hashes", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
server.HashLocalImages(opts)
|
||||
|
||||
fmt.Println("Listening on ", server.httpServer.Addr)
|
||||
log.Println("Listening on ", server.httpServer.Addr)
|
||||
err = server.httpServer.ListenAndServe()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
log.Println(err)
|
||||
}
|
||||
close(server.readerQueue)
|
||||
fmt.Println("waiting on readers")
|
||||
log.Println("waiting on readers")
|
||||
rwg.Wait()
|
||||
for range server.readerQueue {
|
||||
}
|
||||
close(server.hashingQueue)
|
||||
fmt.Println("waiting on hashers")
|
||||
log.Println("waiting on hashers")
|
||||
hwg.Wait()
|
||||
for range server.hashingQueue {
|
||||
}
|
||||
close(server.mappingQueue)
|
||||
fmt.Println("waiting on mapper")
|
||||
log.Println("waiting on mapper")
|
||||
mwg.Wait()
|
||||
for range server.mappingQueue {
|
||||
}
|
||||
@ -698,14 +748,14 @@ func startServer(opts Opts) {
|
||||
gzw := gzip.NewWriter(f)
|
||||
_, err := gzw.Write(encodedHashes)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to write hashes", err)
|
||||
log.Println("Failed to write hashes", err)
|
||||
} else {
|
||||
fmt.Println("Successfully saved hashes")
|
||||
log.Println("Successfully saved hashes")
|
||||
}
|
||||
gzw.Close()
|
||||
f.Close()
|
||||
} else {
|
||||
fmt.Println("Unabled to save hashes", err)
|
||||
log.Println("Unabled to save hashes", err)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("Unable to encode hashes as %v: %v", opts.format, err)
|
||||
|
49
sqlite.go
49
sqlite.go
@ -128,18 +128,18 @@ func (s *sqliteStorage) findPartialHashes(max int, search_hash int64, kind goima
|
||||
func (s *sqliteStorage) dropIndexes() error {
|
||||
_, err := s.db.Exec(`
|
||||
|
||||
DROP INDEX IF EXISTS hash_index;
|
||||
DROP INDEX IF EXISTS hash_1_index;
|
||||
DROP INDEX IF EXISTS hash_2_index;
|
||||
DROP INDEX IF EXISTS hash_3_index;
|
||||
DROP INDEX IF EXISTS hash_4_index;
|
||||
DROP INDEX IF EXISTS hash_5_index;
|
||||
DROP INDEX IF EXISTS hash_6_index;
|
||||
DROP INDEX IF EXISTS hash_7_index;
|
||||
DROP INDEX IF EXISTS hash_8_index;
|
||||
DROP INDEX IF EXISTS hash_index;
|
||||
DROP INDEX IF EXISTS hash_1_index;
|
||||
DROP INDEX IF EXISTS hash_2_index;
|
||||
DROP INDEX IF EXISTS hash_3_index;
|
||||
DROP INDEX IF EXISTS hash_4_index;
|
||||
DROP INDEX IF EXISTS hash_5_index;
|
||||
DROP INDEX IF EXISTS hash_6_index;
|
||||
DROP INDEX IF EXISTS hash_7_index;
|
||||
DROP INDEX IF EXISTS hash_8_index;
|
||||
|
||||
DROP INDEX IF EXISTS id_domain;
|
||||
`)
|
||||
DROP INDEX IF EXISTS id_domain;
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -220,35 +220,46 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re
|
||||
|
||||
func (s *sqliteStorage) MapHashes(hash ImageHash) {
|
||||
insertHashes, err := s.db.Prepare(`
|
||||
INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO NOTHING;
|
||||
INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO UPDATE SET hash=?1 RETURNING hashid;
|
||||
`)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
IDInsertResult, err := s.db.Exec(`
|
||||
INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO NOTHING;
|
||||
`, hash.ID.Domain, hash.ID.Domain)
|
||||
rows, err := s.db.Query(`
|
||||
INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RETURNING idid;
|
||||
`, hash.ID.Domain, hash.ID.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
id_id, err := IDInsertResult.LastInsertId()
|
||||
if !rows.Next() {
|
||||
panic("Unable to insert IDs")
|
||||
}
|
||||
var id_id int64
|
||||
err = rows.Scan(&id_id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rows.Close()
|
||||
hash_ids := []int64{}
|
||||
for _, hash := range hash.Hashes {
|
||||
hashInsertResult, err := insertHashes.Exec(int64(hash.Hash), hash.Kind)
|
||||
rows, err := insertHashes.Query(int64(hash.Hash), hash.Kind)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
id, err := hashInsertResult.LastInsertId()
|
||||
|
||||
if !rows.Next() {
|
||||
panic("Unable to insert IDs")
|
||||
}
|
||||
var id int64
|
||||
err = rows.Scan(&id)
|
||||
rows.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
hash_ids = append(hash_ids, id)
|
||||
}
|
||||
for _, hash_id := range hash_ids {
|
||||
_, err = s.db.Exec(`INSERT INTO id_hash VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id)
|
||||
_, err = s.db.Exec(`INSERT INTO id_hash (hashid,idid) VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user