Source code for dffml.source.idx1

# SPDX-License-Identifier: MIT
# Copyright (c) 2019 Intel Corporation
"""
Loads records from an IDX1 file
"""
import struct

from ..record import Record
from ..base import config, field
from .memory import MemorySource
from .file import BinaryFileSource
from ..util.entrypoint import entrypoint


[docs]@config class IDXSourceConfig: filename: str feature: str = field("Name of the feature the data will be referenced as") readwrite: bool = False allowempty: bool = False
[docs]class IDX1SourceConfig(IDXSourceConfig): pass
[docs]@entrypoint("idx1") class IDX1Source(BinaryFileSource, MemorySource): """ Source to read files in IDX1 format (such as MNIST digit label dataset). """ CONFIG = IDX1SourceConfig async def load_fd(self, xfile): # Reading the binary datafile's details magic, size = struct.unpack(">II", xfile.read(8)) # Reading the rest of binary datafile one byte at a time self.mem = {} for i in range(size): self.mem[str(i)] = Record( str(i), data={ "features": { self.config.feature: struct.unpack( ">b", xfile.read(1) )[0] } }, ) self.logger.debug("%r loaded %d records", self, len(self.mem)) async def dump_fd(self, fd): raise NotImplementedError