reqparse.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. # -*- coding: utf-8 -*-
  2. from __future__ import unicode_literals
  3. import decimal
  4. import six
  5. try:
  6. from collections.abc import Hashable
  7. except ImportError:
  8. from collections import Hashable
  9. from copy import deepcopy
  10. from flask import current_app, request
  11. from werkzeug.datastructures import MultiDict, FileStorage
  12. from werkzeug import exceptions
  13. from .errors import abort, SpecsError
  14. from .marshalling import marshal
  15. from .model import Model
  16. from ._http import HTTPStatus
  17. class ParseResult(dict):
  18. """
  19. The default result container as an Object dict.
  20. """
  21. def __getattr__(self, name):
  22. try:
  23. return self[name]
  24. except KeyError:
  25. raise AttributeError(name)
  26. def __setattr__(self, name, value):
  27. self[name] = value
  28. _friendly_location = {
  29. "json": "the JSON body",
  30. "form": "the post body",
  31. "args": "the query string",
  32. "values": "the post body or the query string",
  33. "headers": "the HTTP headers",
  34. "cookies": "the request's cookies",
  35. "files": "an uploaded file",
  36. }
  37. #: Maps Flask-RESTX RequestParser locations to Swagger ones
  38. LOCATIONS = {
  39. "args": "query",
  40. "form": "formData",
  41. "headers": "header",
  42. "json": "body",
  43. "values": "query",
  44. "files": "formData",
  45. }
  46. #: Maps Python primitives types to Swagger ones
  47. PY_TYPES = {
  48. int: "integer",
  49. str: "string",
  50. bool: "boolean",
  51. float: "number",
  52. None: "void",
  53. }
  54. SPLIT_CHAR = ","
  55. text_type = lambda x: six.text_type(x) # noqa
  56. class Argument(object):
  57. """
  58. :param name: Either a name or a list of option strings, e.g. foo or -f, --foo.
  59. :param default: The value produced if the argument is absent from the request.
  60. :param dest: The name of the attribute to be added to the object
  61. returned by :meth:`~reqparse.RequestParser.parse_args()`.
  62. :param bool required: Whether or not the argument may be omitted (optionals only).
  63. :param string action: The basic type of action to be taken when this argument
  64. is encountered in the request. Valid options are "store" and "append".
  65. :param bool ignore: Whether to ignore cases where the argument fails type conversion
  66. :param type: The type to which the request argument should be converted.
  67. If a type raises an exception, the message in the error will be returned in the response.
  68. Defaults to :class:`unicode` in python2 and :class:`str` in python3.
  69. :param location: The attributes of the :class:`flask.Request` object
  70. to source the arguments from (ex: headers, args, etc.), can be an
  71. iterator. The last item listed takes precedence in the result set.
  72. :param choices: A container of the allowable values for the argument.
  73. :param help: A brief description of the argument, returned in the
  74. response when the argument is invalid. May optionally contain
  75. an "{error_msg}" interpolation token, which will be replaced with
  76. the text of the error raised by the type converter.
  77. :param bool case_sensitive: Whether argument values in the request are
  78. case sensitive or not (this will convert all values to lowercase)
  79. :param bool store_missing: Whether the arguments default value should
  80. be stored if the argument is missing from the request.
  81. :param bool trim: If enabled, trims whitespace around the argument.
  82. :param bool nullable: If enabled, allows null value in argument.
  83. """
  84. def __init__(
  85. self,
  86. name,
  87. default=None,
  88. dest=None,
  89. required=False,
  90. ignore=False,
  91. type=text_type,
  92. location=("json", "values",),
  93. choices=(),
  94. action="store",
  95. help=None,
  96. operators=("=",),
  97. case_sensitive=True,
  98. store_missing=True,
  99. trim=False,
  100. nullable=True,
  101. ):
  102. self.name = name
  103. self.default = default
  104. self.dest = dest
  105. self.required = required
  106. self.ignore = ignore
  107. self.location = location
  108. self.type = type
  109. self.choices = choices
  110. self.action = action
  111. self.help = help
  112. self.case_sensitive = case_sensitive
  113. self.operators = operators
  114. self.store_missing = store_missing
  115. self.trim = trim
  116. self.nullable = nullable
  117. def source(self, request):
  118. """
  119. Pulls values off the request in the provided location
  120. :param request: The flask request object to parse arguments from
  121. """
  122. if isinstance(self.location, six.string_types):
  123. value = getattr(request, self.location, MultiDict())
  124. if callable(value):
  125. value = value()
  126. if value is not None:
  127. return value
  128. else:
  129. values = MultiDict()
  130. for l in self.location:
  131. value = getattr(request, l, None)
  132. if callable(value):
  133. value = value()
  134. if value is not None:
  135. values.update(value)
  136. return values
  137. return MultiDict()
  138. def convert(self, value, op):
  139. # Don't cast None
  140. if value is None:
  141. if not self.nullable:
  142. raise ValueError("Must not be null!")
  143. return None
  144. elif isinstance(self.type, Model) and isinstance(value, dict):
  145. return marshal(value, self.type)
  146. # and check if we're expecting a filestorage and haven't overridden `type`
  147. # (required because the below instantiation isn't valid for FileStorage)
  148. elif isinstance(value, FileStorage) and self.type == FileStorage:
  149. return value
  150. try:
  151. return self.type(value, self.name, op)
  152. except TypeError:
  153. try:
  154. if self.type is decimal.Decimal:
  155. return self.type(str(value), self.name)
  156. else:
  157. return self.type(value, self.name)
  158. except TypeError:
  159. return self.type(value)
  160. def handle_validation_error(self, error, bundle_errors):
  161. """
  162. Called when an error is raised while parsing. Aborts the request
  163. with a 400 status and an error message
  164. :param error: the error that was raised
  165. :param bool bundle_errors: do not abort when first error occurs, return a
  166. dict with the name of the argument and the error message to be
  167. bundled
  168. """
  169. error_str = six.text_type(error)
  170. error_msg = (
  171. " ".join([six.text_type(self.help), error_str]) if self.help else error_str
  172. )
  173. errors = {self.name: error_msg}
  174. if bundle_errors:
  175. return ValueError(error), errors
  176. abort(HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors)
  177. def parse(self, request, bundle_errors=False):
  178. """
  179. Parses argument value(s) from the request, converting according to
  180. the argument's type.
  181. :param request: The flask request object to parse arguments from
  182. :param bool bundle_errors: do not abort when first error occurs, return a
  183. dict with the name of the argument and the error message to be
  184. bundled
  185. """
  186. bundle_errors = current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors
  187. source = self.source(request)
  188. results = []
  189. # Sentinels
  190. _not_found = False
  191. _found = True
  192. for operator in self.operators:
  193. name = self.name + operator.replace("=", "", 1)
  194. if name in source:
  195. # Account for MultiDict and regular dict
  196. if hasattr(source, "getlist"):
  197. values = source.getlist(name)
  198. else:
  199. values = [source.get(name)]
  200. for value in values:
  201. if hasattr(value, "strip") and self.trim:
  202. value = value.strip()
  203. if hasattr(value, "lower") and not self.case_sensitive:
  204. value = value.lower()
  205. if hasattr(self.choices, "__iter__"):
  206. self.choices = [choice.lower() for choice in self.choices]
  207. try:
  208. if self.action == "split":
  209. value = [
  210. self.convert(v, operator)
  211. for v in value.split(SPLIT_CHAR)
  212. ]
  213. else:
  214. value = self.convert(value, operator)
  215. except Exception as error:
  216. if self.ignore:
  217. continue
  218. return self.handle_validation_error(error, bundle_errors)
  219. if self.choices and value not in self.choices:
  220. msg = "The value '{0}' is not a valid choice for '{1}'.".format(
  221. value, name
  222. )
  223. return self.handle_validation_error(msg, bundle_errors)
  224. if name in request.unparsed_arguments:
  225. request.unparsed_arguments.pop(name)
  226. results.append(value)
  227. if not results and self.required:
  228. if isinstance(self.location, six.string_types):
  229. location = _friendly_location.get(self.location, self.location)
  230. else:
  231. locations = [_friendly_location.get(loc, loc) for loc in self.location]
  232. location = " or ".join(locations)
  233. error_msg = "Missing required parameter in {0}".format(location)
  234. return self.handle_validation_error(error_msg, bundle_errors)
  235. if not results:
  236. if callable(self.default):
  237. return self.default(), _not_found
  238. else:
  239. return self.default, _not_found
  240. if self.action == "append":
  241. return results, _found
  242. if self.action == "store" or len(results) == 1:
  243. return results[0], _found
  244. return results, _found
  245. @property
  246. def __schema__(self):
  247. if self.location == "cookie":
  248. return
  249. param = {"name": self.name, "in": LOCATIONS.get(self.location, "query")}
  250. _handle_arg_type(self, param)
  251. if self.required:
  252. param["required"] = True
  253. if self.help:
  254. param["description"] = self.help
  255. if self.default is not None:
  256. param["default"] = (
  257. self.default() if callable(self.default) else self.default
  258. )
  259. if self.action == "append":
  260. param["items"] = {"type": param["type"]}
  261. param["type"] = "array"
  262. param["collectionFormat"] = "multi"
  263. if self.action == "split":
  264. param["items"] = {"type": param["type"]}
  265. param["type"] = "array"
  266. param["collectionFormat"] = "csv"
  267. if self.choices:
  268. param["enum"] = self.choices
  269. return param
  270. class RequestParser(object):
  271. """
  272. Enables adding and parsing of multiple arguments in the context of a single request.
  273. Ex::
  274. from flask_restx import RequestParser
  275. parser = RequestParser()
  276. parser.add_argument('foo')
  277. parser.add_argument('int_bar', type=int)
  278. args = parser.parse_args()
  279. :param bool trim: If enabled, trims whitespace on all arguments in this parser
  280. :param bool bundle_errors: If enabled, do not abort when first error occurs,
  281. return a dict with the name of the argument and the error message to be
  282. bundled and return all validation errors
  283. """
  284. def __init__(
  285. self,
  286. argument_class=Argument,
  287. result_class=ParseResult,
  288. trim=False,
  289. bundle_errors=False,
  290. ):
  291. self.args = []
  292. self.argument_class = argument_class
  293. self.result_class = result_class
  294. self.trim = trim
  295. self.bundle_errors = bundle_errors
  296. def add_argument(self, *args, **kwargs):
  297. """
  298. Adds an argument to be parsed.
  299. Accepts either a single instance of Argument or arguments to be passed
  300. into :class:`Argument`'s constructor.
  301. See :class:`Argument`'s constructor for documentation on the available options.
  302. """
  303. if len(args) == 1 and isinstance(args[0], self.argument_class):
  304. self.args.append(args[0])
  305. else:
  306. self.args.append(self.argument_class(*args, **kwargs))
  307. # Do not know what other argument classes are out there
  308. if self.trim and self.argument_class is Argument:
  309. # enable trim for appended element
  310. self.args[-1].trim = kwargs.get("trim", self.trim)
  311. return self
  312. def parse_args(self, req=None, strict=False):
  313. """
  314. Parse all arguments from the provided request and return the results as a ParseResult
  315. :param bool strict: if req includes args not in parser, throw 400 BadRequest exception
  316. :return: the parsed results as :class:`ParseResult` (or any class defined as :attr:`result_class`)
  317. :rtype: ParseResult
  318. """
  319. if req is None:
  320. req = request
  321. result = self.result_class()
  322. # A record of arguments not yet parsed; as each is found
  323. # among self.args, it will be popped out
  324. req.unparsed_arguments = (
  325. dict(self.argument_class("").source(req)) if strict else {}
  326. )
  327. errors = {}
  328. for arg in self.args:
  329. value, found = arg.parse(req, self.bundle_errors)
  330. if isinstance(value, ValueError):
  331. errors.update(found)
  332. found = None
  333. if found or arg.store_missing:
  334. result[arg.dest or arg.name] = value
  335. if errors:
  336. abort(
  337. HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors
  338. )
  339. if strict and req.unparsed_arguments:
  340. arguments = ", ".join(req.unparsed_arguments.keys())
  341. msg = "Unknown arguments: {0}".format(arguments)
  342. raise exceptions.BadRequest(msg)
  343. return result
  344. def copy(self):
  345. """Creates a copy of this RequestParser with the same set of arguments"""
  346. parser_copy = self.__class__(self.argument_class, self.result_class)
  347. parser_copy.args = deepcopy(self.args)
  348. parser_copy.trim = self.trim
  349. parser_copy.bundle_errors = self.bundle_errors
  350. return parser_copy
  351. def replace_argument(self, name, *args, **kwargs):
  352. """Replace the argument matching the given name with a new version."""
  353. new_arg = self.argument_class(name, *args, **kwargs)
  354. for index, arg in enumerate(self.args[:]):
  355. if new_arg.name == arg.name:
  356. del self.args[index]
  357. self.args.append(new_arg)
  358. break
  359. return self
  360. def remove_argument(self, name):
  361. """Remove the argument matching the given name."""
  362. for index, arg in enumerate(self.args[:]):
  363. if name == arg.name:
  364. del self.args[index]
  365. break
  366. return self
  367. @property
  368. def __schema__(self):
  369. params = []
  370. locations = set()
  371. for arg in self.args:
  372. param = arg.__schema__
  373. if param:
  374. params.append(param)
  375. locations.add(param["in"])
  376. if "body" in locations and "formData" in locations:
  377. raise SpecsError("Can't use formData and body at the same time")
  378. return params
  379. def _handle_arg_type(arg, param):
  380. if isinstance(arg.type, Hashable) and arg.type in PY_TYPES:
  381. param["type"] = PY_TYPES[arg.type]
  382. elif hasattr(arg.type, "__apidoc__"):
  383. param["type"] = arg.type.__apidoc__["name"]
  384. param["in"] = "body"
  385. elif hasattr(arg.type, "__schema__"):
  386. param.update(arg.type.__schema__)
  387. elif arg.location == "files":
  388. param["type"] = "file"
  389. else:
  390. param["type"] = "string"