Add cli flag
This commit is contained in:
parent
1955444dcf
commit
b1de95021a
@ -78,9 +78,6 @@ func (f Format) String() string {
|
|||||||
return "Unknown"
|
return "Unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
type Encoder func(any) ([]byte, error)
|
|
||||||
type Decoder func([]byte, interface{}) error
|
|
||||||
|
|
||||||
func (f *Format) Set(s string) error {
|
func (f *Format) Set(s string) error {
|
||||||
if format, known := formatValues[strings.ToLower(s)]; known {
|
if format, known := formatValues[strings.ToLower(s)]; known {
|
||||||
*f = format
|
*f = format
|
||||||
@ -90,6 +87,45 @@ func (f *Format) Set(s string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Storage int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Map = iota + 1
|
||||||
|
Sqlite
|
||||||
|
Sqlite3
|
||||||
|
)
|
||||||
|
|
||||||
|
var storageNames = map[Storage]string{
|
||||||
|
Map: "map",
|
||||||
|
Sqlite: "sqlite",
|
||||||
|
Sqlite3: "sqlite3",
|
||||||
|
}
|
||||||
|
|
||||||
|
var storageValues = map[string]Storage{
|
||||||
|
"map": Map,
|
||||||
|
"sqlite": Sqlite,
|
||||||
|
"sqlite3": Sqlite3,
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Storage) String() string {
|
||||||
|
if name, known := storageNames[f]; known {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Storage) Set(s string) error {
|
||||||
|
if storage, known := storageValues[strings.ToLower(s)]; known {
|
||||||
|
*f = storage
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("Unknown storage type: %d", f)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Encoder func(any) ([]byte, error)
|
||||||
|
type Decoder func([]byte, interface{}) error
|
||||||
|
|
||||||
type Opts struct {
|
type Opts struct {
|
||||||
cpuprofile string
|
cpuprofile string
|
||||||
coverPath string
|
coverPath string
|
||||||
@ -98,10 +134,11 @@ type Opts struct {
|
|||||||
saveEmbeddedHashes bool
|
saveEmbeddedHashes bool
|
||||||
format Format
|
format Format
|
||||||
hashesPath string
|
hashesPath string
|
||||||
|
storageType Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
opts := Opts{format: Msgpack} // flag is weird
|
opts := Opts{format: Msgpack, storageType: Map} // flag is weird
|
||||||
go func() {
|
go func() {
|
||||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||||
}()
|
}()
|
||||||
@ -113,6 +150,7 @@ func main() {
|
|||||||
flag.BoolVar(&opts.saveEmbeddedHashes, "save-embedded-hashes", false, "Save hashes even if we loaded the embedded hashes")
|
flag.BoolVar(&opts.saveEmbeddedHashes, "save-embedded-hashes", false, "Save hashes even if we loaded the embedded hashes")
|
||||||
flag.StringVar(&opts.hashesPath, "hashes", "hashes.gz", "Path to optionally gziped hashes in msgpack or json format. You must disable embedded hashes to use this option")
|
flag.StringVar(&opts.hashesPath, "hashes", "hashes.gz", "Path to optionally gziped hashes in msgpack or json format. You must disable embedded hashes to use this option")
|
||||||
flag.Var(&opts.format, "save-format", "Specify the format to export hashes to (json, msgpack)")
|
flag.Var(&opts.format, "save-format", "Specify the format to export hashes to (json, msgpack)")
|
||||||
|
flag.Var(&opts.storageType, "storage-type", "Specify the storage type used internally to search hashes (sqlite,sqlite3,map)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if opts.coverPath != "" {
|
if opts.coverPath != "" {
|
||||||
@ -122,7 +160,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
opts.sqlitePath, _ = filepath.Abs(opts.sqlitePath)
|
opts.sqlitePath, _ = filepath.Abs(opts.sqlitePath)
|
||||||
pretty.Logln(opts)
|
log.Println(pretty.Formatter(opts))
|
||||||
startServer(opts)
|
startServer(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -515,10 +553,10 @@ func (s *Server) HashLocalImages(opts Opts) {
|
|||||||
log.Println("Recieved quit")
|
log.Println("Recieved quit")
|
||||||
}
|
}
|
||||||
err := s.httpServer.Shutdown(context.TODO())
|
err := s.httpServer.Shutdown(context.TODO())
|
||||||
fmt.Println("Err:", err)
|
log.Println("Err:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Println("Hashing covers at ", opts.coverPath)
|
log.Println("Hashing covers at ", opts.coverPath)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
err := filepath.WalkDir(opts.coverPath, func(path string, d fs.DirEntry, err error) error {
|
err := filepath.WalkDir(opts.coverPath, func(path string, d fs.DirEntry, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -544,7 +582,7 @@ func (s *Server) HashLocalImages(opts Opts) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
elapsed := time.Since(start)
|
elapsed := time.Since(start)
|
||||||
fmt.Println("Err:", err, "local hashing took", elapsed)
|
log.Println("Err:", err, "local hashing took", elapsed)
|
||||||
|
|
||||||
sig := <-s.signalQueue
|
sig := <-s.signalQueue
|
||||||
if !alreadyQuit {
|
if !alreadyQuit {
|
||||||
@ -555,6 +593,18 @@ func (s *Server) HashLocalImages(opts Opts) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initializeStorage(opts Opts) (ch.HashStorage, error) {
|
||||||
|
switch opts.storageType {
|
||||||
|
case Map:
|
||||||
|
return ch.NewMapStorage()
|
||||||
|
case Sqlite:
|
||||||
|
return ch.NewSqliteStorage("sqlite", opts.sqlitePath)
|
||||||
|
case Sqlite3:
|
||||||
|
return ch.NewSqliteStorage("sqlite3", opts.sqlitePath)
|
||||||
|
}
|
||||||
|
return nil, errors.New("Unknown storage type provided")
|
||||||
|
}
|
||||||
|
|
||||||
func startServer(opts Opts) {
|
func startServer(opts Opts) {
|
||||||
if opts.cpuprofile != "" {
|
if opts.cpuprofile != "" {
|
||||||
f, err := os.Create(opts.cpuprofile)
|
f, err := os.Create(opts.cpuprofile)
|
||||||
@ -584,32 +634,32 @@ func startServer(opts Opts) {
|
|||||||
}
|
}
|
||||||
Notify(server.signalQueue)
|
Notify(server.signalQueue)
|
||||||
var err error
|
var err error
|
||||||
fmt.Println("init hashes")
|
log.Println("init hashes")
|
||||||
server.hashes, err = ch.NewMapStorage()
|
server.hashes, err = initializeStorage(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("init handlers")
|
log.Println("init handlers")
|
||||||
server.setupAppHandlers()
|
server.setupAppHandlers()
|
||||||
|
|
||||||
fmt.Println("init hashers")
|
log.Println("init hashers")
|
||||||
rwg := sync.WaitGroup{}
|
rwg := sync.WaitGroup{}
|
||||||
for i := range 10 {
|
for i := range 10 {
|
||||||
rwg.Add(1)
|
rwg.Add(1)
|
||||||
go server.reader(i, func() { fmt.Println("Reader completed"); rwg.Done() })
|
go server.reader(i, func() { log.Println("Reader completed"); rwg.Done() })
|
||||||
}
|
}
|
||||||
|
|
||||||
hwg := sync.WaitGroup{}
|
hwg := sync.WaitGroup{}
|
||||||
for i := range 10 {
|
for i := range 10 {
|
||||||
hwg.Add(1)
|
hwg.Add(1)
|
||||||
go server.hasher(i, func() { fmt.Println("Hasher completed"); hwg.Done() })
|
go server.hasher(i, func() { log.Println("Hasher completed"); hwg.Done() })
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("init mapper")
|
log.Println("init mapper")
|
||||||
mwg := sync.WaitGroup{}
|
mwg := sync.WaitGroup{}
|
||||||
mwg.Add(1)
|
mwg.Add(1)
|
||||||
go server.mapper(func() { fmt.Println("Mapper completed"); mwg.Done() })
|
go server.mapper(func() { log.Println("Mapper completed"); mwg.Done() })
|
||||||
|
|
||||||
if opts.loadEmbeddedHashes && len(ch.Hashes) != 0 {
|
if opts.loadEmbeddedHashes && len(ch.Hashes) != 0 {
|
||||||
var err error
|
var err error
|
||||||
@ -658,32 +708,32 @@ func startServer(opts Opts) {
|
|||||||
fmt.Printf("Loaded hashes from %q %s\n", opts.hashesPath, format)
|
fmt.Printf("Loaded hashes from %q %s\n", opts.hashesPath, format)
|
||||||
} else {
|
} else {
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
fmt.Println("No saved hashes to load")
|
log.Println("No saved hashes to load")
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("Unable to load saved hashes", err)
|
log.Println("Unable to load saved hashes", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server.HashLocalImages(opts)
|
server.HashLocalImages(opts)
|
||||||
|
|
||||||
fmt.Println("Listening on ", server.httpServer.Addr)
|
log.Println("Listening on ", server.httpServer.Addr)
|
||||||
err = server.httpServer.ListenAndServe()
|
err = server.httpServer.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
close(server.readerQueue)
|
close(server.readerQueue)
|
||||||
fmt.Println("waiting on readers")
|
log.Println("waiting on readers")
|
||||||
rwg.Wait()
|
rwg.Wait()
|
||||||
for range server.readerQueue {
|
for range server.readerQueue {
|
||||||
}
|
}
|
||||||
close(server.hashingQueue)
|
close(server.hashingQueue)
|
||||||
fmt.Println("waiting on hashers")
|
log.Println("waiting on hashers")
|
||||||
hwg.Wait()
|
hwg.Wait()
|
||||||
for range server.hashingQueue {
|
for range server.hashingQueue {
|
||||||
}
|
}
|
||||||
close(server.mappingQueue)
|
close(server.mappingQueue)
|
||||||
fmt.Println("waiting on mapper")
|
log.Println("waiting on mapper")
|
||||||
mwg.Wait()
|
mwg.Wait()
|
||||||
for range server.mappingQueue {
|
for range server.mappingQueue {
|
||||||
}
|
}
|
||||||
@ -698,14 +748,14 @@ func startServer(opts Opts) {
|
|||||||
gzw := gzip.NewWriter(f)
|
gzw := gzip.NewWriter(f)
|
||||||
_, err := gzw.Write(encodedHashes)
|
_, err := gzw.Write(encodedHashes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Failed to write hashes", err)
|
log.Println("Failed to write hashes", err)
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("Successfully saved hashes")
|
log.Println("Successfully saved hashes")
|
||||||
}
|
}
|
||||||
gzw.Close()
|
gzw.Close()
|
||||||
f.Close()
|
f.Close()
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("Unabled to save hashes", err)
|
log.Println("Unabled to save hashes", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("Unable to encode hashes as %v: %v", opts.format, err)
|
fmt.Printf("Unable to encode hashes as %v: %v", opts.format, err)
|
||||||
|
49
sqlite.go
49
sqlite.go
@ -128,18 +128,18 @@ func (s *sqliteStorage) findPartialHashes(max int, search_hash int64, kind goima
|
|||||||
func (s *sqliteStorage) dropIndexes() error {
|
func (s *sqliteStorage) dropIndexes() error {
|
||||||
_, err := s.db.Exec(`
|
_, err := s.db.Exec(`
|
||||||
|
|
||||||
DROP INDEX IF EXISTS hash_index;
|
DROP INDEX IF EXISTS hash_index;
|
||||||
DROP INDEX IF EXISTS hash_1_index;
|
DROP INDEX IF EXISTS hash_1_index;
|
||||||
DROP INDEX IF EXISTS hash_2_index;
|
DROP INDEX IF EXISTS hash_2_index;
|
||||||
DROP INDEX IF EXISTS hash_3_index;
|
DROP INDEX IF EXISTS hash_3_index;
|
||||||
DROP INDEX IF EXISTS hash_4_index;
|
DROP INDEX IF EXISTS hash_4_index;
|
||||||
DROP INDEX IF EXISTS hash_5_index;
|
DROP INDEX IF EXISTS hash_5_index;
|
||||||
DROP INDEX IF EXISTS hash_6_index;
|
DROP INDEX IF EXISTS hash_6_index;
|
||||||
DROP INDEX IF EXISTS hash_7_index;
|
DROP INDEX IF EXISTS hash_7_index;
|
||||||
DROP INDEX IF EXISTS hash_8_index;
|
DROP INDEX IF EXISTS hash_8_index;
|
||||||
|
|
||||||
DROP INDEX IF EXISTS id_domain;
|
DROP INDEX IF EXISTS id_domain;
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -220,35 +220,46 @@ func (s *sqliteStorage) GetMatches(hashes []Hash, max int, exactOnly bool) ([]Re
|
|||||||
|
|
||||||
func (s *sqliteStorage) MapHashes(hash ImageHash) {
|
func (s *sqliteStorage) MapHashes(hash ImageHash) {
|
||||||
insertHashes, err := s.db.Prepare(`
|
insertHashes, err := s.db.Prepare(`
|
||||||
INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO NOTHING;
|
INSERT INTO Hashes (hash,kind) VALUES (?,?) ON CONFLICT DO UPDATE SET hash=?1 RETURNING hashid;
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
IDInsertResult, err := s.db.Exec(`
|
rows, err := s.db.Query(`
|
||||||
INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO NOTHING;
|
INSERT INTO IDs (domain,id) VALUES (?,?) ON CONFLICT DO UPDATE SET domain=?1 RETURNING idid;
|
||||||
`, hash.ID.Domain, hash.ID.Domain)
|
`, hash.ID.Domain, hash.ID.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
id_id, err := IDInsertResult.LastInsertId()
|
if !rows.Next() {
|
||||||
|
panic("Unable to insert IDs")
|
||||||
|
}
|
||||||
|
var id_id int64
|
||||||
|
err = rows.Scan(&id_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
rows.Close()
|
||||||
hash_ids := []int64{}
|
hash_ids := []int64{}
|
||||||
for _, hash := range hash.Hashes {
|
for _, hash := range hash.Hashes {
|
||||||
hashInsertResult, err := insertHashes.Exec(int64(hash.Hash), hash.Kind)
|
rows, err := insertHashes.Query(int64(hash.Hash), hash.Kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
id, err := hashInsertResult.LastInsertId()
|
|
||||||
|
if !rows.Next() {
|
||||||
|
panic("Unable to insert IDs")
|
||||||
|
}
|
||||||
|
var id int64
|
||||||
|
err = rows.Scan(&id)
|
||||||
|
rows.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
hash_ids = append(hash_ids, id)
|
hash_ids = append(hash_ids, id)
|
||||||
}
|
}
|
||||||
for _, hash_id := range hash_ids {
|
for _, hash_id := range hash_ids {
|
||||||
_, err = s.db.Exec(`INSERT INTO id_hash VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id)
|
_, err = s.db.Exec(`INSERT INTO id_hash (hashid,idid) VALUES (?, ?) ON CONFLICT DO NOTHING;`, hash_id, id_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err))
|
panic(fmt.Errorf("Failed inserting: %v,%v: %w", hash.ID.Domain, hash.ID.ID, err))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user