299 lines
9.6 KiB
Python
299 lines
9.6 KiB
Python
"""A python class to manage caching of metadata from comic sources"""
|
|
|
|
#
|
|
# Copyright 2012-2014 ComicTagger Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import datetime
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import sqlite3
|
|
from typing import Any
|
|
|
|
from typing_extensions import NamedTuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Series(NamedTuple):
|
|
id: str
|
|
data: bytes
|
|
|
|
|
|
class Issue(NamedTuple):
|
|
id: str
|
|
series_id: str
|
|
data: bytes
|
|
|
|
|
|
class ComicCacher:
|
|
def __init__(self, cache_folder: pathlib.Path, version: str) -> None:
|
|
self.cache_folder = cache_folder
|
|
self.db_file = cache_folder / "comic_cache.db"
|
|
self.version_file = cache_folder / "cache_version.txt"
|
|
self.version = version
|
|
|
|
# verify that cache is from same version as this one
|
|
data = ""
|
|
try:
|
|
with open(self.version_file, "rb") as f:
|
|
data = f.read().decode("utf-8")
|
|
f.close()
|
|
except Exception:
|
|
pass
|
|
if data != version:
|
|
self.clear_cache()
|
|
|
|
if not os.path.exists(self.db_file):
|
|
self.create_cache_db()
|
|
|
|
def clear_cache(self) -> None:
|
|
try:
|
|
os.unlink(self.db_file)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
os.unlink(self.version_file)
|
|
except Exception:
|
|
pass
|
|
|
|
def create_cache_db(self) -> None:
|
|
# create the version file
|
|
with open(self.version_file, "w", encoding="utf-8") as f:
|
|
f.write(self.version)
|
|
|
|
# this will wipe out any existing version
|
|
open(self.db_file, "wb").close()
|
|
|
|
con = sqlite3.connect(self.db_file)
|
|
con.row_factory = sqlite3.Row
|
|
|
|
# create tables
|
|
with con:
|
|
cur = con.cursor()
|
|
cur.execute(
|
|
"""CREATE TABLE SeriesSearchCache(
|
|
timestamp DATE DEFAULT (datetime('now','localtime')),
|
|
id TEXT NOT NULL,
|
|
source TEXT NOT NULL,
|
|
search_term TEXT,
|
|
PRIMARY KEY (id, source, search_term))"""
|
|
)
|
|
cur.execute("CREATE TABLE Source(id TEXT NOT NULL, name TEXT NOT NULL, PRIMARY KEY (id))")
|
|
|
|
cur.execute(
|
|
"""CREATE TABLE Series(
|
|
timestamp DATE DEFAULT (datetime('now','localtime')),
|
|
id TEXT NOT NULL,
|
|
source TEXT NOT NULL,
|
|
data BLOB,
|
|
complete BOOL,
|
|
PRIMARY KEY (id, source))"""
|
|
)
|
|
|
|
cur.execute(
|
|
"""CREATE TABLE Issues(
|
|
timestamp DATE DEFAULT (datetime('now','localtime')),
|
|
id TEXT NOT NULL,
|
|
source TEXT NOT NULL,
|
|
series_id TEXT,
|
|
data BLOB,
|
|
complete BOOL,
|
|
PRIMARY KEY (id, source))"""
|
|
)
|
|
|
|
def expire_stale_records(self, cur: sqlite3.Cursor, table: str) -> None:
|
|
# purge stale series info
|
|
a_week_ago = datetime.datetime.today() - datetime.timedelta(days=7)
|
|
cur.execute("DELETE FROM Series WHERE timestamp < ?", [str(a_week_ago)])
|
|
|
|
def add_search_results(self, source: str, search_term: str, series_list: list[Series], complete: bool) -> None:
|
|
with sqlite3.connect(self.db_file) as con:
|
|
con.row_factory = sqlite3.Row
|
|
con.text_factory = str
|
|
cur = con.cursor()
|
|
|
|
# remove all previous entries with this search term
|
|
cur.execute(
|
|
"DELETE FROM SeriesSearchCache WHERE search_term = ? AND source = ?",
|
|
[search_term.casefold(), source],
|
|
)
|
|
|
|
# now add in new results
|
|
for series in series_list:
|
|
cur.execute(
|
|
"INSERT INTO SeriesSearchCache (source, search_term, id) VALUES(?, ?, ?)",
|
|
(source, search_term.casefold(), series.id),
|
|
)
|
|
data = {
|
|
"id": series.id,
|
|
"source": source,
|
|
"data": series.data,
|
|
"complete": complete,
|
|
}
|
|
self.upsert(cur, "series", data)
|
|
|
|
def add_series_info(self, source: str, series: Series, complete: bool) -> None:
|
|
with sqlite3.connect(self.db_file) as con:
|
|
con.row_factory = sqlite3.Row
|
|
cur = con.cursor()
|
|
|
|
data = {
|
|
"id": series.id,
|
|
"source": source,
|
|
"data": series.data,
|
|
"complete": complete,
|
|
}
|
|
self.upsert(cur, "series", data)
|
|
|
|
def add_issues_info(self, source: str, issues: list[Issue], complete: bool) -> None:
|
|
with sqlite3.connect(self.db_file) as con:
|
|
con.row_factory = sqlite3.Row
|
|
cur = con.cursor()
|
|
|
|
for issue in issues:
|
|
data = {
|
|
"id": issue.id,
|
|
"series_id": issue.series_id,
|
|
"data": issue.data,
|
|
"source": source,
|
|
"complete": complete,
|
|
}
|
|
self.upsert(cur, "issues", data)
|
|
|
|
def get_search_results(self, source: str, search_term: str, expire_stale: bool = True) -> list[tuple[Series, bool]]:
|
|
results = []
|
|
with sqlite3.connect(self.db_file) as con:
|
|
con.row_factory = sqlite3.Row
|
|
con.text_factory = str
|
|
cur = con.cursor()
|
|
|
|
if expire_stale:
|
|
self.expire_stale_records(cur, "SeriesSearchCache")
|
|
self.expire_stale_records(cur, "Series")
|
|
|
|
cur.execute(
|
|
"""SELECT * FROM SeriesSearchCache INNER JOIN Series on
|
|
SeriesSearchCache.id=Series.id AND SeriesSearchCache.source=Series.source
|
|
WHERE search_term=? AND SeriesSearchCache.source=?""",
|
|
[search_term.casefold(), source],
|
|
)
|
|
|
|
rows = cur.fetchall()
|
|
|
|
for record in rows:
|
|
result = Series(id=record["id"], data=record["data"])
|
|
|
|
results.append((result, record["complete"]))
|
|
|
|
return results
|
|
|
|
def get_series_info(self, series_id: str, source: str, expire_stale: bool = True) -> tuple[Series, bool] | None:
|
|
with sqlite3.connect(self.db_file) as con:
|
|
con.row_factory = sqlite3.Row
|
|
cur = con.cursor()
|
|
con.text_factory = str
|
|
|
|
if expire_stale:
|
|
self.expire_stale_records(cur, "Series")
|
|
|
|
# fetch
|
|
cur.execute("SELECT * FROM Series WHERE id=? AND source=?", [series_id, source])
|
|
|
|
row = cur.fetchone()
|
|
|
|
if row is None:
|
|
return None
|
|
|
|
result = Series(id=row["id"], data=row["data"])
|
|
|
|
return (result, row["complete"])
|
|
|
|
def get_series_issues_info(
|
|
self, series_id: str, source: str, expire_stale: bool = True
|
|
) -> list[tuple[Issue, bool]]:
|
|
with sqlite3.connect(self.db_file) as con:
|
|
con.row_factory = sqlite3.Row
|
|
cur = con.cursor()
|
|
con.text_factory = str
|
|
|
|
if expire_stale:
|
|
self.expire_stale_records(cur, "Issues")
|
|
|
|
# fetch
|
|
results: list[tuple[Issue, bool]] = []
|
|
|
|
cur.execute("SELECT * FROM Issues WHERE series_id=? AND source=?", [series_id, source])
|
|
rows = cur.fetchall()
|
|
|
|
# now process the results
|
|
for row in rows:
|
|
record = (Issue(id=row["id"], series_id=row["series_id"], data=row["data"]), row["complete"])
|
|
|
|
results.append(record)
|
|
|
|
return results
|
|
|
|
def get_issue_info(self, issue_id: str, source: str, expire_stale: bool = True) -> tuple[Issue, bool] | None:
|
|
with sqlite3.connect(self.db_file) as con:
|
|
con.row_factory = sqlite3.Row
|
|
cur = con.cursor()
|
|
con.text_factory = str
|
|
|
|
if expire_stale:
|
|
self.expire_stale_records(cur, "Issues")
|
|
|
|
cur.execute("SELECT * FROM Issues WHERE id=? AND source=?", [issue_id, source])
|
|
row = cur.fetchone()
|
|
|
|
record = None
|
|
|
|
if row:
|
|
record = (Issue(id=row["id"], series_id=row["series_id"], data=row["data"]), row["complete"])
|
|
|
|
return record
|
|
|
|
def upsert(self, cur: sqlite3.Cursor, tablename: str, data: dict[str, Any]) -> None:
|
|
"""This does an insert if the given PK doesn't exist, and an
|
|
update it if does
|
|
|
|
TODO: should the cursor be created here, and not up the stack?
|
|
"""
|
|
|
|
keys = ""
|
|
vals = []
|
|
ins_slots = ""
|
|
set_slots = ""
|
|
|
|
for key in data:
|
|
if data[key] is None:
|
|
continue
|
|
|
|
if keys != "":
|
|
keys += ", "
|
|
if ins_slots != "":
|
|
ins_slots += ", "
|
|
if set_slots != "":
|
|
set_slots += ", "
|
|
|
|
keys += key
|
|
vals.append(data[key])
|
|
ins_slots += "?"
|
|
set_slots += key + " = ?"
|
|
|
|
sql_ins = f"INSERT OR REPLACE INTO {tablename} ({keys}) VALUES ({ins_slots})"
|
|
cur.execute(sql_ins, vals)
|