#!/usr/bin/env python # -*- coding: utf-8 -*- """ This module converts an AWS API Gateway proxied request to a WSGI request. Inspired by: https://github.com/miserlou/zappa Author: Logan Raarup """ import base64 import io import json import os import sys from werkzeug.datastructures import Headers, iter_multi_items, MultiDict from werkzeug.wrappers import Response from werkzeug.urls import url_encode, url_unquote, url_unquote_plus from werkzeug.http import HTTP_STATUS_CODES # List of MIME types that should not be base64 encoded. MIME types within `text/*` # are included by default. TEXT_MIME_TYPES = [ "application/json", "application/javascript", "application/xml", "application/vnd.api+json", "image/svg+xml", ] def all_casings(input_string): """ Permute all casings of a given string. A pretty algoritm, via @Amber http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python """ if not input_string: yield "" else: first = input_string[:1] if first.lower() == first.upper(): for sub_casing in all_casings(input_string[1:]): yield first + sub_casing else: for sub_casing in all_casings(input_string[1:]): yield first.lower() + sub_casing yield first.upper() + sub_casing def split_headers(headers): """ If there are multiple occurrences of headers, create case-mutated variations in order to pass them through APIGW. This is a hack that's currently needed. See: https://github.com/logandk/serverless-wsgi/issues/11 Source: https://github.com/Miserlou/Zappa/blob/master/zappa/middleware.py """ new_headers = {} for key in headers.keys(): values = headers.get_all(key) if len(values) > 1: for value, casing in zip(values, all_casings(key)): new_headers[casing] = value elif len(values) == 1: new_headers[key] = values[0] return new_headers def group_headers(headers): new_headers = {} for key in headers.keys(): new_headers[key] = headers.get_all(key) return new_headers def is_alb_event(event): return event.get("requestContext", {}).get("elb") def encode_query_string(event): params = event.get("multiValueQueryStringParameters") if not params: params = event.get("queryStringParameters") if not params: params = event.get("query") if not params: params = "" if is_alb_event(event): params = MultiDict( (url_unquote_plus(k), url_unquote_plus(v)) for k, v in iter_multi_items(params) ) return url_encode(params) def get_script_name(headers, request_context): strip_stage_path = os.environ.get("STRIP_STAGE_PATH", "").lower().strip() in [ "yes", "y", "true", "t", "1", ] if "amazonaws.com" in headers.get("Host", "") and not strip_stage_path: script_name = "/{}".format(request_context.get("stage", "")) else: script_name = "" return script_name def get_body_bytes(event, body): if event.get("isBase64Encoded", False): body = base64.b64decode(body) if isinstance(body, str): body = body.encode("utf-8") return body def setup_environ_items(environ, headers): for key, value in environ.items(): if isinstance(value, str): environ[key] = value.encode("utf-8").decode("latin1", "replace") for key, value in headers.items(): key = "HTTP_" + key.upper().replace("-", "_") if key not in ("HTTP_CONTENT_TYPE", "HTTP_CONTENT_LENGTH"): environ[key] = value return environ def generate_response(response, event): returndict = {"statusCode": response.status_code} if "multiValueHeaders" in event: returndict["multiValueHeaders"] = group_headers(response.headers) else: returndict["headers"] = split_headers(response.headers) if is_alb_event(event): # If the request comes from ALB we need to add a status description returndict["statusDescription"] = "%d %s" % ( response.status_code, HTTP_STATUS_CODES[response.status_code], ) if response.data: mimetype = response.mimetype or "text/plain" if ( mimetype.startswith("text/") or mimetype in TEXT_MIME_TYPES ) and not response.headers.get("Content-Encoding", ""): returndict["body"] = response.get_data(as_text=True) returndict["isBase64Encoded"] = False else: returndict["body"] = base64.b64encode(response.data).decode("utf-8") returndict["isBase64Encoded"] = True return returndict def handle_request(app, event, context): if event.get("source") in ["aws.events", "serverless-plugin-warmup"]: print("Lambda warming event received, skipping handler") return {} if ( event.get("version") is None and event.get("isBase64Encoded") is None and not is_alb_event(event) ): return handle_lambda_integration(app, event, context) if event.get("version") == "2.0": return handle_payload_v2(app, event, context) return handle_payload_v1(app, event, context) def handle_payload_v1(app, event, context): if "multiValueHeaders" in event: headers = Headers(event["multiValueHeaders"]) else: headers = Headers(event["headers"]) script_name = get_script_name(headers, event.get("requestContext", {})) # If a user is using a custom domain on API Gateway, they may have a base # path in their URL. This allows us to strip it out via an optional # environment variable. path_info = event["path"] base_path = os.environ.get("API_GATEWAY_BASE_PATH") if base_path: script_name = "/" + base_path if path_info.startswith(script_name): path_info = path_info[len(script_name) :] body = event["body"] or "" body = get_body_bytes(event, body) environ = { "CONTENT_LENGTH": str(len(body)), "CONTENT_TYPE": headers.get("Content-Type", ""), "PATH_INFO": url_unquote(path_info), "QUERY_STRING": encode_query_string(event), "REMOTE_ADDR": event.get("requestContext", {}) .get("identity", {}) .get("sourceIp", ""), "REMOTE_USER": event.get("requestContext", {}) .get("authorizer", {}) .get("principalId", ""), "REQUEST_METHOD": event.get("httpMethod", {}), "SCRIPT_NAME": script_name, "SERVER_NAME": headers.get("Host", "lambda"), "SERVER_PORT": headers.get("X-Forwarded-Port", "80"), "SERVER_PROTOCOL": "HTTP/1.1", "wsgi.errors": sys.stderr, "wsgi.input": io.BytesIO(body), "wsgi.multiprocess": False, "wsgi.multithread": False, "wsgi.run_once": False, "wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"), "wsgi.version": (1, 0), "serverless.authorizer": event.get("requestContext", {}).get("authorizer"), "serverless.event": event, "serverless.context": context, } environ = setup_environ_items(environ, headers) response = Response.from_app(app, environ) returndict = generate_response(response, event) return returndict def handle_payload_v2(app, event, context): headers = Headers(event["headers"]) script_name = get_script_name(headers, event.get("requestContext", {})) path_info = event["rawPath"] body = event.get("body", "") body = get_body_bytes(event, body) headers["Cookie"] = "; ".join(event.get("cookies", [])) environ = { "CONTENT_LENGTH": str(len(body)), "CONTENT_TYPE": headers.get("Content-Type", ""), "PATH_INFO": url_unquote(path_info), "QUERY_STRING": event.get("rawQueryString", ""), "REMOTE_ADDR": event.get("requestContext", {}) .get("http", {}) .get("sourceIp", ""), "REMOTE_USER": event.get("requestContext", {}) .get("authorizer", {}) .get("principalId", ""), "REQUEST_METHOD": event.get("requestContext", {}) .get("http", {}) .get("method", ""), "SCRIPT_NAME": script_name, "SERVER_NAME": headers.get("Host", "lambda"), "SERVER_PORT": headers.get("X-Forwarded-Port", "80"), "SERVER_PROTOCOL": "HTTP/1.1", "wsgi.errors": sys.stderr, "wsgi.input": io.BytesIO(body), "wsgi.multiprocess": False, "wsgi.multithread": False, "wsgi.run_once": False, "wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"), "wsgi.version": (1, 0), "serverless.authorizer": event.get("requestContext", {}).get("authorizer"), "serverless.event": event, "serverless.context": context, } environ = setup_environ_items(environ, headers) response = Response.from_app(app, environ) returndict = generate_response(response, event) return returndict def handle_lambda_integration(app, event, context): headers = Headers(event["headers"]) script_name = get_script_name(headers, event) path_info = event["requestPath"] for key, value in event.get("path", {}).items(): path_info = path_info.replace("{%s}" % key, value) path_info = path_info.replace("{%s+}" % key, value) body = event.get("body", {}) body = json.dumps(body) if body else "" body = get_body_bytes(event, body) environ = { "CONTENT_LENGTH": str(len(body)), "CONTENT_TYPE": headers.get("Content-Type", ""), "PATH_INFO": url_unquote(path_info), "QUERY_STRING": url_encode(event.get("query", {})), "REMOTE_ADDR": event.get("identity", {}).get("sourceIp", ""), "REMOTE_USER": event.get("principalId", ""), "REQUEST_METHOD": event.get("method", ""), "SCRIPT_NAME": script_name, "SERVER_NAME": headers.get("Host", "lambda"), "SERVER_PORT": headers.get("X-Forwarded-Port", "80"), "SERVER_PROTOCOL": "HTTP/1.1", "wsgi.errors": sys.stderr, "wsgi.input": io.BytesIO(body), "wsgi.multiprocess": False, "wsgi.multithread": False, "wsgi.run_once": False, "wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"), "wsgi.version": (1, 0), "serverless.authorizer": event.get("enhancedAuthContext"), "serverless.event": event, "serverless.context": context, } environ = setup_environ_items(environ, headers) response = Response.from_app(app, environ) returndict = generate_response(response, event) if response.status_code >= 300: raise RuntimeError(json.dumps(returndict)) return returndict