model.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # -*- coding: utf-8 -*-
  2. from __future__ import unicode_literals
  3. import copy
  4. import re
  5. import warnings
  6. from collections import OrderedDict
  7. try:
  8. from collections.abc import MutableMapping
  9. except ImportError:
  10. # TODO Remove this to drop Python2 support
  11. from collections import MutableMapping
  12. from six import iteritems, itervalues
  13. from werkzeug.utils import cached_property
  14. from .mask import Mask
  15. from .errors import abort
  16. from jsonschema import Draft4Validator
  17. from jsonschema.exceptions import ValidationError
  18. from .utils import not_none
  19. from ._http import HTTPStatus
  20. RE_REQUIRED = re.compile(r"u?\'(?P<name>.*)\' is a required property", re.I | re.U)
  21. def instance(cls):
  22. if isinstance(cls, type):
  23. return cls()
  24. return cls
  25. class ModelBase(object):
  26. """
  27. Handles validation and swagger style inheritance for both subclasses.
  28. Subclass must define `schema` attribute.
  29. :param str name: The model public name
  30. """
  31. def __init__(self, name, *args, **kwargs):
  32. super(ModelBase, self).__init__(*args, **kwargs)
  33. self.__apidoc__ = {"name": name}
  34. self.name = name
  35. self.__parents__ = []
  36. def instance_inherit(name, *parents):
  37. return self.__class__.inherit(name, self, *parents)
  38. self.inherit = instance_inherit
  39. @property
  40. def ancestors(self):
  41. """
  42. Return the ancestors tree
  43. """
  44. ancestors = [p.ancestors for p in self.__parents__]
  45. return set.union(set([self.name]), *ancestors)
  46. def get_parent(self, name):
  47. if self.name == name:
  48. return self
  49. else:
  50. for parent in self.__parents__:
  51. found = parent.get_parent(name)
  52. if found:
  53. return found
  54. raise ValueError("Parent " + name + " not found")
  55. @property
  56. def __schema__(self):
  57. schema = self._schema
  58. if self.__parents__:
  59. refs = [
  60. {"$ref": "#/definitions/{0}".format(parent.name)}
  61. for parent in self.__parents__
  62. ]
  63. return {"allOf": refs + [schema]}
  64. else:
  65. return schema
  66. @classmethod
  67. def inherit(cls, name, *parents):
  68. """
  69. Inherit this model (use the Swagger composition pattern aka. allOf)
  70. :param str name: The new model name
  71. :param dict fields: The new model extra fields
  72. """
  73. model = cls(name, parents[-1])
  74. model.__parents__ = parents[:-1]
  75. return model
  76. def validate(self, data, resolver=None, format_checker=None):
  77. validator = Draft4Validator(
  78. self.__schema__, resolver=resolver, format_checker=format_checker
  79. )
  80. try:
  81. validator.validate(data)
  82. except ValidationError:
  83. abort(
  84. HTTPStatus.BAD_REQUEST,
  85. message="Input payload validation failed",
  86. errors=dict(self.format_error(e) for e in validator.iter_errors(data)),
  87. )
  88. def format_error(self, error):
  89. path = list(error.path)
  90. if error.validator == "required":
  91. name = RE_REQUIRED.match(error.message).group("name")
  92. path.append(name)
  93. key = ".".join(str(p) for p in path)
  94. return key, error.message
  95. def __unicode__(self):
  96. return "Model({name},{{{fields}}})".format(
  97. name=self.name, fields=",".join(self.keys())
  98. )
  99. __str__ = __unicode__
  100. class RawModel(ModelBase):
  101. """
  102. A thin wrapper on ordered fields dict to store API doc metadata.
  103. Can also be used for response marshalling.
  104. :param str name: The model public name
  105. :param str mask: an optional default model mask
  106. :param bool strict: validation should raise error when there is param not provided in schema
  107. """
  108. wrapper = dict
  109. def __init__(self, name, *args, **kwargs):
  110. self.__mask__ = kwargs.pop("mask", None)
  111. self.__strict__ = kwargs.pop("strict", False)
  112. if self.__mask__ and not isinstance(self.__mask__, Mask):
  113. self.__mask__ = Mask(self.__mask__)
  114. super(RawModel, self).__init__(name, *args, **kwargs)
  115. def instance_clone(name, *parents):
  116. return self.__class__.clone(name, self, *parents)
  117. self.clone = instance_clone
  118. @property
  119. def _schema(self):
  120. properties = self.wrapper()
  121. required = set()
  122. discriminator = None
  123. for name, field in iteritems(self):
  124. field = instance(field)
  125. properties[name] = field.__schema__
  126. if field.required:
  127. required.add(name)
  128. if getattr(field, "discriminator", False):
  129. discriminator = name
  130. definition = {
  131. "required": sorted(list(required)) or None,
  132. "properties": properties,
  133. "discriminator": discriminator,
  134. "x-mask": str(self.__mask__) if self.__mask__ else None,
  135. "type": "object",
  136. }
  137. if self.__strict__:
  138. definition['additionalProperties'] = False
  139. return not_none(definition)
  140. @cached_property
  141. def resolved(self):
  142. """
  143. Resolve real fields before submitting them to marshal
  144. """
  145. # Duplicate fields
  146. resolved = copy.deepcopy(self)
  147. # Recursively copy parent fields if necessary
  148. for parent in self.__parents__:
  149. resolved.update(parent.resolved)
  150. # Handle discriminator
  151. candidates = [
  152. f for f in itervalues(resolved) if getattr(f, "discriminator", None)
  153. ]
  154. # Ensure the is only one discriminator
  155. if len(candidates) > 1:
  156. raise ValueError("There can only be one discriminator by schema")
  157. # Ensure discriminator always output the model name
  158. elif len(candidates) == 1:
  159. candidates[0].default = self.name
  160. return resolved
  161. def extend(self, name, fields):
  162. """
  163. Extend this model (Duplicate all fields)
  164. :param str name: The new model name
  165. :param dict fields: The new model extra fields
  166. :deprecated: since 0.9. Use :meth:`clone` instead.
  167. """
  168. warnings.warn(
  169. "extend is is deprecated, use clone instead",
  170. DeprecationWarning,
  171. stacklevel=2,
  172. )
  173. if isinstance(fields, (list, tuple)):
  174. return self.clone(name, *fields)
  175. else:
  176. return self.clone(name, fields)
  177. @classmethod
  178. def clone(cls, name, *parents):
  179. """
  180. Clone these models (Duplicate all fields)
  181. It can be used from the class
  182. >>> model = Model.clone(fields_1, fields_2)
  183. or from an Instanciated model
  184. >>> new_model = model.clone(fields_1, fields_2)
  185. :param str name: The new model name
  186. :param dict parents: The new model extra fields
  187. """
  188. fields = cls.wrapper()
  189. for parent in parents:
  190. fields.update(copy.deepcopy(parent))
  191. return cls(name, fields)
  192. def __deepcopy__(self, memo):
  193. obj = self.__class__(
  194. self.name,
  195. [(key, copy.deepcopy(value, memo)) for key, value in iteritems(self)],
  196. mask=self.__mask__,
  197. strict=self.__strict__,
  198. )
  199. obj.__parents__ = self.__parents__
  200. return obj
  201. class Model(RawModel, dict, MutableMapping):
  202. """
  203. A thin wrapper on fields dict to store API doc metadata.
  204. Can also be used for response marshalling.
  205. :param str name: The model public name
  206. :param str mask: an optional default model mask
  207. """
  208. pass
  209. class OrderedModel(RawModel, OrderedDict, MutableMapping):
  210. """
  211. A thin wrapper on ordered fields dict to store API doc metadata.
  212. Can also be used for response marshalling.
  213. :param str name: The model public name
  214. :param str mask: an optional default model mask
  215. """
  216. wrapper = OrderedDict
  217. class SchemaModel(ModelBase):
  218. """
  219. Stores API doc metadata based on a json schema.
  220. :param str name: The model public name
  221. :param dict schema: The json schema we are documenting
  222. """
  223. def __init__(self, name, schema=None):
  224. super(SchemaModel, self).__init__(name)
  225. self._schema = schema or {}
  226. def __unicode__(self):
  227. return "SchemaModel({name},{schema})".format(
  228. name=self.name, schema=self._schema
  229. )
  230. __str__ = __unicode__