marshalling.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # -*- coding: utf-8 -*-
  2. from __future__ import unicode_literals
  3. from collections import OrderedDict
  4. from functools import wraps
  5. from six import iteritems
  6. from flask import request, current_app, has_app_context
  7. from .mask import Mask, apply as apply_mask
  8. from .utils import unpack
  9. def make(cls):
  10. if isinstance(cls, type):
  11. return cls()
  12. return cls
  13. def marshal(data, fields, envelope=None, skip_none=False, mask=None, ordered=False):
  14. """Takes raw data (in the form of a dict, list, object) and a dict of
  15. fields to output and filters the data based on those fields.
  16. :param data: the actual object(s) from which the fields are taken from
  17. :param fields: a dict of whose keys will make up the final serialized
  18. response output
  19. :param envelope: optional key that will be used to envelop the serialized
  20. response
  21. :param bool skip_none: optional key will be used to eliminate fields
  22. which value is None or the field's key not
  23. exist in data
  24. :param bool ordered: Wether or not to preserve order
  25. >>> from flask_restx import fields, marshal
  26. >>> data = { 'a': 100, 'b': 'foo', 'c': None }
  27. >>> mfields = { 'a': fields.Raw, 'c': fields.Raw, 'd': fields.Raw }
  28. >>> marshal(data, mfields)
  29. {'a': 100, 'c': None, 'd': None}
  30. >>> marshal(data, mfields, envelope='data')
  31. {'data': {'a': 100, 'c': None, 'd': None}}
  32. >>> marshal(data, mfields, skip_none=True)
  33. {'a': 100}
  34. >>> marshal(data, mfields, ordered=True)
  35. OrderedDict([('a', 100), ('c', None), ('d', None)])
  36. >>> marshal(data, mfields, envelope='data', ordered=True)
  37. OrderedDict([('data', OrderedDict([('a', 100), ('c', None), ('d', None)]))])
  38. >>> marshal(data, mfields, skip_none=True, ordered=True)
  39. OrderedDict([('a', 100)])
  40. """
  41. out, has_wildcards = _marshal(data, fields, envelope, skip_none, mask, ordered)
  42. if has_wildcards:
  43. # ugly local import to avoid dependency loop
  44. from .fields import Wildcard
  45. items = []
  46. keys = []
  47. for dkey, val in fields.items():
  48. key = dkey
  49. if isinstance(val, dict):
  50. value = marshal(data, val, skip_none=skip_none, ordered=ordered)
  51. else:
  52. field = make(val)
  53. is_wildcard = isinstance(field, Wildcard)
  54. # exclude already parsed keys from the wildcard
  55. if is_wildcard:
  56. field.reset()
  57. if keys:
  58. field.exclude |= set(keys)
  59. keys = []
  60. value = field.output(dkey, data, ordered=ordered)
  61. if is_wildcard:
  62. def _append(k, v):
  63. if skip_none and (v is None or v == OrderedDict() or v == {}):
  64. return
  65. items.append((k, v))
  66. key = field.key or dkey
  67. _append(key, value)
  68. while True:
  69. value = field.output(dkey, data, ordered=ordered)
  70. if value is None or value == field.container.format(
  71. field.default
  72. ):
  73. break
  74. key = field.key
  75. _append(key, value)
  76. continue
  77. keys.append(key)
  78. if skip_none and (value is None or value == OrderedDict() or value == {}):
  79. continue
  80. items.append((key, value))
  81. items = tuple(items)
  82. out = OrderedDict(items) if ordered else dict(items)
  83. if envelope:
  84. out = OrderedDict([(envelope, out)]) if ordered else {envelope: out}
  85. return out
  86. return out
  87. def _marshal(data, fields, envelope=None, skip_none=False, mask=None, ordered=False):
  88. """Takes raw data (in the form of a dict, list, object) and a dict of
  89. fields to output and filters the data based on those fields.
  90. :param data: the actual object(s) from which the fields are taken from
  91. :param fields: a dict of whose keys will make up the final serialized
  92. response output
  93. :param envelope: optional key that will be used to envelop the serialized
  94. response
  95. :param bool skip_none: optional key will be used to eliminate fields
  96. which value is None or the field's key not
  97. exist in data
  98. :param bool ordered: Wether or not to preserve order
  99. >>> from flask_restx import fields, marshal
  100. >>> data = { 'a': 100, 'b': 'foo', 'c': None }
  101. >>> mfields = { 'a': fields.Raw, 'c': fields.Raw, 'd': fields.Raw }
  102. >>> marshal(data, mfields)
  103. {'a': 100, 'c': None, 'd': None}
  104. >>> marshal(data, mfields, envelope='data')
  105. {'data': {'a': 100, 'c': None, 'd': None}}
  106. >>> marshal(data, mfields, skip_none=True)
  107. {'a': 100}
  108. >>> marshal(data, mfields, ordered=True)
  109. OrderedDict([('a', 100), ('c', None), ('d', None)])
  110. >>> marshal(data, mfields, envelope='data', ordered=True)
  111. OrderedDict([('data', OrderedDict([('a', 100), ('c', None), ('d', None)]))])
  112. >>> marshal(data, mfields, skip_none=True, ordered=True)
  113. OrderedDict([('a', 100)])
  114. """
  115. # ugly local import to avoid dependency loop
  116. from .fields import Wildcard
  117. mask = mask or getattr(fields, "__mask__", None)
  118. fields = getattr(fields, "resolved", fields)
  119. if mask:
  120. fields = apply_mask(fields, mask, skip=True)
  121. if isinstance(data, (list, tuple)):
  122. out = [marshal(d, fields, skip_none=skip_none, ordered=ordered) for d in data]
  123. if envelope:
  124. out = OrderedDict([(envelope, out)]) if ordered else {envelope: out}
  125. return out, False
  126. has_wildcards = {"present": False}
  127. def __format_field(key, val):
  128. field = make(val)
  129. if isinstance(field, Wildcard):
  130. has_wildcards["present"] = True
  131. value = field.output(key, data, ordered=ordered)
  132. return (key, value)
  133. items = (
  134. (k, marshal(data, v, skip_none=skip_none, ordered=ordered))
  135. if isinstance(v, dict)
  136. else __format_field(k, v)
  137. for k, v in iteritems(fields)
  138. )
  139. if skip_none:
  140. items = (
  141. (k, v) for k, v in items if v is not None and v != OrderedDict() and v != {}
  142. )
  143. out = OrderedDict(items) if ordered else dict(items)
  144. if envelope:
  145. out = OrderedDict([(envelope, out)]) if ordered else {envelope: out}
  146. return out, has_wildcards["present"]
  147. class marshal_with(object):
  148. """A decorator that apply marshalling to the return values of your methods.
  149. >>> from flask_restx import fields, marshal_with
  150. >>> mfields = { 'a': fields.Raw }
  151. >>> @marshal_with(mfields)
  152. ... def get():
  153. ... return { 'a': 100, 'b': 'foo' }
  154. ...
  155. ...
  156. >>> get()
  157. OrderedDict([('a', 100)])
  158. >>> @marshal_with(mfields, envelope='data')
  159. ... def get():
  160. ... return { 'a': 100, 'b': 'foo' }
  161. ...
  162. ...
  163. >>> get()
  164. OrderedDict([('data', OrderedDict([('a', 100)]))])
  165. >>> mfields = { 'a': fields.Raw, 'c': fields.Raw, 'd': fields.Raw }
  166. >>> @marshal_with(mfields, skip_none=True)
  167. ... def get():
  168. ... return { 'a': 100, 'b': 'foo', 'c': None }
  169. ...
  170. ...
  171. >>> get()
  172. OrderedDict([('a', 100)])
  173. see :meth:`flask_restx.marshal`
  174. """
  175. def __init__(
  176. self, fields, envelope=None, skip_none=False, mask=None, ordered=False
  177. ):
  178. """
  179. :param fields: a dict of whose keys will make up the final
  180. serialized response output
  181. :param envelope: optional key that will be used to envelop the serialized
  182. response
  183. """
  184. self.fields = fields
  185. self.envelope = envelope
  186. self.skip_none = skip_none
  187. self.ordered = ordered
  188. self.mask = Mask(mask, skip=True)
  189. def __call__(self, f):
  190. @wraps(f)
  191. def wrapper(*args, **kwargs):
  192. resp = f(*args, **kwargs)
  193. mask = self.mask
  194. if has_app_context():
  195. mask_header = current_app.config["RESTX_MASK_HEADER"]
  196. mask = request.headers.get(mask_header) or mask
  197. if isinstance(resp, tuple):
  198. data, code, headers = unpack(resp)
  199. return (
  200. marshal(
  201. data,
  202. self.fields,
  203. self.envelope,
  204. self.skip_none,
  205. mask,
  206. self.ordered,
  207. ),
  208. code,
  209. headers,
  210. )
  211. else:
  212. return marshal(
  213. resp, self.fields, self.envelope, self.skip_none, mask, self.ordered
  214. )
  215. return wrapper
  216. class marshal_with_field(object):
  217. """
  218. A decorator that formats the return values of your methods with a single field.
  219. >>> from flask_restx import marshal_with_field, fields
  220. >>> @marshal_with_field(fields.List(fields.Integer))
  221. ... def get():
  222. ... return ['1', 2, 3.0]
  223. ...
  224. >>> get()
  225. [1, 2, 3]
  226. see :meth:`flask_restx.marshal_with`
  227. """
  228. def __init__(self, field):
  229. """
  230. :param field: a single field with which to marshal the output.
  231. """
  232. if isinstance(field, type):
  233. self.field = field()
  234. else:
  235. self.field = field
  236. def __call__(self, f):
  237. @wraps(f)
  238. def wrapper(*args, **kwargs):
  239. resp = f(*args, **kwargs)
  240. if isinstance(resp, tuple):
  241. data, code, headers = unpack(resp)
  242. return self.field.format(data), code, headers
  243. return self.field.format(resp)
  244. return wrapper