"""
Loads files from a directory
"""
import os
import glob
import pathlib
from typing import List
from ..record import Record
from ..base import config, field
from .memory import MemorySource
from ..util.entrypoint import entrypoint
from ..source.source import BaseSource
from ..configloader.configloader import ConfigLoaders
from ..high_level.source import save
[docs]class FolderNotFoundError(Exception):
"""
Folder doesn't exist.
"""
[docs]@config
class DirectorySourceConfig:
foldername: str
feature: str = field("Name of the feature the data will be referenced as")
labels: List[str] = field(
"Image labels", default_factory=lambda: ["unlabelled"]
)
save: BaseSource = None
[docs]@entrypoint("dir")
class DirectorySource(MemorySource):
"""
Source to read files in a folder.
"""
CONFIG = DirectorySourceConfig
CONFIG_LOADER = ConfigLoaders()
def __init__(self, config):
super().__init__(config)
if isinstance(getattr(self.config, "foldername", None), str):
with self.config.no_enforce_immutable():
self.config.foldername = pathlib.Path(self.config.foldername)
async def __aenter__(self) -> "BaseSourceContext":
await self._open()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self._close()
async def _open(self):
if not os.path.exists(self.config.foldername) and not os.path.isdir(
self.config.foldername
):
raise FolderNotFoundError(f"Folder path: {self.config.foldername}")
if (
self.config.labels != ["unlabelled"]
and len(self.config.labels) == 1
):
if os.path.isfile(self.config.labels[0]):
# Update labels with list read from the file
with self.config.no_enforce_immutable():
self.config.labels = pathlib.Path.read_text(
pathlib.Path(self.config.labels[0])
).split(",")
elif self.config.labels != ["unlabelled"]:
label_folders = [
labels
for labels in os.listdir(self.config.foldername)
if os.path.isdir(os.path.join(self.config.foldername, labels))
]
# Check if all existing label folders are given to `labels` list
if set(label_folders) > set(self.config.labels):
self.logger.warning(
"All labels not specified. Folders present: %s \nLabels entered: %s",
label_folders,
self.config.labels,
)
await self.load_fd()
async def _close(self):
if self.config.save:
await save(self.config.save, self.mem)
async def load_fd(self):
self.mem = {}
# Iterate over the labels list
for label in self.config.labels:
if self.config.labels == ["unlabelled"]:
folders = self.config.foldername
else:
folders = self.config.foldername.joinpath(label)
# Go through all image files and read them using pngconfigloader
for file_name in map(
os.path.basename, glob.glob(str(folders) + "/*")
):
image_filename = folders.joinpath(file_name)
async with self.CONFIG_LOADER as cfgl:
_, feature_data = await cfgl.load_file(image_filename)
if self.config.labels != ["unlabelled"]:
file_name = label + "/" + file_name
self.mem[file_name] = Record(
file_name,
data={
"features": {
self.config.feature: feature_data,
"label": label,
}
},
)
if self.config.labels == ["unlabelled"]:
del self.mem[file_name].features()["label"]
self.logger.debug("%r loaded %d records", self, len(self.mem))