Plugins
WARNING
The usage of plugins can execute arbitrary code. Be careful using plugins from a 3rd party.
safestructures
utilizes a plugin-based architecture to process different data types.
These are based on two main classes: safestructures.DataProcessor
and safestructures.TensorProcessor
.
DataProcessor
serves as an abstract base class that you can extend to handle serialization and deserialization of specific data types. By subclassing DataProcessor
, you can define how your custom data type is converted into a format that safetensors
can store, and how to load it from safetensors
metadata. This is subclassed directly for basic data types and data containers.
TensorProcessor
is a special subclass of DataProcessor
, utilizing safetensors
's capabilities to serialize and deserialize tensors.
Basic types
Basic types, such as str
, int
, float
, etc. use subclasses of DataProcessor
.
These are considered as "atomic" data types, the base case where no further serialization is needed.
If you have a custom atomic data type, especially one that is not covered by core safestructures
capabilities,
then you would subclass DataProcessor
(called MyTypeProcessor
below for example) and follow 3 main steps:
- Define
MyTypeProcessor.data_type
, a class attribute. - Implement
MyTypeProcessor.serialize
, the serialization method.- Input: A value of your custom type.
- Returns: A string. Strings are required to be compatible with
safetensors
metadata.
- Implement
MyTypeProcessor.deserialize
, the deserialization method.- Input: The string representation of the value.
- Returns: the original value as your custom type.
For example, for the custom class:
The data processor would be:
from safestructures import DataProcessor
class MyTypeProcessor(DataProcessor):
data_type = MyCustomType # Set this to your custom data type
def serialize(self, data: MyCustomType) -> str:
# Convert MyCustomType to a format Safetensors can store
return str(data.value)
def deserialize(self, serialized: str) -> MyCustomType:
# Reconstruct MyCustomType from the serialized form
return MyCustomType(int(serialized))
You can then serialize your object or a data container containing your object by using the plugins
keyword argument with save_file
and load_file
:
from safestructures import save_file, load_file
list_obj = [MyCustomType(42), MyCustomType(88)]
file_path = "my_custom_objs.safestructures"
save_file(list_obj, file_path, plugins=[MyTypeProcessor])
loaded_obj = load_file(file_path, plugins=[MyTypeProcessor])
However, if the custom object you want to serialize is a container or even a nested container that would house atomic data types, then see the Containers section.
Optional: storing other metadata to aid in deserialization
It may be helpful to store other data that cannot be captured in a single string representation by DataProcessor.serialize
, but would be needed to properly deserialize your custom data type.
DataProcessor.serialize_extra
helps with this by giving the option to provide extra metadata.
It accepts the value you want to serialize, and your implementation would need to return a dictionary of only string types
to be used as keyword arguments for MyTypeProcessor.deserialize
.
Note that MyTypeProcessor.deserialize
would need to accept these keyword arguments.
An example would be the core DictProcessor
:
Source code in src/safestructures/processors/iterable.py
Note
DictProcessor
is a container plugin. See the Containers section for details on container plugins.
Tensors
TensorProcessor
is a special DataProcessor
.
The serialize
and deserialize
methods do not need to be overloaded for a subclass / plugin,
but there are still 2 main steps:
- Define
MyTensorProcessor.data_type
, a class attribute.- This would be the tensor class of the ML framework.
- Implement
MyTensorProcessor.to_numpy
, a processing method to convert to NumPy.- Input: The ML framework tensor.
- Returns: The tensor as a
numpy.ndarray
. The implementation should:- Provide a contiguous array.
- Be casted to FP32 for float tensors for maximum compatibility.
An example would be the core TorchProcessor
:
Source code in src/safestructures/processors/tensor.py
Containers
Processors for containers such as lists and dictionaries are still DataProcessor
subclasses but are recursive in nature.
safestructures
uses recursion to traverse a data structure.
Unless there are values that are needed to be serialized at the container level, such as dictionary keys,
no actual values are serialized/deserialized once the container object is reached.
A DataProcessor
for a container merely iterates through the object and uses self.serializer.serialize
or self.serializer.deserialize
to further serialize or deserialize child values, respectively.
Implementing a DataProcessor
subclass to handle your custom container class follows the same steps as
for basic types, but with the following extra considerations:
- The
DataProcessor.serialize
method must useself.serializer.serialize
to further serialize as you iterate through the values in the container.- Input: The data container to serialize.
- Output: Must be a
builtin
container. Consider whatbuiltin
container (such asdict
orlist
) best fits your custom container class. For example,safestructures
usesdict
to help serializedataclasses.dataclass
objects.
- The
DataProcessor.deserialize
method must useself.serializer.deserialize
to further deserialize as you iterate through the values in abuiltin
serialized container.- Input: A
builtin
container with serialized values. For example, an object that used a dictionary to serialize would have adict
provided toDataProcessor.deserialize
, and a list/tuple/set-like object would have alist
provided toDataProcessor.deserialize
. - Output: The deserialized custom object.
- Input: A
The DictProcessor
implementation above is a good example of these concepts, while also exercising DataProcessor.serialize_extra
to handle other values that require serialization at the container level.
In safestructures
, the core container processors are iterable in nature, so all are in the safestructures.processor.iterable
submodule:
ListProcessor
SetProcessor
TupleProcessor
DictProcessor
DataclassProcessor
For example, let's create a plugin for transformers.modeling_outputs.ModelOutput
objects.
Since ModelOutput
objects are similar to dataclasses.dataclass
objects, we'll subclass safestructures.processors.iterable.DataclassProcessor
from safestructures.processors.iterable import DataclassProcessor
class ModelOutputProcessor(DataclassProcessor):
"""Processor for `transformers`'s ModelOutput."""
data_type = ModelOutput
def deserialize(self, serialized: dict, **kwargs) -> ModelOutput:
"""Overload DataclassProcessor.deserialize.
This is so the proper ModelOutput is provided.
"""
mo_kwargs = {}
for k, v in serialized.items():
mo_kwargs[k] = self.serializer.deserialize(v)
model_output_instance = self.data_type(**mo_kwargs)
return model_output_instance
Since most ModelOutput
objects from a transformers
model are subclasses, we can just subclass ModelOutputProcessor
like below:
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions BaseModelOutputWithPoolingAndCrossAttentions
class BertOutputProcessor(ModelOutputProcessor):
"""Processor for BERT model outputs."""
data_type = BaseModelOutputWithPoolingAndCrossAttentions
class BertEncoderOutputProcessor(ModelOutputProcessor):
"""Processor for BERT encoder outputs."""
data_type = BaseModelOutputWithPastAndCrossAttentions
We can then serialize outputs of the model. Below is a PyTorch-based example:
from transformers import BertConfig, BertModel
config = BertConfig()
model = BertModel(config)
test_plugins = [BertOutputProcessor, BertEncoderOutputProcessor]
results = {}
def _store_encoder_output(module, args, kwargs, output):
results["encoder_output"] = output
return output
model.encoder.register_forward_hook(_store_encoder_output, with_kwargs=True)
test_input_ids = torch.tensor([[0] * 128])
test_output = model(test_input_ids)
results["model_output"] = test_output
test_filepath = tmp_path / "test.safestructures"
save_file(results, test_filepath, plugins=test_plugins)
deserialized = load_file(test_filepath, plugins=test_plugins)