You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
484 lines
16 KiB
484 lines
16 KiB
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"). You
|
|
# may not use this file except in compliance with the License. A copy of
|
|
# the License is located at
|
|
#
|
|
# http://aws.amazon.com/apache2.0/
|
|
#
|
|
# or in the "license" file accompanying this file. This file is
|
|
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
|
# ANY KIND, either express or implied. See the License for the specific
|
|
# language governing permissions and limitations under the License.
|
|
|
|
""" The interfaces in this module are not intended for public use.
|
|
|
|
This module defines interfaces for applying checksums to HTTP requests within
|
|
the context of botocore. This involves both resolving the checksum to be used
|
|
based on client configuration and environment, as well as application of the
|
|
checksum to the request.
|
|
"""
|
|
import base64
|
|
import io
|
|
import logging
|
|
from binascii import crc32
|
|
from hashlib import sha1, sha256
|
|
|
|
from botocore.compat import HAS_CRT
|
|
from botocore.exceptions import (
|
|
AwsChunkedWrapperError,
|
|
FlexibleChecksumError,
|
|
MissingDependencyException,
|
|
)
|
|
from botocore.response import StreamingBody
|
|
from botocore.utils import (
|
|
conditionally_calculate_md5,
|
|
determine_content_length,
|
|
)
|
|
|
|
if HAS_CRT:
|
|
from awscrt import checksums as crt_checksums
|
|
else:
|
|
crt_checksums = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseChecksum:
|
|
_CHUNK_SIZE = 1024 * 1024
|
|
|
|
def update(self, chunk):
|
|
pass
|
|
|
|
def digest(self):
|
|
pass
|
|
|
|
def b64digest(self):
|
|
bs = self.digest()
|
|
return base64.b64encode(bs).decode("ascii")
|
|
|
|
def _handle_fileobj(self, fileobj):
|
|
start_position = fileobj.tell()
|
|
for chunk in iter(lambda: fileobj.read(self._CHUNK_SIZE), b""):
|
|
self.update(chunk)
|
|
fileobj.seek(start_position)
|
|
|
|
def handle(self, body):
|
|
if isinstance(body, (bytes, bytearray)):
|
|
self.update(body)
|
|
else:
|
|
self._handle_fileobj(body)
|
|
return self.b64digest()
|
|
|
|
|
|
class Crc32Checksum(BaseChecksum):
|
|
def __init__(self):
|
|
self._int_crc32 = 0
|
|
|
|
def update(self, chunk):
|
|
self._int_crc32 = crc32(chunk, self._int_crc32) & 0xFFFFFFFF
|
|
|
|
def digest(self):
|
|
return self._int_crc32.to_bytes(4, byteorder="big")
|
|
|
|
|
|
class CrtCrc32Checksum(BaseChecksum):
|
|
# Note: This class is only used if the CRT is available
|
|
def __init__(self):
|
|
self._int_crc32 = 0
|
|
|
|
def update(self, chunk):
|
|
new_checksum = crt_checksums.crc32(chunk, self._int_crc32)
|
|
self._int_crc32 = new_checksum & 0xFFFFFFFF
|
|
|
|
def digest(self):
|
|
return self._int_crc32.to_bytes(4, byteorder="big")
|
|
|
|
|
|
class CrtCrc32cChecksum(BaseChecksum):
|
|
# Note: This class is only used if the CRT is available
|
|
def __init__(self):
|
|
self._int_crc32c = 0
|
|
|
|
def update(self, chunk):
|
|
new_checksum = crt_checksums.crc32c(chunk, self._int_crc32c)
|
|
self._int_crc32c = new_checksum & 0xFFFFFFFF
|
|
|
|
def digest(self):
|
|
return self._int_crc32c.to_bytes(4, byteorder="big")
|
|
|
|
|
|
class Sha1Checksum(BaseChecksum):
|
|
def __init__(self):
|
|
self._checksum = sha1()
|
|
|
|
def update(self, chunk):
|
|
self._checksum.update(chunk)
|
|
|
|
def digest(self):
|
|
return self._checksum.digest()
|
|
|
|
|
|
class Sha256Checksum(BaseChecksum):
|
|
def __init__(self):
|
|
self._checksum = sha256()
|
|
|
|
def update(self, chunk):
|
|
self._checksum.update(chunk)
|
|
|
|
def digest(self):
|
|
return self._checksum.digest()
|
|
|
|
|
|
class AwsChunkedWrapper:
|
|
_DEFAULT_CHUNK_SIZE = 1024 * 1024
|
|
|
|
def __init__(
|
|
self,
|
|
raw,
|
|
checksum_cls=None,
|
|
checksum_name="x-amz-checksum",
|
|
chunk_size=None,
|
|
):
|
|
self._raw = raw
|
|
self._checksum_name = checksum_name
|
|
self._checksum_cls = checksum_cls
|
|
self._reset()
|
|
|
|
if chunk_size is None:
|
|
chunk_size = self._DEFAULT_CHUNK_SIZE
|
|
self._chunk_size = chunk_size
|
|
|
|
def _reset(self):
|
|
self._remaining = b""
|
|
self._complete = False
|
|
self._checksum = None
|
|
if self._checksum_cls:
|
|
self._checksum = self._checksum_cls()
|
|
|
|
def seek(self, offset, whence=0):
|
|
if offset != 0 or whence != 0:
|
|
raise AwsChunkedWrapperError(
|
|
error_msg="Can only seek to start of stream"
|
|
)
|
|
self._reset()
|
|
self._raw.seek(0)
|
|
|
|
def read(self, size=None):
|
|
# Normalize "read all" size values to None
|
|
if size is not None and size <= 0:
|
|
size = None
|
|
|
|
# If the underlying body is done and we have nothing left then
|
|
# end the stream
|
|
if self._complete and not self._remaining:
|
|
return b""
|
|
|
|
# While we're not done and want more bytes
|
|
want_more_bytes = size is None or size > len(self._remaining)
|
|
while not self._complete and want_more_bytes:
|
|
self._remaining += self._make_chunk()
|
|
want_more_bytes = size is None or size > len(self._remaining)
|
|
|
|
# If size was None, we want to return everything
|
|
if size is None:
|
|
size = len(self._remaining)
|
|
|
|
# Return a chunk up to the size asked for
|
|
to_return = self._remaining[:size]
|
|
self._remaining = self._remaining[size:]
|
|
return to_return
|
|
|
|
def _make_chunk(self):
|
|
# NOTE: Chunk size is not deterministic as read could return less. This
|
|
# means we cannot know the content length of the encoded aws-chunked
|
|
# stream ahead of time without ensuring a consistent chunk size
|
|
raw_chunk = self._raw.read(self._chunk_size)
|
|
hex_len = hex(len(raw_chunk))[2:].encode("ascii")
|
|
self._complete = not raw_chunk
|
|
|
|
if self._checksum:
|
|
self._checksum.update(raw_chunk)
|
|
|
|
if self._checksum and self._complete:
|
|
name = self._checksum_name.encode("ascii")
|
|
checksum = self._checksum.b64digest().encode("ascii")
|
|
return b"0\r\n%s:%s\r\n\r\n" % (name, checksum)
|
|
|
|
return b"%s\r\n%s\r\n" % (hex_len, raw_chunk)
|
|
|
|
def __iter__(self):
|
|
while not self._complete:
|
|
yield self._make_chunk()
|
|
|
|
|
|
class StreamingChecksumBody(StreamingBody):
|
|
def __init__(self, raw_stream, content_length, checksum, expected):
|
|
super().__init__(raw_stream, content_length)
|
|
self._checksum = checksum
|
|
self._expected = expected
|
|
|
|
def read(self, amt=None):
|
|
chunk = super().read(amt=amt)
|
|
self._checksum.update(chunk)
|
|
if amt is None or (not chunk and amt > 0):
|
|
self._validate_checksum()
|
|
return chunk
|
|
|
|
def _validate_checksum(self):
|
|
if self._checksum.digest() != base64.b64decode(self._expected):
|
|
error_msg = (
|
|
f"Expected checksum {self._expected} did not match calculated "
|
|
f"checksum: {self._checksum.b64digest()}"
|
|
)
|
|
raise FlexibleChecksumError(error_msg=error_msg)
|
|
|
|
|
|
def resolve_checksum_context(request, operation_model, params):
|
|
resolve_request_checksum_algorithm(request, operation_model, params)
|
|
resolve_response_checksum_algorithms(request, operation_model, params)
|
|
|
|
|
|
def resolve_request_checksum_algorithm(
|
|
request,
|
|
operation_model,
|
|
params,
|
|
supported_algorithms=None,
|
|
):
|
|
http_checksum = operation_model.http_checksum
|
|
algorithm_member = http_checksum.get("requestAlgorithmMember")
|
|
if algorithm_member and algorithm_member in params:
|
|
# If the client has opted into using flexible checksums and the
|
|
# request supports it, use that instead of checksum required
|
|
if supported_algorithms is None:
|
|
supported_algorithms = _SUPPORTED_CHECKSUM_ALGORITHMS
|
|
|
|
algorithm_name = params[algorithm_member].lower()
|
|
if algorithm_name not in supported_algorithms:
|
|
if not HAS_CRT and algorithm_name in _CRT_CHECKSUM_ALGORITHMS:
|
|
raise MissingDependencyException(
|
|
msg=(
|
|
f"Using {algorithm_name.upper()} requires an "
|
|
"additional dependency. You will need to pip install "
|
|
"botocore[crt] before proceeding."
|
|
)
|
|
)
|
|
raise FlexibleChecksumError(
|
|
error_msg="Unsupported checksum algorithm: %s" % algorithm_name
|
|
)
|
|
|
|
location_type = "header"
|
|
if operation_model.has_streaming_input:
|
|
# Operations with streaming input must support trailers.
|
|
if request["url"].startswith("https:"):
|
|
# We only support unsigned trailer checksums currently. As this
|
|
# disables payload signing we'll only use trailers over TLS.
|
|
location_type = "trailer"
|
|
|
|
algorithm = {
|
|
"algorithm": algorithm_name,
|
|
"in": location_type,
|
|
"name": "x-amz-checksum-%s" % algorithm_name,
|
|
}
|
|
|
|
if algorithm["name"] in request["headers"]:
|
|
# If the header is already set by the customer, skip calculation
|
|
return
|
|
|
|
checksum_context = request["context"].get("checksum", {})
|
|
checksum_context["request_algorithm"] = algorithm
|
|
request["context"]["checksum"] = checksum_context
|
|
elif operation_model.http_checksum_required or http_checksum.get(
|
|
"requestChecksumRequired"
|
|
):
|
|
# Otherwise apply the old http checksum behavior via Content-MD5
|
|
checksum_context = request["context"].get("checksum", {})
|
|
checksum_context["request_algorithm"] = "conditional-md5"
|
|
request["context"]["checksum"] = checksum_context
|
|
|
|
|
|
def apply_request_checksum(request):
|
|
checksum_context = request.get("context", {}).get("checksum", {})
|
|
algorithm = checksum_context.get("request_algorithm")
|
|
|
|
if not algorithm:
|
|
return
|
|
|
|
if algorithm == "conditional-md5":
|
|
# Special case to handle the http checksum required trait
|
|
conditionally_calculate_md5(request)
|
|
elif algorithm["in"] == "header":
|
|
_apply_request_header_checksum(request)
|
|
elif algorithm["in"] == "trailer":
|
|
_apply_request_trailer_checksum(request)
|
|
else:
|
|
raise FlexibleChecksumError(
|
|
error_msg="Unknown checksum variant: %s" % algorithm["in"]
|
|
)
|
|
|
|
|
|
def _apply_request_header_checksum(request):
|
|
checksum_context = request.get("context", {}).get("checksum", {})
|
|
algorithm = checksum_context.get("request_algorithm")
|
|
location_name = algorithm["name"]
|
|
if location_name in request["headers"]:
|
|
# If the header is already set by the customer, skip calculation
|
|
return
|
|
checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])
|
|
digest = checksum_cls().handle(request["body"])
|
|
request["headers"][location_name] = digest
|
|
|
|
|
|
def _apply_request_trailer_checksum(request):
|
|
checksum_context = request.get("context", {}).get("checksum", {})
|
|
algorithm = checksum_context.get("request_algorithm")
|
|
location_name = algorithm["name"]
|
|
checksum_cls = _CHECKSUM_CLS.get(algorithm["algorithm"])
|
|
|
|
headers = request["headers"]
|
|
body = request["body"]
|
|
|
|
if location_name in headers:
|
|
# If the header is already set by the customer, skip calculation
|
|
return
|
|
|
|
headers["Transfer-Encoding"] = "chunked"
|
|
if "Content-Encoding" in headers:
|
|
# We need to preserve the existing content encoding and add
|
|
# aws-chunked as a new content encoding.
|
|
headers["Content-Encoding"] += ",aws-chunked"
|
|
else:
|
|
headers["Content-Encoding"] = "aws-chunked"
|
|
headers["X-Amz-Trailer"] = location_name
|
|
|
|
content_length = determine_content_length(body)
|
|
if content_length is not None:
|
|
# Send the decoded content length if we can determine it. Some
|
|
# services such as S3 may require the decoded content length
|
|
headers["X-Amz-Decoded-Content-Length"] = str(content_length)
|
|
|
|
if isinstance(body, (bytes, bytearray)):
|
|
body = io.BytesIO(body)
|
|
|
|
request["body"] = AwsChunkedWrapper(
|
|
body,
|
|
checksum_cls=checksum_cls,
|
|
checksum_name=location_name,
|
|
)
|
|
|
|
|
|
def resolve_response_checksum_algorithms(
|
|
request, operation_model, params, supported_algorithms=None
|
|
):
|
|
http_checksum = operation_model.http_checksum
|
|
mode_member = http_checksum.get("requestValidationModeMember")
|
|
if mode_member and mode_member in params:
|
|
if supported_algorithms is None:
|
|
supported_algorithms = _SUPPORTED_CHECKSUM_ALGORITHMS
|
|
response_algorithms = {
|
|
a.lower() for a in http_checksum.get("responseAlgorithms", [])
|
|
}
|
|
|
|
usable_algorithms = []
|
|
for algorithm in _ALGORITHMS_PRIORITY_LIST:
|
|
if algorithm not in response_algorithms:
|
|
continue
|
|
if algorithm in supported_algorithms:
|
|
usable_algorithms.append(algorithm)
|
|
|
|
checksum_context = request["context"].get("checksum", {})
|
|
checksum_context["response_algorithms"] = usable_algorithms
|
|
request["context"]["checksum"] = checksum_context
|
|
|
|
|
|
def handle_checksum_body(http_response, response, context, operation_model):
|
|
headers = response["headers"]
|
|
checksum_context = context.get("checksum", {})
|
|
algorithms = checksum_context.get("response_algorithms")
|
|
|
|
if not algorithms:
|
|
return
|
|
|
|
for algorithm in algorithms:
|
|
header_name = "x-amz-checksum-%s" % algorithm
|
|
# If the header is not found, check the next algorithm
|
|
if header_name not in headers:
|
|
continue
|
|
|
|
# If a - is in the checksum this is not valid Base64. S3 returns
|
|
# checksums that include a -# suffix to indicate a checksum derived
|
|
# from the hash of all part checksums. We cannot wrap this response
|
|
if "-" in headers[header_name]:
|
|
continue
|
|
|
|
if operation_model.has_streaming_output:
|
|
response["body"] = _handle_streaming_response(
|
|
http_response, response, algorithm
|
|
)
|
|
else:
|
|
response["body"] = _handle_bytes_response(
|
|
http_response, response, algorithm
|
|
)
|
|
|
|
# Expose metadata that the checksum check actually occurred
|
|
checksum_context = response["context"].get("checksum", {})
|
|
checksum_context["response_algorithm"] = algorithm
|
|
response["context"]["checksum"] = checksum_context
|
|
return
|
|
|
|
logger.info(
|
|
f'Skipping checksum validation. Response did not contain one of the '
|
|
f'following algorithms: {algorithms}.'
|
|
)
|
|
|
|
|
|
def _handle_streaming_response(http_response, response, algorithm):
|
|
checksum_cls = _CHECKSUM_CLS.get(algorithm)
|
|
header_name = "x-amz-checksum-%s" % algorithm
|
|
return StreamingChecksumBody(
|
|
http_response.raw,
|
|
response["headers"].get("content-length"),
|
|
checksum_cls(),
|
|
response["headers"][header_name],
|
|
)
|
|
|
|
|
|
def _handle_bytes_response(http_response, response, algorithm):
|
|
body = http_response.content
|
|
header_name = "x-amz-checksum-%s" % algorithm
|
|
checksum_cls = _CHECKSUM_CLS.get(algorithm)
|
|
checksum = checksum_cls()
|
|
checksum.update(body)
|
|
expected = response["headers"][header_name]
|
|
if checksum.digest() != base64.b64decode(expected):
|
|
error_msg = (
|
|
"Expected checksum %s did not match calculated checksum: %s"
|
|
% (
|
|
expected,
|
|
checksum.b64digest(),
|
|
)
|
|
)
|
|
raise FlexibleChecksumError(error_msg=error_msg)
|
|
return body
|
|
|
|
|
|
_CHECKSUM_CLS = {
|
|
"crc32": Crc32Checksum,
|
|
"sha1": Sha1Checksum,
|
|
"sha256": Sha256Checksum,
|
|
}
|
|
_CRT_CHECKSUM_ALGORITHMS = ["crc32", "crc32c"]
|
|
if HAS_CRT:
|
|
# Use CRT checksum implementations if available
|
|
_CRT_CHECKSUM_CLS = {
|
|
"crc32": CrtCrc32Checksum,
|
|
"crc32c": CrtCrc32cChecksum,
|
|
}
|
|
_CHECKSUM_CLS.update(_CRT_CHECKSUM_CLS)
|
|
# Validate this list isn't out of sync with _CRT_CHECKSUM_CLS keys
|
|
assert all(
|
|
name in _CRT_CHECKSUM_ALGORITHMS for name in _CRT_CHECKSUM_CLS.keys()
|
|
)
|
|
_SUPPORTED_CHECKSUM_ALGORITHMS = list(_CHECKSUM_CLS.keys())
|
|
_ALGORITHMS_PRIORITY_LIST = ['crc32c', 'crc32', 'sha1', 'sha256']
|