Skip to content

safestructures.DataProcessor

Bases: ABC

Base class for data processors other than tensors.

Source code in src/safestructures/processors/base.py
class DataProcessor(ABC):
    """Base class for data processors other than tensors."""

    data_type: Type[Any] = Any

    def __init__(self, serializer: Serializer):
        """Initialize the DataProcessor.

        Args:
            serializer (Serializer): The serializer.
                This may be used to help recursively process.
        """
        self.serializer = serializer

    @property
    def schema_type(self) -> str:
        """Provide the data type in schema-compatible format.

        Returns:
            str: The import path of the data type.
        """
        return get_import_path(self.data_type)

    @abstractmethod
    def serialize(self, data: Any) -> Union[str, None, bool, list, dict]:
        """Serialize the data.

        Args:
            data (Any): The data to serialize.

        Returns:
            Union[str, None, bool, list, dict]: The serialized value.
                Safetensors only accepts strings in the metadata.
                json.dumps is used, so None and booleans are handled.
                Lists and dictionaries of accepted data types as indicated
                here are acceptable, including nested types.
        """
        pass

    def serialize_extra(self, data: Any) -> dict:
        """Provide extra serialization details to the schema.

        Args:
            data (Any): The data to generate additional schema.

        Returns:
            dict: The additional schema to add onto the data's schema.
                The keys must not conflict with TYPE_FIELD and VALUE_FIELD,
                and must be strings.
                The values must also be of acceptable type to Safetensors metadata.
                See `DataProcessor.serialize`
        """
        return {}

    @abstractmethod
    def deserialize(
        self, serialized: Union[str, None, bool, list, dict], **kwargs
    ) -> Any:
        """Deserialize the schema into data.

        Args:
            schema (Any): The serialized value.

        Any additional schema other than VALUE_FIELD will be
        passed as keyword arguments.

        Returns:
            Any: The loaded value.
        """
        pass

    def __call__(self, data_or_schema: Any) -> Any:
        """Process the data or schema.

        Args:
            data_or_schema (Any): The data (Any) or schema (dict).

        Returns:
            The schema if the serializer is in save mode,
            The loaded data if the serializer is in load mode.
        """
        mode = self.serializer.mode
        if mode == Mode.SAVE:
            schema = {TYPE_FIELD: self.schema_type}

            schema[VALUE_FIELD] = self.serialize(data_or_schema)
            extra = self.serialize_extra(data_or_schema)

            if not isinstance(extra, dict):
                raise TypeError(
                    f"{type(self)}.serialize_extra must return a dictionary."
                )
            if TYPE_FIELD in extra:
                raise KeyError(
                    f"{type(self)}.serialize_extra must not have a {TYPE_FIELD} key."
                )
            if VALUE_FIELD in extra:
                raise KeyError(
                    f"{type(self)}.serialize_extra must not have a {VALUE_FIELD} key."
                )

            for k in extra.keys():
                if not isinstance(k, str):
                    raise TypeError(
                        (
                            f"Dictionary returned by {type(self)}.serialize_extra"
                            " must have string keys only."
                        )
                    )
            schema.update(extra)

            return schema

        elif mode == Mode.LOAD:
            kwargs = {}
            for k in data_or_schema:
                if k not in [TYPE_FIELD, VALUE_FIELD]:
                    kwargs[k] = data_or_schema[k]
            return self.deserialize(data_or_schema[VALUE_FIELD], **kwargs)

        else:
            raise ValueError(f"Mode {mode} not recognized.")

schema_type property

Provide the data type in schema-compatible format.

Returns:

Name Type Description
str str

The import path of the data type.

__call__(data_or_schema)

Process the data or schema.

Parameters:

Name Type Description Default
data_or_schema Any

The data (Any) or schema (dict).

required

Returns:

Type Description
Any

The schema if the serializer is in save mode,

Any

The loaded data if the serializer is in load mode.

Source code in src/safestructures/processors/base.py
def __call__(self, data_or_schema: Any) -> Any:
    """Process the data or schema.

    Args:
        data_or_schema (Any): The data (Any) or schema (dict).

    Returns:
        The schema if the serializer is in save mode,
        The loaded data if the serializer is in load mode.
    """
    mode = self.serializer.mode
    if mode == Mode.SAVE:
        schema = {TYPE_FIELD: self.schema_type}

        schema[VALUE_FIELD] = self.serialize(data_or_schema)
        extra = self.serialize_extra(data_or_schema)

        if not isinstance(extra, dict):
            raise TypeError(
                f"{type(self)}.serialize_extra must return a dictionary."
            )
        if TYPE_FIELD in extra:
            raise KeyError(
                f"{type(self)}.serialize_extra must not have a {TYPE_FIELD} key."
            )
        if VALUE_FIELD in extra:
            raise KeyError(
                f"{type(self)}.serialize_extra must not have a {VALUE_FIELD} key."
            )

        for k in extra.keys():
            if not isinstance(k, str):
                raise TypeError(
                    (
                        f"Dictionary returned by {type(self)}.serialize_extra"
                        " must have string keys only."
                    )
                )
        schema.update(extra)

        return schema

    elif mode == Mode.LOAD:
        kwargs = {}
        for k in data_or_schema:
            if k not in [TYPE_FIELD, VALUE_FIELD]:
                kwargs[k] = data_or_schema[k]
        return self.deserialize(data_or_schema[VALUE_FIELD], **kwargs)

    else:
        raise ValueError(f"Mode {mode} not recognized.")

__init__(serializer)

Initialize the DataProcessor.

Parameters:

Name Type Description Default
serializer Serializer

The serializer. This may be used to help recursively process.

required
Source code in src/safestructures/processors/base.py
def __init__(self, serializer: Serializer):
    """Initialize the DataProcessor.

    Args:
        serializer (Serializer): The serializer.
            This may be used to help recursively process.
    """
    self.serializer = serializer

deserialize(serialized, **kwargs) abstractmethod

Deserialize the schema into data.

Parameters:

Name Type Description Default
schema Any

The serialized value.

required

Any additional schema other than VALUE_FIELD will be passed as keyword arguments.

Returns:

Name Type Description
Any Any

The loaded value.

Source code in src/safestructures/processors/base.py
@abstractmethod
def deserialize(
    self, serialized: Union[str, None, bool, list, dict], **kwargs
) -> Any:
    """Deserialize the schema into data.

    Args:
        schema (Any): The serialized value.

    Any additional schema other than VALUE_FIELD will be
    passed as keyword arguments.

    Returns:
        Any: The loaded value.
    """
    pass

serialize(data) abstractmethod

Serialize the data.

Parameters:

Name Type Description Default
data Any

The data to serialize.

required

Returns:

Type Description
Union[str, None, bool, list, dict]

Union[str, None, bool, list, dict]: The serialized value. Safetensors only accepts strings in the metadata. json.dumps is used, so None and booleans are handled. Lists and dictionaries of accepted data types as indicated here are acceptable, including nested types.

Source code in src/safestructures/processors/base.py
@abstractmethod
def serialize(self, data: Any) -> Union[str, None, bool, list, dict]:
    """Serialize the data.

    Args:
        data (Any): The data to serialize.

    Returns:
        Union[str, None, bool, list, dict]: The serialized value.
            Safetensors only accepts strings in the metadata.
            json.dumps is used, so None and booleans are handled.
            Lists and dictionaries of accepted data types as indicated
            here are acceptable, including nested types.
    """
    pass

serialize_extra(data)

Provide extra serialization details to the schema.

Parameters:

Name Type Description Default
data Any

The data to generate additional schema.

required

Returns:

Name Type Description
dict dict

The additional schema to add onto the data's schema. The keys must not conflict with TYPE_FIELD and VALUE_FIELD, and must be strings. The values must also be of acceptable type to Safetensors metadata. See DataProcessor.serialize

Source code in src/safestructures/processors/base.py
def serialize_extra(self, data: Any) -> dict:
    """Provide extra serialization details to the schema.

    Args:
        data (Any): The data to generate additional schema.

    Returns:
        dict: The additional schema to add onto the data's schema.
            The keys must not conflict with TYPE_FIELD and VALUE_FIELD,
            and must be strings.
            The values must also be of acceptable type to Safetensors metadata.
            See `DataProcessor.serialize`
    """
    return {}