message.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # -*- coding: utf-8 -
  2. #
  3. # This file is part of gunicorn released under the MIT license.
  4. # See the NOTICE for more information.
  5. import io
  6. import re
  7. import socket
  8. from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
  9. from gunicorn.http.errors import (
  10. InvalidHeader, InvalidHeaderName, NoMoreData,
  11. InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
  12. LimitRequestLine, LimitRequestHeaders,
  13. )
  14. from gunicorn.http.errors import InvalidProxyLine, ForbiddenProxyRequest
  15. from gunicorn.http.errors import InvalidSchemeHeaders
  16. from gunicorn.util import bytes_to_str, split_request_uri
  17. MAX_REQUEST_LINE = 8190
  18. MAX_HEADERS = 32768
  19. DEFAULT_MAX_HEADERFIELD_SIZE = 8190
  20. HEADER_RE = re.compile(r"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\"]")
  21. METH_RE = re.compile(r"[A-Z0-9$-_.]{3,20}")
  22. VERSION_RE = re.compile(r"HTTP/(\d+)\.(\d+)")
  23. class Message(object):
  24. def __init__(self, cfg, unreader, peer_addr):
  25. self.cfg = cfg
  26. self.unreader = unreader
  27. self.peer_addr = peer_addr
  28. self.version = None
  29. self.headers = []
  30. self.trailers = []
  31. self.body = None
  32. self.scheme = "https" if cfg.is_ssl else "http"
  33. # set headers limits
  34. self.limit_request_fields = cfg.limit_request_fields
  35. if (self.limit_request_fields <= 0
  36. or self.limit_request_fields > MAX_HEADERS):
  37. self.limit_request_fields = MAX_HEADERS
  38. self.limit_request_field_size = cfg.limit_request_field_size
  39. if self.limit_request_field_size < 0:
  40. self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
  41. # set max header buffer size
  42. max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
  43. self.max_buffer_headers = self.limit_request_fields * \
  44. (max_header_field_size + 2) + 4
  45. unused = self.parse(self.unreader)
  46. self.unreader.unread(unused)
  47. self.set_body_reader()
  48. def parse(self, unreader):
  49. raise NotImplementedError()
  50. def parse_headers(self, data):
  51. cfg = self.cfg
  52. headers = []
  53. # Split lines on \r\n keeping the \r\n on each line
  54. lines = [bytes_to_str(line) + "\r\n" for line in data.split(b"\r\n")]
  55. # handle scheme headers
  56. scheme_header = False
  57. secure_scheme_headers = {}
  58. if ('*' in cfg.forwarded_allow_ips or
  59. not isinstance(self.peer_addr, tuple)
  60. or self.peer_addr[0] in cfg.forwarded_allow_ips):
  61. secure_scheme_headers = cfg.secure_scheme_headers
  62. # Parse headers into key/value pairs paying attention
  63. # to continuation lines.
  64. while lines:
  65. if len(headers) >= self.limit_request_fields:
  66. raise LimitRequestHeaders("limit request headers fields")
  67. # Parse initial header name : value pair.
  68. curr = lines.pop(0)
  69. header_length = len(curr)
  70. if curr.find(":") < 0:
  71. raise InvalidHeader(curr.strip())
  72. name, value = curr.split(":", 1)
  73. if self.cfg.strip_header_spaces:
  74. name = name.rstrip(" \t").upper()
  75. else:
  76. name = name.upper()
  77. if HEADER_RE.search(name):
  78. raise InvalidHeaderName(name)
  79. name, value = name.strip(), [value.lstrip()]
  80. # Consume value continuation lines
  81. while lines and lines[0].startswith((" ", "\t")):
  82. curr = lines.pop(0)
  83. header_length += len(curr)
  84. if header_length > self.limit_request_field_size > 0:
  85. raise LimitRequestHeaders("limit request headers "
  86. "fields size")
  87. value.append(curr)
  88. value = ''.join(value).rstrip()
  89. if header_length > self.limit_request_field_size > 0:
  90. raise LimitRequestHeaders("limit request headers fields size")
  91. if name in secure_scheme_headers:
  92. secure = value == secure_scheme_headers[name]
  93. scheme = "https" if secure else "http"
  94. if scheme_header:
  95. if scheme != self.scheme:
  96. raise InvalidSchemeHeaders()
  97. else:
  98. scheme_header = True
  99. self.scheme = scheme
  100. headers.append((name, value))
  101. return headers
  102. def set_body_reader(self):
  103. chunked = False
  104. content_length = None
  105. for (name, value) in self.headers:
  106. if name == "CONTENT-LENGTH":
  107. if content_length is not None:
  108. raise InvalidHeader("CONTENT-LENGTH", req=self)
  109. content_length = value
  110. elif name == "TRANSFER-ENCODING":
  111. if value.lower() == "chunked":
  112. chunked = True
  113. if chunked:
  114. self.body = Body(ChunkedReader(self, self.unreader))
  115. elif content_length is not None:
  116. try:
  117. content_length = int(content_length)
  118. except ValueError:
  119. raise InvalidHeader("CONTENT-LENGTH", req=self)
  120. if content_length < 0:
  121. raise InvalidHeader("CONTENT-LENGTH", req=self)
  122. self.body = Body(LengthReader(self.unreader, content_length))
  123. else:
  124. self.body = Body(EOFReader(self.unreader))
  125. def should_close(self):
  126. for (h, v) in self.headers:
  127. if h == "CONNECTION":
  128. v = v.lower().strip()
  129. if v == "close":
  130. return True
  131. elif v == "keep-alive":
  132. return False
  133. break
  134. return self.version <= (1, 0)
  135. class Request(Message):
  136. def __init__(self, cfg, unreader, peer_addr, req_number=1):
  137. self.method = None
  138. self.uri = None
  139. self.path = None
  140. self.query = None
  141. self.fragment = None
  142. # get max request line size
  143. self.limit_request_line = cfg.limit_request_line
  144. if (self.limit_request_line < 0
  145. or self.limit_request_line >= MAX_REQUEST_LINE):
  146. self.limit_request_line = MAX_REQUEST_LINE
  147. self.req_number = req_number
  148. self.proxy_protocol_info = None
  149. super().__init__(cfg, unreader, peer_addr)
  150. def get_data(self, unreader, buf, stop=False):
  151. data = unreader.read()
  152. if not data:
  153. if stop:
  154. raise StopIteration()
  155. raise NoMoreData(buf.getvalue())
  156. buf.write(data)
  157. def parse(self, unreader):
  158. buf = io.BytesIO()
  159. self.get_data(unreader, buf, stop=True)
  160. # get request line
  161. line, rbuf = self.read_line(unreader, buf, self.limit_request_line)
  162. # proxy protocol
  163. if self.proxy_protocol(bytes_to_str(line)):
  164. # get next request line
  165. buf = io.BytesIO()
  166. buf.write(rbuf)
  167. line, rbuf = self.read_line(unreader, buf, self.limit_request_line)
  168. self.parse_request_line(line)
  169. buf = io.BytesIO()
  170. buf.write(rbuf)
  171. # Headers
  172. data = buf.getvalue()
  173. idx = data.find(b"\r\n\r\n")
  174. done = data[:2] == b"\r\n"
  175. while True:
  176. idx = data.find(b"\r\n\r\n")
  177. done = data[:2] == b"\r\n"
  178. if idx < 0 and not done:
  179. self.get_data(unreader, buf)
  180. data = buf.getvalue()
  181. if len(data) > self.max_buffer_headers:
  182. raise LimitRequestHeaders("max buffer headers")
  183. else:
  184. break
  185. if done:
  186. self.unreader.unread(data[2:])
  187. return b""
  188. self.headers = self.parse_headers(data[:idx])
  189. ret = data[idx + 4:]
  190. buf = None
  191. return ret
  192. def read_line(self, unreader, buf, limit=0):
  193. data = buf.getvalue()
  194. while True:
  195. idx = data.find(b"\r\n")
  196. if idx >= 0:
  197. # check if the request line is too large
  198. if idx > limit > 0:
  199. raise LimitRequestLine(idx, limit)
  200. break
  201. if len(data) - 2 > limit > 0:
  202. raise LimitRequestLine(len(data), limit)
  203. self.get_data(unreader, buf)
  204. data = buf.getvalue()
  205. return (data[:idx], # request line,
  206. data[idx + 2:]) # residue in the buffer, skip \r\n
  207. def proxy_protocol(self, line):
  208. """\
  209. Detect, check and parse proxy protocol.
  210. :raises: ForbiddenProxyRequest, InvalidProxyLine.
  211. :return: True for proxy protocol line else False
  212. """
  213. if not self.cfg.proxy_protocol:
  214. return False
  215. if self.req_number != 1:
  216. return False
  217. if not line.startswith("PROXY"):
  218. return False
  219. self.proxy_protocol_access_check()
  220. self.parse_proxy_protocol(line)
  221. return True
  222. def proxy_protocol_access_check(self):
  223. # check in allow list
  224. if ("*" not in self.cfg.proxy_allow_ips and
  225. isinstance(self.peer_addr, tuple) and
  226. self.peer_addr[0] not in self.cfg.proxy_allow_ips):
  227. raise ForbiddenProxyRequest(self.peer_addr[0])
  228. def parse_proxy_protocol(self, line):
  229. bits = line.split()
  230. if len(bits) != 6:
  231. raise InvalidProxyLine(line)
  232. # Extract data
  233. proto = bits[1]
  234. s_addr = bits[2]
  235. d_addr = bits[3]
  236. # Validation
  237. if proto not in ["TCP4", "TCP6"]:
  238. raise InvalidProxyLine("protocol '%s' not supported" % proto)
  239. if proto == "TCP4":
  240. try:
  241. socket.inet_pton(socket.AF_INET, s_addr)
  242. socket.inet_pton(socket.AF_INET, d_addr)
  243. except socket.error:
  244. raise InvalidProxyLine(line)
  245. elif proto == "TCP6":
  246. try:
  247. socket.inet_pton(socket.AF_INET6, s_addr)
  248. socket.inet_pton(socket.AF_INET6, d_addr)
  249. except socket.error:
  250. raise InvalidProxyLine(line)
  251. try:
  252. s_port = int(bits[4])
  253. d_port = int(bits[5])
  254. except ValueError:
  255. raise InvalidProxyLine("invalid port %s" % line)
  256. if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
  257. raise InvalidProxyLine("invalid port %s" % line)
  258. # Set data
  259. self.proxy_protocol_info = {
  260. "proxy_protocol": proto,
  261. "client_addr": s_addr,
  262. "client_port": s_port,
  263. "proxy_addr": d_addr,
  264. "proxy_port": d_port
  265. }
  266. def parse_request_line(self, line_bytes):
  267. bits = [bytes_to_str(bit) for bit in line_bytes.split(None, 2)]
  268. if len(bits) != 3:
  269. raise InvalidRequestLine(bytes_to_str(line_bytes))
  270. # Method
  271. if not METH_RE.match(bits[0]):
  272. raise InvalidRequestMethod(bits[0])
  273. self.method = bits[0].upper()
  274. # URI
  275. self.uri = bits[1]
  276. try:
  277. parts = split_request_uri(self.uri)
  278. except ValueError:
  279. raise InvalidRequestLine(bytes_to_str(line_bytes))
  280. self.path = parts.path or ""
  281. self.query = parts.query or ""
  282. self.fragment = parts.fragment or ""
  283. # Version
  284. match = VERSION_RE.match(bits[2])
  285. if match is None:
  286. raise InvalidHTTPVersion(bits[2])
  287. self.version = (int(match.group(1)), int(match.group(2)))
  288. def set_body_reader(self):
  289. super().set_body_reader()
  290. if isinstance(self.body.reader, EOFReader):
  291. self.body = Body(LengthReader(self.unreader, 0))