Preserve custom attributes when pickling subclass of numpy array

np.ndarray uses __reduce__ to pickle itself. We can take a look at what it actually returns when you call that function to get an idea of what’s going on:

>>> obj = RealisticInfoArray([1, 2, 3], info='foo')
>>> obj.__reduce__()
(<built-in function _reconstruct>, (<class 'pick.RealisticInfoArray'>, (0,), 'b'), (1, (3,), dtype('int64'), False, '\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00'))

So, we get a 3-tuple back. The docs for __reduce__ describe what each element is doing:

When a tuple is returned, it must be between two and five elements
long. Optional elements can either be omitted, or None can be provided
as their value. The contents of this tuple are pickled as normal and
used to reconstruct the object at unpickling time. The semantics of
each element are:

  • A callable object that will be called to create the initial version of
    the object. The next element of the tuple will provide arguments for
    this callable, and later elements provide additional state information
    that will subsequently be used to fully reconstruct the pickled data.

    In the unpickling environment this object must be either a class, a
    callable registered as a “safe constructor” (see below), or it must
    have an attribute __safe_for_unpickling__ with a true value.
    Otherwise, an UnpicklingError will be raised in the unpickling
    environment. Note that as usual, the callable itself is pickled by
    name.

  • A tuple of arguments for the callable object.

  • Optionally, the object’s state, which will be passed to the object’s
    __setstate__() method as described in section Pickling and unpickling normal class instances. If the object has no __setstate__() method,
    then, as above, the value must be a dictionary and it will be added to
    the object’s __dict__.

So, _reconstruct is the function called to rebuild the object, (<class 'pick.RealisticInfoArray'>, (0,), 'b') are the arguments passed to that function, and (1, (3,), dtype('int64'), False, '\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00')) gets passed to the class’ __setstate__. This gives us an opportunity; we could override __reduce__ and provide our own tuple to __setstate__, and then additionally override __setstate__, to set our custom attribute when we unpickle. We just need to make sure we preserve all the data the parent class needs, and call the parent’s __setstate__, too:

class RealisticInfoArray(np.ndarray):
    def __new__(cls, input_array, info=None):
        obj = np.asarray(input_array).view(cls)
        obj.info = info
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.info = getattr(obj, 'info', None)

    def __reduce__(self):
        # Get the parent's __reduce__ tuple
        pickled_state = super(RealisticInfoArray, self).__reduce__()
        # Create our own tuple to pass to __setstate__
        new_state = pickled_state[2] + (self.info,)
        # Return a tuple that replaces the parent's __setstate__ tuple with our own
        return (pickled_state[0], pickled_state[1], new_state)

    def __setstate__(self, state):
        self.info = state[-1]  # Set the info attribute
        # Call the parent's __setstate__ with the other tuple elements.
        super(RealisticInfoArray, self).__setstate__(state[0:-1])

Usage:

>>> obj = pick.RealisticInfoArray([1, 2, 3], info='foo')
>>> pickle_str = pickle.dumps(obj)
>>> pickle_str
"cnumpy.core.multiarray\n_reconstruct\np0\n(cpick\nRealisticInfoArray\np1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I3\ntp6\ncnumpy\ndtype\np7\n(S'i8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\nI0\ntp12\nbI00\nS'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x03\\x00\\x00\\x00\\x00\\x00\\x00\\x00'\np13\nS'foo'\np14\ntp15\nb."
>>> new_obj = pickle.loads(pickle_str)
>>> new_obj.info
'foo'

Leave a Comment