Improve performance when re-tagging file based tags in zip archives

This commit is contained in:
Timmy Welch 2024-07-27 22:18:02 -07:00
parent d089c4bb6a
commit b8728c5eed
2 changed files with 89 additions and 54 deletions

View File

@ -4,7 +4,6 @@ import logging
import os
import pathlib
import shutil
import struct
import tempfile
import zipfile
from typing import cast
@ -59,13 +58,13 @@ class ZipArchiver(Archiver):
# zip archive w/o the indicated file. Very sucky, but maybe
# another solution can be found
files = self.get_filename_list()
if archive_file in files:
if not self.rebuild([archive_file]):
return False
try:
# now just add the archive file as a new one
with zipfile.ZipFile(self.path, mode="a", allowZip64=True, compression=zipfile.ZIP_DEFLATED) as zf:
_patch_zipfile(zf)
if archive_file in files:
zf.remove(archive_file) # type: ignore
zf.writestr(archive_file, data)
return True
except (zipfile.BadZipfile, OSError) as e:
@ -125,7 +124,7 @@ class ZipArchiver(Archiver):
# preserve the old comment
comment = other_archive.get_comment()
if comment is not None:
if not self.write_zip_comment(self.path, comment):
if not self.set_comment(comment):
return False
except Exception as e:
logger.error("Error while copying to zip archive [%s]: from %s to %s", e, other_archive.path, self.path)
@ -146,59 +145,95 @@ class ZipArchiver(Archiver):
def is_valid(cls, path: pathlib.Path) -> bool:
return zipfile.is_zipfile(path)
def write_zip_comment(self, filename: pathlib.Path | str, comment: str) -> bool:
"""
This is a custom function for writing a comment to a zip file,
since the built-in one doesn't seem to work on Windows and Mac OS/X
Fortunately, the zip comment is at the end of the file, and it's
easy to manipulate. See this website for more info:
see: http://en.wikipedia.org/wiki/Zip_(file_format)#Structure
"""
def _patch_zipfile(zf): # type: ignore
zf.remove = _zip_remove.__get__(zf, zipfile.ZipFile)
zf._remove_members = _zip_remove_members.__get__(zf, zipfile.ZipFile)
# get file size
statinfo = os.stat(filename)
file_length = statinfo.st_size
def _zip_remove(self, zinfo_or_arcname): # type: ignore
"""Remove a member from the archive."""
if self.mode not in ("w", "x", "a"):
raise ValueError("remove() requires mode 'w', 'x', or 'a'")
if not self.fp:
raise ValueError("Attempt to write to ZIP archive that was already closed")
if self._writing:
raise ValueError("Can't write to ZIP archive while an open writing handle exists")
# Make sure we have an existing info object
if isinstance(zinfo_or_arcname, zipfile.ZipInfo):
zinfo = zinfo_or_arcname
# make sure zinfo exists
if zinfo not in self.filelist:
raise KeyError("There is no item %r in the archive" % zinfo_or_arcname)
else:
# get the info object
zinfo = self.getinfo(zinfo_or_arcname)
return self._remove_members({zinfo})
def _zip_remove_members(self, members, *, remove_physical=True, chunk_size=2**20): # type: ignore
"""Remove members in a zip file.
All members (as zinfo) should exist in the zip; otherwise the zip file
will erroneously end in an inconsistent state.
"""
fp = self.fp
entry_offset = 0
member_seen = False
# get a sorted filelist by header offset, in case the dir order
# doesn't match the actual entry order
filelist = sorted(self.filelist, key=lambda x: x.header_offset)
for i in range(len(filelist)):
info = filelist[i]
is_member = info in members
if not (member_seen or is_member):
continue
# get the total size of the entry
try:
with open(filename, mode="r+b") as file:
# the starting position, relative to EOF
pos = -4
found = False
offset = filelist[i + 1].header_offset
except IndexError:
offset = self.start_dir
entry_size = offset - info.header_offset
# walk backwards to find the "End of Central Directory" record
while (not found) and (-pos != file_length):
# seek, relative to EOF
file.seek(pos, 2)
value = file.read(4)
if is_member:
member_seen = True
entry_offset += entry_size
# look for the end of central directory signature
if bytearray(value) == bytearray([0x50, 0x4B, 0x05, 0x06]):
found = True
else:
# not found, step back another byte
pos = pos - 1
# update caches
self.filelist.remove(info)
try:
del self.NameToInfo[info.filename]
except KeyError:
pass
continue
if found:
# now skip forward 20 bytes to the comment length word
pos += 20
file.seek(pos, 2)
# update the header and move entry data to the new position
if remove_physical:
old_header_offset = info.header_offset
info.header_offset -= entry_offset
read_size = 0
while read_size < entry_size:
fp.seek(old_header_offset + read_size)
data = fp.read(min(entry_size - read_size, chunk_size))
fp.seek(info.header_offset + read_size)
fp.write(data)
fp.flush()
read_size += len(data)
# Pack the length of the comment string
fmt = "H" # one 2-byte integer
comment_length = struct.pack(fmt, len(comment)) # pack integer in a binary string
# Avoid missing entry if entries have a duplicated name.
# Reverse the order as NameToInfo normally stores the last added one.
for info in reversed(self.filelist):
self.NameToInfo.setdefault(info.filename, info)
# write out the length
file.write(comment_length)
file.seek(pos + 2, 2)
# update state
if remove_physical:
self.start_dir -= entry_offset
self._didModify = True
# write out the comment itself
file.write(comment.encode("utf-8"))
file.truncate()
else:
raise Exception("Could not find the End of Central Directory record!")
except Exception as e:
logger.error("Error writing comment to zip archive [%s]: %s", e, self.path)
return False
else:
return True
# seek to the start of the central dir
fp.seek(self.start_dir)

View File

@ -89,9 +89,9 @@ def configure_locale() -> None:
os.environ["LANG"] = f"{code}.utf-8"
locale.setlocale(locale.LC_ALL, "")
sys.stdout.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[attr-defined]
sys.stderr.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[attr-defined]
sys.stdin.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[attr-defined]
sys.stdout.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[union-attr]
sys.stderr.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[union-attr]
sys.stdin.reconfigure(encoding=sys.getdefaultencoding()) # type: ignore[union-attr]
def update_publishers(config: settngs.Config[ct_ns]) -> None: