diff --git a/comicapi/archivers/zip.py b/comicapi/archivers/zip.py index 3c5a44d..f8b6ca3 100644 --- a/comicapi/archivers/zip.py +++ b/comicapi/archivers/zip.py @@ -4,7 +4,6 @@ import logging import os import pathlib import shutil -import struct import tempfile import zipfile from typing import cast @@ -59,13 +58,13 @@ class ZipArchiver(Archiver): # zip archive w/o the indicated file. Very sucky, but maybe # another solution can be found files = self.get_filename_list() - if archive_file in files: - if not self.rebuild([archive_file]): - return False try: # now just add the archive file as a new one with zipfile.ZipFile(self.path, mode="a", allowZip64=True, compression=zipfile.ZIP_DEFLATED) as zf: + _patch_zipfile(zf) + if archive_file in files: + zf.remove(archive_file) # type: ignore zf.writestr(archive_file, data) return True except (zipfile.BadZipfile, OSError) as e: @@ -125,7 +124,7 @@ class ZipArchiver(Archiver): # preserve the old comment comment = other_archive.get_comment() if comment is not None: - if not self.write_zip_comment(self.path, comment): + if not self.set_comment(comment): return False except Exception as e: logger.error("Error while copying to zip archive [%s]: from %s to %s", e, other_archive.path, self.path) @@ -146,59 +145,95 @@ class ZipArchiver(Archiver): def is_valid(cls, path: pathlib.Path) -> bool: return zipfile.is_zipfile(path) - def write_zip_comment(self, filename: pathlib.Path | str, comment: str) -> bool: - """ - This is a custom function for writing a comment to a zip file, - since the built-in one doesn't seem to work on Windows and Mac OS/X - Fortunately, the zip comment is at the end of the file, and it's - easy to manipulate. See this website for more info: - see: http://en.wikipedia.org/wiki/Zip_(file_format)#Structure - """ +def _patch_zipfile(zf): # type: ignore + zf.remove = _zip_remove.__get__(zf, zipfile.ZipFile) + zf._remove_members = _zip_remove_members.__get__(zf, zipfile.ZipFile) - # get file size - statinfo = os.stat(filename) - file_length = statinfo.st_size +def _zip_remove(self, zinfo_or_arcname): # type: ignore + """Remove a member from the archive.""" + + if self.mode not in ("w", "x", "a"): + raise ValueError("remove() requires mode 'w', 'x', or 'a'") + if not self.fp: + raise ValueError("Attempt to write to ZIP archive that was already closed") + if self._writing: + raise ValueError("Can't write to ZIP archive while an open writing handle exists") + + # Make sure we have an existing info object + if isinstance(zinfo_or_arcname, zipfile.ZipInfo): + zinfo = zinfo_or_arcname + # make sure zinfo exists + if zinfo not in self.filelist: + raise KeyError("There is no item %r in the archive" % zinfo_or_arcname) + else: + # get the info object + zinfo = self.getinfo(zinfo_or_arcname) + + return self._remove_members({zinfo}) + + +def _zip_remove_members(self, members, *, remove_physical=True, chunk_size=2**20): # type: ignore + """Remove members in a zip file. + All members (as zinfo) should exist in the zip; otherwise the zip file + will erroneously end in an inconsistent state. + """ + fp = self.fp + entry_offset = 0 + member_seen = False + + # get a sorted filelist by header offset, in case the dir order + # doesn't match the actual entry order + filelist = sorted(self.filelist, key=lambda x: x.header_offset) + for i in range(len(filelist)): + info = filelist[i] + is_member = info in members + + if not (member_seen or is_member): + continue + + # get the total size of the entry try: - with open(filename, mode="r+b") as file: - # the starting position, relative to EOF - pos = -4 - found = False + offset = filelist[i + 1].header_offset + except IndexError: + offset = self.start_dir + entry_size = offset - info.header_offset - # walk backwards to find the "End of Central Directory" record - while (not found) and (-pos != file_length): - # seek, relative to EOF - file.seek(pos, 2) - value = file.read(4) + if is_member: + member_seen = True + entry_offset += entry_size - # look for the end of central directory signature - if bytearray(value) == bytearray([0x50, 0x4B, 0x05, 0x06]): - found = True - else: - # not found, step back another byte - pos = pos - 1 + # update caches + self.filelist.remove(info) + try: + del self.NameToInfo[info.filename] + except KeyError: + pass + continue - if found: - # now skip forward 20 bytes to the comment length word - pos += 20 - file.seek(pos, 2) + # update the header and move entry data to the new position + if remove_physical: + old_header_offset = info.header_offset + info.header_offset -= entry_offset + read_size = 0 + while read_size < entry_size: + fp.seek(old_header_offset + read_size) + data = fp.read(min(entry_size - read_size, chunk_size)) + fp.seek(info.header_offset + read_size) + fp.write(data) + fp.flush() + read_size += len(data) - # Pack the length of the comment string - fmt = "H" # one 2-byte integer - comment_length = struct.pack(fmt, len(comment)) # pack integer in a binary string + # Avoid missing entry if entries have a duplicated name. + # Reverse the order as NameToInfo normally stores the last added one. + for info in reversed(self.filelist): + self.NameToInfo.setdefault(info.filename, info) - # write out the length - file.write(comment_length) - file.seek(pos + 2, 2) + # update state + if remove_physical: + self.start_dir -= entry_offset + self._didModify = True - # write out the comment itself - file.write(comment.encode("utf-8")) - file.truncate() - else: - raise Exception("Could not find the End of Central Directory record!") - except Exception as e: - logger.error("Error writing comment to zip archive [%s]: %s", e, self.path) - return False - else: - return True + # seek to the start of the central dir + fp.seek(self.start_dir) diff --git a/comictaggerlib/main.py b/comictaggerlib/main.py index 7282847..cc61c76 100644 --- a/comictaggerlib/main.py +++ b/comictaggerlib/main.py @@ -89,9 +89,9 @@ def configure_locale() -> None: os.environ["LANG"] = f"{code}.utf-8" locale.setlocale(locale.LC_ALL, "") - sys.stdout.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[attr-defined] - sys.stderr.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[attr-defined] - sys.stdin.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[attr-defined] + sys.stdout.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[union-attr] + sys.stderr.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[union-attr] + sys.stdin.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[union-attr] def update_publishers(config: settngs.Config[ct_ns]) -> None: