Source code for dffml.source.idx3

# SPDX-License-Identifier: MIT
# Copyright (c) 2019 Intel Corporation
"""
Loads records from an IDX3 file
"""
import struct

from ..record import Record
from ..util.entrypoint import entrypoint
from .idx1 import IDX1Source, IDXSourceConfig


[docs]class IDX3SourceConfig(IDXSourceConfig): pass
[docs]@entrypoint("idx3") class IDX3Source(IDX1Source): """ Source to read files in IDX3 format (such as MNIST digit image dataset). """ CONFIG = IDX3SourceConfig async def load_fd(self, xfile): # Reading the binary datafile's details magic, size = struct.unpack(">II", xfile.read(8)) nrows, ncols = struct.unpack(">II", xfile.read(8)) self.mem = {} inner_array_size = nrows * ncols for i in range(0, size): self.mem[str(i)] = Record( str(i), data={ "features": { self.config.feature: struct.unpack( f">{inner_array_size}B", xfile.read(inner_array_size), ) } }, ) self.logger.debug("%r loaded %d records", self, len(self.mem))