#!/usr/bin python3 """ Utilities for working with images """ from __future__ import annotations import json import logging import os import struct import typing as T from ast import literal_eval from concurrent import futures from queue import Empty as QueueEmpty, Full as QueueFull, Queue from threading import current_thread, main_thread from zlib import crc32 import cv2 import numpy as np from lib.align.objects import PNGHeader from lib.logger import parse_class_init from lib.multithreading import FSThread from lib.utils import FaceswapError, get_image_paths, get_module_objects from lib.video import check_for_video, VideoReader if T.TYPE_CHECKING: import numpy.typing as npt from lib.multithreading import ErrorState logger = logging.getLogger(__name__) # Image I/O @T.overload def read_image(filename: str, raise_error: T.Literal[False] = False, with_metadata: T.Literal[False] = False) -> npt.NDArray[np.uint8] | None: ... @T.overload def read_image(filename: str, raise_error: T.Literal[True], with_metadata: T.Literal[False] = False) -> npt.NDArray[np.uint8]: ... @T.overload def read_image(filename: str, raise_error: T.Literal[False] = False, *, with_metadata: T.Literal[True]) -> tuple[npt.NDArray[np.uint8], PNGHeader]: ... @T.overload def read_image(filename: str, raise_error: T.Literal[True], with_metadata: T.Literal[True]) -> npt.NDArray[np.uint8]: ... def read_image(filename: str, # noqa[C901] # pylint:disable=too-many-statements,too-many-branches raise_error: bool = False, with_metadata: bool = False ) -> np.ndarray | None | tuple[npt.NDArray[np.uint8], PNGHeader]: """Read an image file from a file location. Extends the functionality of :func:`cv2.imread()` by ensuring that an image was actually loaded. Errors can be logged and ignored so that the process can continue on an image load failure. Parameters ---------- filename Full path to the image to be loaded. raise_error If ``True`` then any failures (including the returned image being ``None``) will be raised. If ``False`` then an error message will be logged, but the error will not be raised. Default: ``False`` with_metadata Only returns a value if the images loaded are extracted Faceswap faces. If ``True`` then returns the Faceswap metadata stored with in a Face images .png EXIF header. Default: ``False`` Returns ------- image The image in `BGR` channel order as UINT8 for the corresponding :attr:`filename` metadata The faceswap metadata corresponding to the image. Only returned if `with_metadata` is ``True`` Example ------- >>> image_file = "/path/to/image.png" >>> try: >>> image = read_image(image_file, raise_error=True, with_metadata=False) >>> except: >>> raise ValueError("There was an error") """ logger.trace("Requested image: '%s'", filename) # type:ignore[attr-defined] success = True image = None retval: np.ndarray | tuple[np.ndarray, PNGHeader] | None = None try: with open(filename, "rb") as in_file: raw_file = in_file.read() image = cv2.imdecode(np.frombuffer(raw_file, dtype=np.uint8), cv2.IMREAD_UNCHANGED) if image is None: raise ValueError("Image is None") if image.ndim == 2: # Convert grayscale to BGR image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) elif image.ndim == 2 and image.shape[2] == 4: # Strip mask image = image[:, :, :3] if np.issubdtype(image.dtype, np.integer): info = np.iinfo(T.cast(np.integer, image.dtype)) # Scale non UINT8 INT images to UINT8 if info.max != 255: image = image.astype(np.float32) / info.max * 255.0 elif np.issubdtype(image.dtype, np.floating): # Just naively clip floating images to 0-1 for now image = (np.clip(image, 0.0, 1.0) * 255.).astype(np.float32) if image.dtype != np.uint8: image = np.clip(image, 0, 255).astype(np.uint8) if with_metadata: metadata = png_read_meta(raw_file) assert isinstance(metadata, PNGHeader) retval = (image, metadata) else: retval = image except TypeError as err: success = False msg = f"Error while reading image (TypeError): '{filename}'" msg += f". Original error message: {str(err)}" logger.error(msg) if raise_error: raise TypeError(msg) from err except ValueError as err: success = False msg = ("Error while reading image. This can be caused by special characters in the " f"filename or a corrupt image file: '{filename}'") msg += f". Original error message: {str(err)}" logger.error(msg) if raise_error: raise ValueError(msg) from err except Exception as err: # pylint:disable=broad-except success = False msg = f"Failed to load image '{filename}'. Original Error: {str(err)}" logger.error(msg) if raise_error: raise Exception(msg) from err # pylint:disable=broad-exception-raised logger.trace("Loaded image: '%s'. Success: %s", filename, success) # type:ignore[attr-defined] return retval @T.overload def read_image_batch(filenames: list[str], with_metadata: T.Literal[False] = False ) -> np.ndarray: ... @T.overload def read_image_batch(filenames: list[str], with_metadata: T.Literal[True] ) -> tuple[np.ndarray, list[PNGHeader]]: ... def read_image_batch(filenames: list[str], with_metadata: bool = False ) -> np.ndarray | tuple[np.ndarray, list[PNGHeader]]: """Load a batch of images from the given file locations. Leverages multi-threading to load multiple images from disk at the same time leading to vastly reduced image read times. Parameters ---------- filenames A of full paths to the images to be loaded. with_metadata Only returns a value if the images loaded are extracted Faceswap faces. If ``True`` then returns the Faceswap metadata stored within each Face's .png exif header. Default: ``False`` Returns ------- batch The batch of images in `BGR` channel order returned in the order of :attr:`filenames` metadata The faceswap metadata corresponding to each image in the batch. Only returned if `with_metadata` is ``True`` Notes ----- As the images are compiled into a batch, they should be all of the same dimensions, otherwise a homogenous array will be returned Example ------- >>> image_filenames = ["/path/to/image_1.png", "/path/to/image_2.png", "/path/to/image_3.png"] >>> images = read_image_batch(image_filenames) >>> print(images.shape) ... (3, 64, 64, 3) >>> images, metadata = read_image_batch(image_filenames, with_metadata=True) >>> print(images.shape) ... (3, 64, 64, 3) >>> print(len(metadata)) ... 3 """ logger.trace("Requested batch: '%s'", filenames) # type:ignore[attr-defined] batch: list[np.ndarray | None] = [None for _ in range(len(filenames))] meta: list[PNGHeader | None] = [None for _ in range(len(filenames))] with futures.ThreadPoolExecutor() as executor: images = {executor.submit( # NOTE submit strips positionals, breaking type-checking read_image, # type:ignore[arg-type] filename, raise_error=True, # pyright:ignore[reportArgumentType] with_metadata=with_metadata): idx # pyright:ignore[reportArgumentType] for idx, filename in enumerate(filenames)} for future in futures.as_completed(images): result = T.cast(np.ndarray | tuple[np.ndarray, "PNGHeader"], future.result()) ret_idx = images[future] if with_metadata: assert isinstance(result, tuple) batch[ret_idx], meta[ret_idx] = result else: assert isinstance(result, np.ndarray) batch[ret_idx] = result arr_batch = np.array(batch) retval: np.ndarray | tuple[np.ndarray, list[PNGHeader]] if with_metadata: retval = (arr_batch, T.cast(list["PNGHeader"], meta)) else: retval = arr_batch logger.trace( # type:ignore[attr-defined] "Returning images: (filenames: %s, batch shape: %s, with_metadata: %s)", filenames, arr_batch.shape, with_metadata) return retval def read_image_meta(filename): """ Read the Faceswap metadata stored in an extracted face's exif header. Parameters ---------- filename: str Full path to the image to be retrieve the meta information for. Returns ------- dict The output dictionary will contain the `width` and `height` of the png image as well as any `itxt` information. Example ------- >>> image_file = "/path/to/image.png" >>> metadata = read_image_meta(image_file) >>> width = metadata["width] >>> height = metadata["height"] >>> faceswap_info = metadata["itxt"] """ retval = {} if os.path.splitext(filename)[-1].lower() != ".png": # Get the dimensions directly from the image for non-png logger.trace( # type:ignore[attr-defined] "Non png found. Loading file for dimensions: '%s'", filename) img = cv2.imread(filename) assert img is not None retval["height"], retval["width"] = img.shape[:2] return retval with open(filename, "rb") as in_file: try: chunk = in_file.read(8) except PermissionError as exc: raise PermissionError(f"PermissionError while reading: {filename}") from exc if chunk != b"\x89PNG\r\n\x1a\n": raise ValueError(f"Invalid header found in png: {filename}") while True: chunk = in_file.read(8) length, field = struct.unpack(">I4s", chunk) logger.trace( # type:ignore[attr-defined] "Read chunk: (chunk: %s, length: %s, field: %s", chunk, length, field) if not chunk or field == b"IDAT": break if field == b"IHDR": # Get dimensions chunk = in_file.read(8) retval["width"], retval["height"] = struct.unpack(">II", chunk) length -= 8 elif field == b"iTXt": keyword, value = in_file.read(length).split(b"\0", 1) if keyword == b"faceswap": retval["itxt"] = literal_eval(value[4:].decode("utf-8", errors="replace")) break logger.trace("Skipping iTXt chunk: '%s'", # type:ignore[attr-defined] keyword.decode("latin-1", errors="ignore")) length = 0 # Reset marker for next chunk in_file.seek(length + 4, 1) logger.trace("filename: %s, metadata: %s", filename, retval) # type:ignore[attr-defined] return retval def read_image_meta_batch(filenames): """ Read the Faceswap metadata stored in a batch extracted faces' exif headers. Leverages multi-threading to load multiple images from disk at the same time leading to vastly reduced image read times. Creates a generator to retrieve filenames with their metadata as they are calculated. Notes ----- The order of returned values is non-deterministic so will most likely not be returned in the same order as the filenames Parameters ---------- filenames: list A list of ``str`` full paths to the images to be loaded. Yields ------- tuple (**filename** (`str`), **metadata** (`dict`) ) Example ------- >>> image_filenames = ["/path/to/image_1.png", "/path/to/image_2.png", "/path/to/image_3.png"] >>> for filename, meta in read_image_meta_batch(image_filenames): >>> """ logger.trace("Requested batch: '%s'", filenames) # type:ignore[attr-defined] executor = futures.ThreadPoolExecutor() with executor: logger.debug("Submitting %s items to executor", len(filenames)) read_meta = {executor.submit(read_image_meta, filename): filename for filename in filenames} logger.debug("Successfully submitted %s items to executor", len(filenames)) for future in futures.as_completed(read_meta): retval = (read_meta[future], future.result()) logger.trace("Yielding: %s", retval) # type:ignore[attr-defined] yield retval def pack_to_itxt(metadata: PNGHeader | dict[str, T.Any] | bytes) -> bytes: """ Pack the given metadata dictionary to a PNG iTXt header field. Parameters ---------- metadata The dictionary to write to the header. Can be pre-encoded as utf-8. Returns ------- A byte encoded PNG iTXt field, including chunk header and CRC """ if isinstance(metadata, PNGHeader): metadata = metadata.to_dict() if not isinstance(metadata, bytes): metadata = str(metadata).encode("utf-8", "strict") key = "faceswap".encode("latin-1", "strict") chunk = key + b"\0\0\0\0\0" + metadata crc = struct.pack(">I", crc32(chunk, crc32(b"iTXt")) & 0xFFFFFFFF) length = struct.pack(">I", len(chunk)) retval = length + b"iTXt" + chunk + crc return retval def update_existing_metadata(filename: str, metadata: PNGHeader | bytes) -> None: """ Update the png header metadata for an existing .png extracted face file on the filesystem. Parameters ---------- filename The full path to the face to be updated metadata The dictionary to write to the header. Can be pre-encoded as utf-8. """ if not isinstance(metadata, bytes): metadata = str(metadata.to_dict()).encode("utf-8", errors="strict") tmp_filename = filename + "~" with open(filename, "rb") as png, open(tmp_filename, "wb") as tmp: chunk = png.read(8) if chunk != b"\x89PNG\r\n\x1a\n": raise ValueError(f"Invalid header found in png: {filename}") tmp.write(chunk) while True: chunk = png.read(8) length, field = struct.unpack(">I4s", chunk) logger.trace( # type:ignore[attr-defined] "Read chunk: (chunk: %s, length: %s, field: %s)", chunk, length, field) if field == b"IDAT": # Write out all remaining data logger.trace("Writing image data and closing png") # type:ignore[attr-defined] tmp.write(chunk + png.read()) break if field != b"iTXt": # Write non iTXt chunk straight out logger.trace("Copying existing chunk") # type:ignore[attr-defined] tmp.write(chunk + png.read(length + 4)) # Header + CRC continue keyword, value = png.read(length).split(b"\0", 1) if keyword != b"faceswap": # Write existing non fs-iTXt data + CRC logger.trace("Copying non-faceswap iTXt chunk: %s", # type:ignore[attr-defined] keyword) tmp.write(keyword + b"\0" + value + png.read(4)) continue logger.trace("Updating faceswap iTXt chunk") # type:ignore[attr-defined] tmp.write(pack_to_itxt(metadata)) png.seek(4, 1) # Skip old CRC os.replace(tmp_filename, filename) def encode_image(image: np.ndarray, extension: str, encoding_args: tuple[int, ...] | None = None, metadata: PNGHeader | dict[str, T.Any] | bytes | None = None) -> bytes: """Encode an image. Parameters ---------- image The image to be encoded in `BGR` channel order. extension A compatible `cv2` image file extension that the final image is to be saved to. encoding_args Any encoding arguments to pass to cv2's imencode function metadata Metadata for the image. If provided, and the extension is png or tiff, this information will be written to the PNG itxt header. Default:``None`` Can be provided as a python dict or pre-encoded Returns ------- encoded_image: bytes The image encoded into the correct file format as bytes Example ------- >>> image_file = "/path/to/image.png" >>> image = read_image(image_file) >>> encoded_image = encode_image(image, ".jpg") """ if metadata and extension.lower() not in (".png", ".tif"): raise ValueError("Metadata is only supported for .png and .tif images") args = tuple() if encoding_args is None else encoding_args retval = cv2.imencode(extension, image, args)[1].tobytes() if metadata: func = {".png": png_write_meta, ".tif": tiff_write_meta}[extension] retval = func(retval, metadata) return retval def png_write_meta(image: bytes, data: PNGHeader | dict[str, T.Any] | bytes) -> bytes: """Write Faceswap information to a png's iTXt field. Parameters ---------- image The bytes encoded png file to write header data to data The dictionary to write to the header. Can be pre-encoded as utf-8. Notes ----- This is a fairly stripped down and non-robust header writer to fit a very specific task. OpenCV will not write any iTXt headers to the PNG file, so we make the assumption that the only iTXt header that exists is the one that we created for storing alignments. References ---------- PNG Specification: https://www.w3.org/TR/2003/REC-PNG-20031110/ """ split = image.find(b"IDAT") - 4 retval = image[:split] + pack_to_itxt(data) + image[split:] return retval def tiff_write_meta(image: bytes, # pylint:disable=too-many-locals data: PNGHeader | dict[str, T.Any] | bytes) -> bytes: """Write Faceswap information to a tiff's image_description field. Parameters ---------- png The bytes encoded tiff file to write header data to data The data to write to the image-description field. If provided as a dict, then it should be a json serializable object, otherwise it should be data encoded as ascii bytes Notes ----- This handles a very specific task of adding, and populating, an ImageDescription field in a Tiff file generated by OpenCV. For any other use cases it will likely fail """ if isinstance(data, PNGHeader): data = data.to_dict() if not isinstance(data, bytes): data = json.dumps(data, ensure_ascii=True).encode("ascii") assert image[:2] == b"II", "Not a supported TIFF file" assert struct.unpack(" 270: insert_idx = i # Log insert location of image description if size <= 4: # value in offset column ifd += tag continue ifd += tag[:8] tag_offset = struct.unpack(" dict[str, T.Any]: # pylint:disable=too-many-locals """ Read information stored in a Tiff's Image Description field Returns ------- dict[str, Any] Any arbitrary information stored in the TIFF header (for example matrix information for the patch writer) """ assert image[:2] == b"II", "Not a supported TIFF file" assert struct.unpack(" PNGHeader | dict[str, T.Any]: """ Read the Faceswap information stored in a png's iTXt field. Parameters ---------- image The bytes encoded png file to read header data from Returns ------- The Faceswap information stored in the PNG header. This will either be a PNGHeader object if an extracted face, or other arbitrary information (for example for the Patch Writer) Notes ----- This is a very stripped down, non-robust and non-secure header reader to fit a very specific task. OpenCV will not write any iTXt headers to the PNG file, so we make the assumption that the only iTXt header that exists is the one that Faceswap created for storing alignments. """ retval: PNGHeader | dict[str, T.Any] | None = None pointer = 0 while True: pointer = image.find(b"iTXt", pointer) - 4 if pointer < 0: logger.trace("No metadata in png") # type:ignore[attr-defined] break length = struct.unpack(">I", image[pointer:pointer + 4])[0] pointer += 8 keyword, value = image[pointer:pointer + length].split(b"\0", 1) if keyword == b"faceswap": retval = PNGHeader.from_dict(literal_eval(value[4:].decode("utf-8", errors="ignore"))) break logger.trace("Skipping iTXt chunk: '%s'", # type:ignore[attr-defined] keyword.decode("latin-1", errors="ignore")) pointer += length + 4 assert retval is not None return retval def generate_thumbnail(image, size=96, quality=60): """ Generate a jpg thumbnail for the given image. Parameters ---------- image: :class:`numpy.ndarray` Three channel BGR image to convert to a jpg thumbnail size: int The width and height, in pixels, that the thumbnail should be generated at quality: int The jpg quality setting to use Returns ------- :class:`numpy.ndarray` The given image encoded to a jpg at the given size and quality settings """ logger.trace("Input shape: %s, size: %s, quality: %s", # type:ignore[attr-defined] image.shape, size, quality) orig_size = image.shape[0] if orig_size != size: interpolator = cv2.INTER_AREA if orig_size > size else cv2.INTER_CUBIC image = cv2.resize(image, (size, size), interpolation=interpolator) retval = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])[1] logger.trace("Output shape: %s", retval.shape) # type:ignore[attr-defined] return retval def batch_convert_color(batch, color_space): """ Convert a batch of images from one color space to another. Converts a batch of images by reshaping the batch prior to conversion rather than iterating over the images. This leads to a significant speed up in the convert process. Parameters ---------- batch: numpy.ndarray A batch of images. color_space: str The OpenCV Color Conversion Code suffix. For example for BGR to LAB this would be ``'BGR2LAB'``. See https://docs.opencv.org/4.1.1/d8/d01/group__imgproc__color__conversions.html for a full list of color codes. Returns ------- numpy.ndarray The batch converted to the requested color space. Example ------- >>> images_bgr = numpy.array([image1, image2, image3]) >>> images_lab = batch_convert_color(images_bgr, "BGR2LAB") Notes ----- This function is only compatible for color space conversions that have the same image shape for source and destination color spaces. If you use :func:`batch_convert_color` with 8-bit images, the conversion will have some information lost. For many cases, this will not be noticeable but it is recommended to use 32-bit images in cases that need the full range of colors or that convert an image before an operation and then convert back. """ logger.trace( # type:ignore[attr-defined] "Batch converting: (batch shape: %s, color_space: %s)", batch.shape, color_space) original_shape = batch.shape batch = batch.reshape((original_shape[0] * original_shape[1], *original_shape[2:])) batch = cv2.cvtColor(batch, getattr(cv2, f"COLOR_{color_space}")) return batch.reshape(original_shape) def hex_to_rgb(hex_code): """ Convert a hex number to it's RGB counterpart. Parameters ---------- hex_code: str The hex code to convert (e.g. `"#0d25ac"`) Returns ------- tuple The hex code as a 3 integer (`R`, `G`, `B`) tuple """ value = hex_code.lstrip("#") chars = len(value) return tuple(int(value[i:i + chars // 3], 16) for i in range(0, chars, chars // 3)) def rgb_to_hex(rgb): """ Convert an RGB tuple to it's hex counterpart. Parameters ---------- rgb: tuple The (`R`, `G`, `B`) integer values to convert (e.g. `(0, 255, 255)`) Returns ------- str: The 6 digit hex code with leading `#` applied """ return f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}" # ################### # # <<< VIDEO UTILS >>> # # ################### # class ImageIO(): """ Perform disk IO for images or videos in a background thread. This is the parent thread for :class:`ImagesLoader` and :class:`ImagesSaver` and should not be called directly. Parameters ---------- path: str or list The path to load or save images to/from. For loading this can be a folder which contains images, video file or a list of image files. For saving this must be an existing folder. queue_size: int The amount of images to hold in the internal buffer. args: tuple, optional The arguments to be passed to the loader or saver thread. Default: ``None`` See Also -------- lib.image.ImagesLoader : Background Image Loader inheriting from this class. lib.image.ImagesSaver : Background Image Saver inheriting from this class. """ def __init__(self, path, queue_size, args=None): logger.debug(parse_class_init(locals())) self._name = self.__class__.__name__ self._args = tuple() if args is None else args self._location = path self._check_location_exists() self._queue = Queue(maxsize=queue_size) self._thread = None self._error_state: ErrorState | None = None @property def location(self): """ str: The folder or video that was passed in as the :attr:`path` parameter. """ return self._location def _check_location_exists(self): """ Check whether the input location exists. Raises ------ FaceswapError If the given location does not exist """ if isinstance(self.location, str) and not os.path.exists(self.location): raise FaceswapError(f"The location '{self.location}' does not exist") if isinstance(self.location, (list, tuple)) and not all(os.path.exists(location) for location in self.location): raise FaceswapError("Not all locations in the input list exist") def _set_thread(self): """ Set the background thread for the load and save iterators and launch it. """ logger.trace("[%s] Setting thread", self._name) # type:ignore[attr-defined] if self._thread is not None and self._thread.is_alive(): logger.trace("[%s] Thread pre-exists and is alive: %s", # type:ignore[attr-defined] self._name, self._thread) return self._thread = FSThread(self._process, name=self.__class__.__name__, args=(self._queue, )) self._error_state = self._thread.error_state logger.debug("[%s] Set thread: %s", self._name, self._thread) self._thread.start() def _process(self, queue): """ Image IO process to be run in a thread. Override for loader/saver process. Parameters ---------- queue: queue.Queue() The ImageIO Queue """ raise NotImplementedError def close(self): """ Closes down and joins the internal threads """ logger.debug("[%s] Received Close", self._name) if self._thread is not None: self._thread.join() del self._thread self._thread = None logger.debug("[%s] Closed", self._name) class ImagesLoader(ImageIO): """Perform image loading from a folder of images or a video. Images will be loaded and returned in the order that they appear in the folder, or in the video to ensure deterministic ordering. Loading occurs in a background thread, caching 8 images at a time so that other processes do not need to wait on disk reads. See also :class:`ImageIO` for additional attributes. Parameters ---------- path The path to load images from. This can be a folder which contains images a video file or a list of image files. queue_size The amount of images to hold in the internal buffer. Default: 8. fast_count When loading from video, the video needs to be parsed frame by frame to get an accurate count. This can be done quite quickly without guaranteed accuracy, or slower with guaranteed accuracy. Set to ``True`` to count quickly, or ``False`` to count slower but accurately. Default: ``True``. skip_list Optional list of frame/image indices to not load. Any indices provided here will be skipped when executing the :func:`load` function from the given location. Default: ``None`` count If the number of images that the loader will encounter is already known, it can be passed in here to skip the image counting step, which can save time at launch. Set to ``None`` if the count is not already known. Default: ``None`` pts The Presentation Timestamps if the source is a video and they are available. Default: ``None`` keyframes The Keyframes if the source is a video and they are available. Default: ``None`` Examples -------- Loading from a video file: >>> loader = ImagesLoader('/path/to/video.mp4') >>> for filename, image in loader.load(): >>> """ def __init__(self, path: str | list[str], queue_size: int = 8, fast_count: bool = True, skip_list: list[int] | None = None, count: int | None = None, pts: list[int] | None = None, keyframes: list[int] | None = None) -> None: logger.debug(parse_class_init(locals())) super().__init__(path, queue_size=queue_size) self._skip_list = set() if skip_list is None else set(skip_list) self._is_video = check_for_video(self.location) self._count: int | None = None self._file_list: list[str] = [] self._reader = VideoReader(self.location, fast_count=fast_count, pts=pts, keyframes=keyframes) if self._is_video else None self._get_count_and_filelist(count) @property def count(self) -> int: """The number of images or video frames in the source location. This count includes any files that will ultimately be skipped if a :attr:`skip_list` has been provided. See also :attr:`process_count`""" assert self._count is not None return self._count @property def process_count(self) -> int: """The number of images or video frames to be processed (IE the total count less items that are to be skipped from the :attr:`skip_list`)""" return self.count - len(self._skip_list) @property def is_video(self) -> bool: """``True`` if the input is a video, ``False`` if it is not""" return self._is_video @property def file_list(self) -> list[str]: """A full list of files in the source location. This includes any files that will ultimately be skipped if a :attr:`skip_list` has been provided. If the input is a video then this is a list of dummy filenames as corresponding to an alignments file """ return self._file_list @property def processed_file_list(self) -> list[str]: """A list of files in the source location with any files that will be skipped removed""" return [f for i, f in enumerate(self._file_list) if i not in self._skip_list] def add_skip_list(self, skip_list: list[int]) -> None: """Add a skip list to this :class:`ImagesLoader` Parameters ---------- skip_list A list of indices corresponding to the frame indices that should be skipped by the :func:`load` function. """ logger.debug("[%s] skip_list: %s", self._name, skip_list) self._skip_list = set(skip_list) def _get_count_and_filelist(self, count: int | None) -> None: """Set the count of images to be processed and set the file list. If the input is a video, a dummy file list is created for checking against an alignments file, otherwise it will be a list of full filenames. Parameters ---------- count: int The number of images that the loader will encounter if already known, otherwise ``None`` """ if self._is_video: assert self._reader is not None self._count = len(self._reader) self._file_list = [self._dummy_video_frame_name(i) for i in range(self.count)] else: if isinstance(self.location, (list, tuple)): self._file_list = list(self.location) else: self._file_list = get_image_paths(self.location) self._count = len(self.file_list) if count is None else count logger.debug("[%s] count: %s", self._name, self.count) logger.trace("[%s] file_list: %s", self._name, self.file_list) # type:ignore[attr-defined] def _process(self, queue: Queue) -> None: """The load thread. Loads from a folder of images or from a video and puts to a queue Parameters ---------- queue The ImageIO Queue """ iterator = self._from_video if self._is_video else self._from_folder logger.debug("[%s] Load iterator: %s", self._name, iterator) assert self._error_state is not None for retval in iterator(): filename, image = retval[:2] if image is None or (not image.any() and image.ndim not in (2, 3)): # All black frames will return not numpy.any() so check dims too logger.warning("Unable to open image. Skipping: '%s'", filename) continue logger.trace("[%s] Putting to queue: %s", # type:ignore[attr-defined] self._name, [v.shape if isinstance(v, np.ndarray) else v for v in retval]) while True: if self._error_state.has_error: logger.debug("[%s] Thread error detected in worker thread", self._name) return try: queue.put(retval, timeout=0.2) break except QueueFull: logger.trace("[%s] Queue full. Waiting", # type:ignore[attr-defined] self._name) continue logger.trace("[%s] Putting EOF", self._name) # type:ignore[attr-defined] queue.put("EOF") def _dummy_video_frame_name(self, index: int) -> str: """Return a dummy filename for video files. The file name is made up of: _. Notes ----- Indexes start at 0, frame numbers start at 1, so index is incremented by 1 when creating the filename Parameters ---------- index The index number for the frame in the video file Returns ------- A dummied filename for a video frame """ vid_name, ext = os.path.splitext(os.path.basename(self.location)) return f"{vid_name}_{index + 1:06d}{ext}" def _from_video(self) -> T.Generator[tuple[str, npt.NDArray[np.uint8]], None, None]: """Generator for loading frames from a video Yields ------ filename The dummy filename of the loaded video frame. image The loaded video frame. """ assert self._reader is not None logger.debug("[%s] Loading frames from video: '%s'", self._name, self.location) for idx, frame in enumerate(self._reader): if idx in self._skip_list: logger.trace( # type:ignore[attr-defined] "[%s] Skipping frame %s due to skip list", self._name, idx) continue image = T.cast("npt.NDArray[np.uint8]", frame.to_ndarray(channel_last=True, format="bgr24")) filename = self._dummy_video_frame_name(idx) logger.trace("[%s] Loading video frame: '%s'", # type:ignore[attr-defined] self._name, filename) yield filename, image def _from_folder(self) -> T.Generator[tuple[str, npt.NDArray[np.uint8]] | tuple[str, npt.NDArray[np.uint8], PNGHeader], None, None]: """Generator for loading images from a folder Yields ------ filename The filename of the loaded image. image The loaded image. metadata The Faceswap metadata associated with the loaded image. (:class:`FacesLoader` only) """ logger.debug("[%s] Loading frames from folder: '%s'", self._name, self.location) for idx, filename in enumerate(self.file_list): if idx in self._skip_list: logger.trace( # type:ignore[attr-defined] "[%s] Skipping frame %s due to skip list", self._name, filename) continue image_read = read_image(filename, raise_error=False) if image_read is None: logger.warning("Frame not loaded: '%s'", filename) continue yield filename, image_read def load(self) -> T.Generator[tuple[str, npt.NDArray[np.uint8]] | tuple[str, npt.NDArray[np.uint8], PNGHeader], None, None]: """Generator for loading images from the given :attr:`location` If :class:`FacesLoader` is in use then the Faceswap metadata of the image stored in the image exif file is added as the final item in the output `tuple`. Yields ------ filename The filename of the loaded image. image The loaded image. metadata The Faceswap metadata associated with the loaded image. (:class:`FacesLoader` only) """ logger.debug("[%s] Initializing Load Generator", self._name) self._set_thread() assert self._error_state is not None while True: if self._error_state.has_error: current = current_thread() if current is main_thread(): self._error_state.re_raise() else: logger.debug("[%s.%s] Thread error detected in worker thread", current.name, self._name) break try: retval = self._queue.get(True, 1) except QueueEmpty: continue if retval == "EOF": logger.trace("[%s] Got EOF", self._name) # type:ignore[attr-defined] break logger.trace("[%s] Yielding: %s", # type:ignore[attr-defined] self._name, [v.shape if isinstance(v, np.ndarray) else v for v in retval]) yield retval logger.debug("[%s] Closing Load Generator", self._name) self.close() class FacesLoader(ImagesLoader): """ Loads faces from a faces folder along with the face's Faceswap metadata. Examples -------- Loading faces with their Faceswap metadata: >>> loader = FacesLoader('/path/to/faces/folder') >>> for filename, face, metadata in loader.load(): >>> """ def __init__(self, path, skip_list=None, count=None): logger.debug(parse_class_init(locals())) super().__init__(path, queue_size=8, skip_list=skip_list, count=count) def _get_count_and_filelist(self, count): """ Override default implementation to only return png files from the source folder Parameters ---------- count: int The number of images that the loader will encounter if already known, otherwise ``None`` """ if isinstance(self.location, (list, tuple)): file_list = self.location else: file_list = get_image_paths(self.location) self._file_list = [fname for fname in file_list if os.path.splitext(fname)[-1].lower() == ".png"] self._count = len(self.file_list) if count is None else count logger.debug("[%s] count: %s", self._name, self.count) logger.trace("[%s] file_list: %s", self._name, self.file_list) # type:ignore[attr-defined] def _from_folder(self): """ Generator for loading images from a folder Faces will only ever be loaded from a folder, so this is the only function requiring an override Yields ------ filename: str The filename of the loaded image. image: numpy.ndarray The loaded image. metadata: dict The Faceswap metadata associated with the loaded image. """ logger.debug("[%s] Loading images from folder: '%s'", self._name, self.location) for idx, filename in enumerate(self.file_list): if idx in self._skip_list: logger.trace( # type:ignore[attr-defined] "[%s] Skipping face %s due to skip list", self._name, idx) continue image_read = read_image(filename, raise_error=False, with_metadata=True) retval = filename, *image_read if retval[1] is None: logger.warning("Face not loaded: '%s'", filename) continue yield retval class SingleFrameLoader(ImagesLoader): """Allows direct access to a frame by filename or frame index. As we are interested in instant access to frames, there is no requirement to process in a background thread, as either way we need to wait for the frame to load. Parameters ---------- path Full path to the input media video_meta_data Existing video meta information containing the pts_time and is_key flags for the given video. Used in conjunction with single_frame_reader for faster seeks. Providing this means that the video does not need to be scanned again. Set to ``None`` if the video is to be scanned. Default: ``None`` """ def __init__(self, path: str, video_meta_data: dict[T.Literal["pts_time", "keyframes"], list[int]] | None = None ) -> None: logger.debug(parse_class_init(locals())) self._video_meta_data: dict[T.Literal["pts_time", "keyframes"], list[int]] | None = video_meta_data pts = None if video_meta_data is None else video_meta_data["pts_time"] keyframes = None if video_meta_data is None else video_meta_data["keyframes"] super().__init__(path, queue_size=1, fast_count=False, pts=pts, keyframes=keyframes) @property def video_meta_data(self) -> dict[T.Literal["pts_time", "keyframes"], list[int]] | None: """For videos contains the keys `frame_pts` holding a list of time stamps for each frame and `keyframes` holding the frame index of each key frame. Notes ----- Only populated if the input is a video and single frame reader is being used, otherwise returns ``None``. """ if self._reader is None: return None return {"pts_time": self._reader.info.pts.tolist(), "keyframes": self._reader.info.keyframes.tolist()} def image_from_index(self, index: int) -> tuple[str, npt.NDArray[np.uint8]]: """Return a single image from :attr:`file_list` for the given index. We do not use a background thread for this task, as it is assumed that requesting an image by index will be done when required. Parameters ---------- index: int The index number (frame number) of the frame to retrieve. NB: The first frame is index `0` Returns ------- filename: str The filename of the returned image image: :class:`numpy.ndarray` The image for the given index """ if self.is_video: assert self._reader is not None image = T.cast("npt.NDArray[np.uint8]", self._reader.get(index).to_ndarray(channel_last=True, format="bgr24")) filename = self._dummy_video_frame_name(index) else: file_list = [f for idx, f in enumerate(self._file_list) if idx not in self._skip_list] if self._skip_list else self._file_list filename = file_list[index] image = read_image(filename, raise_error=True) filename = os.path.basename(filename) logger.trace( # type:ignore[attr-defined] "[%s] index: %s, filename: %s image shape: %s", self._name, index, filename, image.shape) return filename, image def close(self) -> None: """Shut down the video reader""" if self._reader is not None: self._reader.close() super().close() class ImagesSaver(ImageIO): """ Perform image saving to a destination folder. Images are saved in a background ThreadPoolExecutor to allow for concurrent saving. See also :class:`ImageIO` for additional attributes. Parameters ---------- path: str The folder to save images to. This must be an existing folder. queue_size: int, optional The amount of images to hold in the internal buffer. Default: 8. as_bytes: bool, optional ``True`` if the image is already encoded to bytes, ``False`` if the image is a :class:`numpy.ndarray`. Default: ``False``. Examples -------- >>> saver = ImagesSaver('/path/to/save/folder') >>> for filename, image in : >>> saver.save(filename, image) >>> saver.close() """ def __init__(self, path, queue_size=8, as_bytes=False): logger.debug(parse_class_init(locals())) super().__init__(path, queue_size=queue_size) self._as_bytes = as_bytes def _check_location_exists(self): """ Check whether the output location exists and is a folder Raises ------ FaceswapError If the given location does not exist or the location is not a folder """ if not isinstance(self.location, str): raise FaceswapError("The output location must be a string not a " f"{type(self.location)}") super()._check_location_exists() if not os.path.isdir(self.location): raise FaceswapError(f"The output location '{self.location}' is not a folder") def _process(self, queue): """ Saves images from the save queue to the given :attr:`location` inside a thread. Parameters ---------- queue: queue.Queue() The ImageIO Queue """ executor = futures.ThreadPoolExecutor(thread_name_prefix=self.__class__.__name__) assert self._error_state is not None while True: if self._error_state.has_error: logger.debug("[%s] Thread error detected in worker thread", self._name) executor.shutdown(cancel_futures=True) return item = queue.get() if item == "EOF": logger.debug("[%s] EOF received", self._name) break logger.trace("[%s] Submitting: '%s'", self._name, item[0]) # type:ignore[attr-defined] executor.submit(self._save, *item) executor.shutdown() def _save(self, filename: str, image: bytes | np.ndarray, sub_folder: str | None) -> None: """ Save a single image inside a ThreadPoolExecutor Parameters ---------- filename: str The filename of the image to be saved. NB: Any folders passed in with the filename will be stripped and replaced with :attr:`location`. image: bytes or :class:`numpy.ndarray` The encoded image or numpy array to be saved subfolder: str or ``None`` If the file should be saved in a subfolder in the output location, the subfolder should be provided here. ``None`` for no subfolder. """ location = os.path.join(self.location, sub_folder) if sub_folder else self._location if sub_folder and not os.path.exists(location): os.makedirs(location) filename = os.path.join(location, os.path.basename(filename)) try: if self._as_bytes: assert isinstance(image, bytes) with open(filename, "wb") as out_file: out_file.write(image) else: assert isinstance(image, np.ndarray) cv2.imwrite(filename, image) logger.trace("[%s] Saved image: '%s'", # type:ignore[attr-defined] self._name, filename) except Exception as err: # pylint:disable=broad-except logger.error("Failed to save image '%s'. Original Error: %s", filename, str(err)) del image del filename def save(self, filename: str, image: bytes | np.ndarray, sub_folder: str | None = None) -> None: """ Save the given image in the background thread Ensure that :func:`close` is called once all save operations are complete. Parameters ---------- filename: str The filename of the image to be saved. NB: Any folders passed in with the filename will be stripped and replaced with :attr:`location`. image: bytes The encoded image to be saved subfolder: str, optional If the file should be saved in a subfolder in the output location, the subfolder should be provided here. ``None`` for no subfolder. Default: ``None`` """ if self._error_state is not None and self._error_state.has_error: logger.debug("[%s.%s] Thread error detected in worker thread. Not putting", current_thread().name, self._name) return self._set_thread() logger.trace("[%s] Putting to save queue: '%s'", # type:ignore[attr-defined] self._name, filename) self._queue.put((filename, image, sub_folder)) def close(self): """ Signal to the Save Threads that they should be closed and cleanly shutdown the saver """ logger.debug("[%s] Putting EOF to save queue", self._name) self._queue.put("EOF") super().close() __all__ = get_module_objects(__name__)