Source code for numcodecs.categorize

from .abc import Codec
from .compat import ensure_ndarray, ndarray_copy, ensure_text


import numpy as np


[docs]class Categorize(Codec): """Filter encoding categorical string data as integers. Parameters ---------- labels : sequence of strings Category labels. dtype : dtype Data type to use for decoded data. astype : dtype, optional Data type to use for encoded data. Examples -------- >>> import numcodecs >>> import numpy as np >>> x = np.array(['male', 'female', 'female', 'male', 'unexpected'], dtype=object) >>> x array(['male', 'female', 'female', 'male', 'unexpected'], dtype=object) >>> codec = numcodecs.Categorize(labels=['female', 'male'], dtype=object) >>> y = codec.encode(x) >>> y array([2, 1, 1, 2, 0], dtype=uint8) >>> z = codec.decode(y) >>> z array(['male', 'female', 'female', 'male', ''], dtype=object) """ codec_id = 'categorize' def __init__(self, labels, dtype, astype='u1'): self.dtype = np.dtype(dtype) if self.dtype.kind not in 'UO': raise TypeError("only unicode ('U') and object ('O') dtypes are " "supported") self.labels = [ensure_text(label) for label in labels] self.astype = np.dtype(astype) if self.astype == object: raise TypeError('encoding as object array not supported')
[docs] def encode(self, buf): # normalise input if self.dtype == object: arr = np.asarray(buf, dtype=object) else: arr = ensure_ndarray(buf).view(self.dtype) # flatten to simplify implementation arr = arr.reshape(-1, order='A') # setup output array enc = np.zeros_like(arr, dtype=self.astype) # apply encoding, reserving 0 for values not specified in labels for i, l in enumerate(self.labels): enc[arr == l] = i + 1 return enc
[docs] def decode(self, buf, out=None): # normalise input enc = ensure_ndarray(buf).view(self.astype) # flatten to simplify implementation enc = enc.reshape(-1, order='A') # setup output dec = np.full_like(enc, fill_value='', dtype=self.dtype) # apply decoding for i, l in enumerate(self.labels): dec[enc == (i + 1)] = l # handle output dec = ndarray_copy(dec, out) return dec
[docs] def get_config(self): config = dict( id=self.codec_id, labels=self.labels, dtype=self.dtype.str, astype=self.astype.str ) return config
def __repr__(self): # make sure labels part is not too long labels = repr(self.labels[:3]) if len(self.labels) > 3: labels = labels[:-1] + ', ...]' r = '%s(dtype=%r, astype=%r, labels=%s)' % \ (type(self).__name__, self.dtype.str, self.astype.str, labels) return r