123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- # -*- coding: utf-8 -*-
- from __future__ import unicode_literals, absolute_import
- import logging
- import re
- import six
- from collections import OrderedDict
- from inspect import isclass
- from .errors import RestError
- log = logging.getLogger(__name__)
- LEXER = re.compile(r"\{|\}|\,|[\w_:\-\*]+")
- class MaskError(RestError):
- """Raised when an error occurs on mask"""
- pass
- class ParseError(MaskError):
- """Raised when the mask parsing failed"""
- pass
- class Mask(OrderedDict):
- """
- Hold a parsed mask.
- :param str|dict|Mask mask: A mask, parsed or not
- :param bool skip: If ``True``, missing fields won't appear in result
- """
- def __init__(self, mask=None, skip=False, **kwargs):
- self.skip = skip
- if isinstance(mask, six.string_types):
- super(Mask, self).__init__()
- self.parse(mask)
- elif isinstance(mask, (dict, OrderedDict)):
- super(Mask, self).__init__(mask, **kwargs)
- else:
- self.skip = skip
- super(Mask, self).__init__(**kwargs)
- def parse(self, mask):
- """
- Parse a fields mask.
- Expect something in the form::
- {field,nested{nested_field,another},last}
- External brackets are optionals so it can also be written::
- field,nested{nested_field,another},last
- All extras characters will be ignored.
- :param str mask: the mask string to parse
- :raises ParseError: when a mask is unparseable/invalid
- """
- if not mask:
- return
- mask = self.clean(mask)
- fields = self
- previous = None
- stack = []
- for token in LEXER.findall(mask):
- if token == "{":
- if previous not in fields:
- raise ParseError("Unexpected opening bracket")
- fields[previous] = Mask(skip=self.skip)
- stack.append(fields)
- fields = fields[previous]
- elif token == "}":
- if not stack:
- raise ParseError("Unexpected closing bracket")
- fields = stack.pop()
- elif token == ",":
- if previous in (",", "{", None):
- raise ParseError("Unexpected comma")
- else:
- fields[token] = True
- previous = token
- if stack:
- raise ParseError("Missing closing bracket")
- def clean(self, mask):
- """Remove unnecessary characters"""
- mask = mask.replace("\n", "").strip()
- # External brackets are optional
- if mask[0] == "{":
- if mask[-1] != "}":
- raise ParseError("Missing closing bracket")
- mask = mask[1:-1]
- return mask
- def apply(self, data):
- """
- Apply a fields mask to the data.
- :param data: The data or model to apply mask on
- :raises MaskError: when unable to apply the mask
- """
- from . import fields
- # Should handle lists
- if isinstance(data, (list, tuple, set)):
- return [self.apply(d) for d in data]
- elif isinstance(data, (fields.Nested, fields.List, fields.Polymorph)):
- return data.clone(self)
- elif type(data) == fields.Raw:
- return fields.Raw(default=data.default, attribute=data.attribute, mask=self)
- elif data == fields.Raw:
- return fields.Raw(mask=self)
- elif (
- isinstance(data, fields.Raw)
- or isclass(data)
- and issubclass(data, fields.Raw)
- ):
- # Not possible to apply a mask on these remaining fields types
- raise MaskError("Mask is inconsistent with model")
- # Should handle objects
- elif not isinstance(data, (dict, OrderedDict)) and hasattr(data, "__dict__"):
- data = data.__dict__
- return self.filter_data(data)
- def filter_data(self, data):
- """
- Handle the data filtering given a parsed mask
- :param dict data: the raw data to filter
- :param list mask: a parsed mask to filter against
- :param bool skip: whether or not to skip missing fields
- """
- out = {}
- for field, content in six.iteritems(self):
- if field == "*":
- continue
- elif isinstance(content, Mask):
- nested = data.get(field, None)
- if self.skip and nested is None:
- continue
- elif nested is None:
- out[field] = None
- else:
- out[field] = content.apply(nested)
- elif self.skip and field not in data:
- continue
- else:
- out[field] = data.get(field, None)
- if "*" in self.keys():
- for key, value in six.iteritems(data):
- if key not in out:
- out[key] = value
- return out
- def __str__(self):
- return "{{{0}}}".format(
- ",".join(
- [
- "".join((k, str(v))) if isinstance(v, Mask) else k
- for k, v in six.iteritems(self)
- ]
- )
- )
- def apply(data, mask, skip=False):
- """
- Apply a fields mask to the data.
- :param data: The data or model to apply mask on
- :param str|Mask mask: the mask (parsed or not) to apply on data
- :param bool skip: If rue, missing field won't appear in result
- :raises MaskError: when unable to apply the mask
- """
- return Mask(mask, skip).apply(data)
|