mask.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # -*- coding: utf-8 -*-
  2. from __future__ import unicode_literals, absolute_import
  3. import logging
  4. import re
  5. import six
  6. from collections import OrderedDict
  7. from inspect import isclass
  8. from .errors import RestError
  9. log = logging.getLogger(__name__)
  10. LEXER = re.compile(r"\{|\}|\,|[\w_:\-\*]+")
  11. class MaskError(RestError):
  12. """Raised when an error occurs on mask"""
  13. pass
  14. class ParseError(MaskError):
  15. """Raised when the mask parsing failed"""
  16. pass
  17. class Mask(OrderedDict):
  18. """
  19. Hold a parsed mask.
  20. :param str|dict|Mask mask: A mask, parsed or not
  21. :param bool skip: If ``True``, missing fields won't appear in result
  22. """
  23. def __init__(self, mask=None, skip=False, **kwargs):
  24. self.skip = skip
  25. if isinstance(mask, six.string_types):
  26. super(Mask, self).__init__()
  27. self.parse(mask)
  28. elif isinstance(mask, (dict, OrderedDict)):
  29. super(Mask, self).__init__(mask, **kwargs)
  30. else:
  31. self.skip = skip
  32. super(Mask, self).__init__(**kwargs)
  33. def parse(self, mask):
  34. """
  35. Parse a fields mask.
  36. Expect something in the form::
  37. {field,nested{nested_field,another},last}
  38. External brackets are optionals so it can also be written::
  39. field,nested{nested_field,another},last
  40. All extras characters will be ignored.
  41. :param str mask: the mask string to parse
  42. :raises ParseError: when a mask is unparseable/invalid
  43. """
  44. if not mask:
  45. return
  46. mask = self.clean(mask)
  47. fields = self
  48. previous = None
  49. stack = []
  50. for token in LEXER.findall(mask):
  51. if token == "{":
  52. if previous not in fields:
  53. raise ParseError("Unexpected opening bracket")
  54. fields[previous] = Mask(skip=self.skip)
  55. stack.append(fields)
  56. fields = fields[previous]
  57. elif token == "}":
  58. if not stack:
  59. raise ParseError("Unexpected closing bracket")
  60. fields = stack.pop()
  61. elif token == ",":
  62. if previous in (",", "{", None):
  63. raise ParseError("Unexpected comma")
  64. else:
  65. fields[token] = True
  66. previous = token
  67. if stack:
  68. raise ParseError("Missing closing bracket")
  69. def clean(self, mask):
  70. """Remove unnecessary characters"""
  71. mask = mask.replace("\n", "").strip()
  72. # External brackets are optional
  73. if mask[0] == "{":
  74. if mask[-1] != "}":
  75. raise ParseError("Missing closing bracket")
  76. mask = mask[1:-1]
  77. return mask
  78. def apply(self, data):
  79. """
  80. Apply a fields mask to the data.
  81. :param data: The data or model to apply mask on
  82. :raises MaskError: when unable to apply the mask
  83. """
  84. from . import fields
  85. # Should handle lists
  86. if isinstance(data, (list, tuple, set)):
  87. return [self.apply(d) for d in data]
  88. elif isinstance(data, (fields.Nested, fields.List, fields.Polymorph)):
  89. return data.clone(self)
  90. elif type(data) == fields.Raw:
  91. return fields.Raw(default=data.default, attribute=data.attribute, mask=self)
  92. elif data == fields.Raw:
  93. return fields.Raw(mask=self)
  94. elif (
  95. isinstance(data, fields.Raw)
  96. or isclass(data)
  97. and issubclass(data, fields.Raw)
  98. ):
  99. # Not possible to apply a mask on these remaining fields types
  100. raise MaskError("Mask is inconsistent with model")
  101. # Should handle objects
  102. elif not isinstance(data, (dict, OrderedDict)) and hasattr(data, "__dict__"):
  103. data = data.__dict__
  104. return self.filter_data(data)
  105. def filter_data(self, data):
  106. """
  107. Handle the data filtering given a parsed mask
  108. :param dict data: the raw data to filter
  109. :param list mask: a parsed mask to filter against
  110. :param bool skip: whether or not to skip missing fields
  111. """
  112. out = {}
  113. for field, content in six.iteritems(self):
  114. if field == "*":
  115. continue
  116. elif isinstance(content, Mask):
  117. nested = data.get(field, None)
  118. if self.skip and nested is None:
  119. continue
  120. elif nested is None:
  121. out[field] = None
  122. else:
  123. out[field] = content.apply(nested)
  124. elif self.skip and field not in data:
  125. continue
  126. else:
  127. out[field] = data.get(field, None)
  128. if "*" in self.keys():
  129. for key, value in six.iteritems(data):
  130. if key not in out:
  131. out[key] = value
  132. return out
  133. def __str__(self):
  134. return "{{{0}}}".format(
  135. ",".join(
  136. [
  137. "".join((k, str(v))) if isinstance(v, Mask) else k
  138. for k, v in six.iteritems(self)
  139. ]
  140. )
  141. )
  142. def apply(data, mask, skip=False):
  143. """
  144. Apply a fields mask to the data.
  145. :param data: The data or model to apply mask on
  146. :param str|Mask mask: the mask (parsed or not) to apply on data
  147. :param bool skip: If rue, missing field won't appear in result
  148. :raises MaskError: when unable to apply the mask
  149. """
  150. return Mask(mask, skip).apply(data)