123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455 |
- # -*- coding: utf-8 -*-
- from __future__ import unicode_literals
- import decimal
- import six
- try:
- from collections.abc import Hashable
- except ImportError:
- from collections import Hashable
- from copy import deepcopy
- from flask import current_app, request
- from werkzeug.datastructures import MultiDict, FileStorage
- from werkzeug import exceptions
- from .errors import abort, SpecsError
- from .marshalling import marshal
- from .model import Model
- from ._http import HTTPStatus
- class ParseResult(dict):
- """
- The default result container as an Object dict.
- """
- def __getattr__(self, name):
- try:
- return self[name]
- except KeyError:
- raise AttributeError(name)
- def __setattr__(self, name, value):
- self[name] = value
- _friendly_location = {
- "json": "the JSON body",
- "form": "the post body",
- "args": "the query string",
- "values": "the post body or the query string",
- "headers": "the HTTP headers",
- "cookies": "the request's cookies",
- "files": "an uploaded file",
- }
- #: Maps Flask-RESTX RequestParser locations to Swagger ones
- "args": "query",
- "form": "formData",
- "headers": "header",
- "json": "body",
- "values": "query",
- "files": "formData",
- }
- #: Maps Python primitives types to Swagger ones
- PY_TYPES = {
- int: "integer",
- str: "string",
- bool: "boolean",
- float: "number",
- None: "void",
- }
- SPLIT_CHAR = ","
- text_type = lambda x: six.text_type(x) # noqa
- class Argument(object):
- """
- :param name: Either a name or a list of option strings, e.g. foo or -f, --foo.
- :param default: The value produced if the argument is absent from the request.
- :param dest: The name of the attribute to be added to the object
- returned by :meth:`~reqparse.RequestParser.parse_args()`.
- :param bool required: Whether or not the argument may be omitted (optionals only).
- :param string action: The basic type of action to be taken when this argument
- is encountered in the request. Valid options are "store" and "append".
- :param bool ignore: Whether to ignore cases where the argument fails type conversion
- :param type: The type to which the request argument should be converted.
- If a type raises an exception, the message in the error will be returned in the response.
- Defaults to :class:`unicode` in python2 and :class:`str` in python3.
- :param location: The attributes of the :class:`flask.Request` object
- to source the arguments from (ex: headers, args, etc.), can be an
- iterator. The last item listed takes precedence in the result set.
- :param choices: A container of the allowable values for the argument.
- :param help: A brief description of the argument, returned in the
- response when the argument is invalid. May optionally contain
- an "{error_msg}" interpolation token, which will be replaced with
- the text of the error raised by the type converter.
- :param bool case_sensitive: Whether argument values in the request are
- case sensitive or not (this will convert all values to lowercase)
- :param bool store_missing: Whether the arguments default value should
- be stored if the argument is missing from the request.
- :param bool trim: If enabled, trims whitespace around the argument.
- :param bool nullable: If enabled, allows null value in argument.
- """
- def __init__(
- self,
- name,
- default=None,
- dest=None,
- required=False,
- ignore=False,
- type=text_type,
- location=("json", "values",),
- choices=(),
- action="store",
- help=None,
- operators=("=",),
- case_sensitive=True,
- store_missing=True,
- trim=False,
- nullable=True,
- ):
- self.name = name
- self.default = default
- self.dest = dest
- self.required = required
- self.ignore = ignore
- self.location = location
- self.type = type
- self.choices = choices
- self.action = action
- self.help = help
- self.case_sensitive = case_sensitive
- self.operators = operators
- self.store_missing = store_missing
- self.trim = trim
- self.nullable = nullable
- def source(self, request):
- """
- Pulls values off the request in the provided location
- :param request: The flask request object to parse arguments from
- """
- if isinstance(self.location, six.string_types):
- value = getattr(request, self.location, MultiDict())
- if callable(value):
- value = value()
- if value is not None:
- return value
- else:
- values = MultiDict()
- for l in self.location:
- value = getattr(request, l, None)
- if callable(value):
- value = value()
- if value is not None:
- values.update(value)
- return values
- return MultiDict()
- def convert(self, value, op):
- # Don't cast None
- if value is None:
- if not self.nullable:
- raise ValueError("Must not be null!")
- return None
- elif isinstance(self.type, Model) and isinstance(value, dict):
- return marshal(value, self.type)
- # and check if we're expecting a filestorage and haven't overridden `type`
- # (required because the below instantiation isn't valid for FileStorage)
- elif isinstance(value, FileStorage) and self.type == FileStorage:
- return value
- try:
- return self.type(value, self.name, op)
- except TypeError:
- try:
- if self.type is decimal.Decimal:
- return self.type(str(value), self.name)
- else:
- return self.type(value, self.name)
- except TypeError:
- return self.type(value)
- def handle_validation_error(self, error, bundle_errors):
- """
- Called when an error is raised while parsing. Aborts the request
- with a 400 status and an error message
- :param error: the error that was raised
- :param bool bundle_errors: do not abort when first error occurs, return a
- dict with the name of the argument and the error message to be
- bundled
- """
- error_str = six.text_type(error)
- error_msg = (
- " ".join([six.text_type(self.help), error_str]) if self.help else error_str
- )
- errors = {self.name: error_msg}
- if bundle_errors:
- return ValueError(error), errors
- abort(HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors)
- def parse(self, request, bundle_errors=False):
- """
- Parses argument value(s) from the request, converting according to
- the argument's type.
- :param request: The flask request object to parse arguments from
- :param bool bundle_errors: do not abort when first error occurs, return a
- dict with the name of the argument and the error message to be
- bundled
- """
- bundle_errors = current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors
- source = self.source(request)
- results = []
- # Sentinels
- _not_found = False
- _found = True
- for operator in self.operators:
- name = self.name + operator.replace("=", "", 1)
- if name in source:
- # Account for MultiDict and regular dict
- if hasattr(source, "getlist"):
- values = source.getlist(name)
- else:
- values = [source.get(name)]
- for value in values:
- if hasattr(value, "strip") and self.trim:
- value = value.strip()
- if hasattr(value, "lower") and not self.case_sensitive:
- value = value.lower()
- if hasattr(self.choices, "__iter__"):
- self.choices = [choice.lower() for choice in self.choices]
- try:
- if self.action == "split":
- value = [
- self.convert(v, operator)
- for v in value.split(SPLIT_CHAR)
- ]
- else:
- value = self.convert(value, operator)
- except Exception as error:
- if self.ignore:
- continue
- return self.handle_validation_error(error, bundle_errors)
- if self.choices and value not in self.choices:
- msg = "The value '{0}' is not a valid choice for '{1}'.".format(
- value, name
- )
- return self.handle_validation_error(msg, bundle_errors)
- if name in request.unparsed_arguments:
- request.unparsed_arguments.pop(name)
- results.append(value)
- if not results and self.required:
- if isinstance(self.location, six.string_types):
- location = _friendly_location.get(self.location, self.location)
- else:
- locations = [_friendly_location.get(loc, loc) for loc in self.location]
- location = " or ".join(locations)
- error_msg = "Missing required parameter in {0}".format(location)
- return self.handle_validation_error(error_msg, bundle_errors)
- if not results:
- if callable(self.default):
- return self.default(), _not_found
- else:
- return self.default, _not_found
- if self.action == "append":
- return results, _found
- if self.action == "store" or len(results) == 1:
- return results[0], _found
- return results, _found
- @property
- def __schema__(self):
- if self.location == "cookie":
- return
- param = {"name": self.name, "in": LOCATIONS.get(self.location, "query")}
- _handle_arg_type(self, param)
- if self.required:
- param["required"] = True
- if self.help:
- param["description"] = self.help
- if self.default is not None:
- param["default"] = (
- self.default() if callable(self.default) else self.default
- )
- if self.action == "append":
- param["items"] = {"type": param["type"]}
- param["type"] = "array"
- param["collectionFormat"] = "multi"
- if self.action == "split":
- param["items"] = {"type": param["type"]}
- param["type"] = "array"
- param["collectionFormat"] = "csv"
- if self.choices:
- param["enum"] = self.choices
- return param
- class RequestParser(object):
- """
- Enables adding and parsing of multiple arguments in the context of a single request.
- Ex::
- from flask_restx import RequestParser
- parser = RequestParser()
- parser.add_argument('foo')
- parser.add_argument('int_bar', type=int)
- args = parser.parse_args()
- :param bool trim: If enabled, trims whitespace on all arguments in this parser
- :param bool bundle_errors: If enabled, do not abort when first error occurs,
- return a dict with the name of the argument and the error message to be
- bundled and return all validation errors
- """
- def __init__(
- self,
- argument_class=Argument,
- result_class=ParseResult,
- trim=False,
- bundle_errors=False,
- ):
- self.args = []
- self.argument_class = argument_class
- self.result_class = result_class
- self.trim = trim
- self.bundle_errors = bundle_errors
- def add_argument(self, *args, **kwargs):
- """
- Adds an argument to be parsed.
- Accepts either a single instance of Argument or arguments to be passed
- into :class:`Argument`'s constructor.
- See :class:`Argument`'s constructor for documentation on the available options.
- """
- if len(args) == 1 and isinstance(args[0], self.argument_class):
- self.args.append(args[0])
- else:
- self.args.append(self.argument_class(*args, **kwargs))
- # Do not know what other argument classes are out there
- if self.trim and self.argument_class is Argument:
- # enable trim for appended element
- self.args[-1].trim = kwargs.get("trim", self.trim)
- return self
- def parse_args(self, req=None, strict=False):
- """
- Parse all arguments from the provided request and return the results as a ParseResult
- :param bool strict: if req includes args not in parser, throw 400 BadRequest exception
- :return: the parsed results as :class:`ParseResult` (or any class defined as :attr:`result_class`)
- :rtype: ParseResult
- """
- if req is None:
- req = request
- result = self.result_class()
- # A record of arguments not yet parsed; as each is found
- # among self.args, it will be popped out
- req.unparsed_arguments = (
- dict(self.argument_class("").source(req)) if strict else {}
- )
- errors = {}
- for arg in self.args:
- value, found = arg.parse(req, self.bundle_errors)
- if isinstance(value, ValueError):
- errors.update(found)
- found = None
- if found or arg.store_missing:
- result[arg.dest or arg.name] = value
- if errors:
- abort(
- HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors
- )
- if strict and req.unparsed_arguments:
- arguments = ", ".join(req.unparsed_arguments.keys())
- msg = "Unknown arguments: {0}".format(arguments)
- raise exceptions.BadRequest(msg)
- return result
- def copy(self):
- """Creates a copy of this RequestParser with the same set of arguments"""
- parser_copy = self.__class__(self.argument_class, self.result_class)
- parser_copy.args = deepcopy(self.args)
- parser_copy.trim = self.trim
- parser_copy.bundle_errors = self.bundle_errors
- return parser_copy
- def replace_argument(self, name, *args, **kwargs):
- """Replace the argument matching the given name with a new version."""
- new_arg = self.argument_class(name, *args, **kwargs)
- for index, arg in enumerate(self.args[:]):
- if new_arg.name == arg.name:
- del self.args[index]
- self.args.append(new_arg)
- break
- return self
- def remove_argument(self, name):
- """Remove the argument matching the given name."""
- for index, arg in enumerate(self.args[:]):
- if name == arg.name:
- del self.args[index]
- break
- return self
- @property
- def __schema__(self):
- params = []
- locations = set()
- for arg in self.args:
- param = arg.__schema__
- if param:
- params.append(param)
- locations.add(param["in"])
- if "body" in locations and "formData" in locations:
- raise SpecsError("Can't use formData and body at the same time")
- return params
- def _handle_arg_type(arg, param):
- if isinstance(arg.type, Hashable) and arg.type in PY_TYPES:
- param["type"] = PY_TYPES[arg.type]
- elif hasattr(arg.type, "__apidoc__"):
- param["type"] = arg.type.__apidoc__["name"]
- param["in"] = "body"
- elif hasattr(arg.type, "__schema__"):
- param.update(arg.type.__schema__)
- elif arg.location == "files":
- param["type"] = "file"
- else:
- param["type"] = "string"