Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@ def __init__(self, amz_setting: str | AmazonSPAPISettings) -> None:

self.amz_setting = amz_setting
self.instance_params = dict(
iam_arn=self.amz_setting.iam_arn,
client_id=self.amz_setting.client_id,
client_secret=self.amz_setting.get_password("client_secret"),
refresh_token=self.amz_setting.refresh_token,
aws_access_key=self.amz_setting.aws_access_key,
aws_secret_key=self.amz_setting.get_password("aws_secret_key"),
country_code=self.amz_setting.country,
)

Expand Down Expand Up @@ -488,22 +485,16 @@ def get_catalog_items_instance(self) -> CatalogItems:

def validate_amazon_sp_api_credentials(**args) -> None:
api = SPAPI(
iam_arn=args.get("iam_arn"),
client_id=args.get("client_id"),
client_secret=args.get("client_secret"),
refresh_token=args.get("refresh_token"),
aws_access_key=args.get("aws_access_key"),
aws_secret_key=args.get("aws_secret_key"),
country_code=args.get("country"),
)

try:
# validate client_id, client_secret and refresh_token.
api.get_access_token()

# validate aws_access_key, aws_secret_key, region and iam_arn.
api.get_auth()

except SPAPIError as e:
msg = f"<b>Error:</b> {e.error}<br/><b>Error Description:</b> {e.error_description}"
frappe.throw(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,7 @@
# For license information, please see license.txt


import datetime
import hashlib
import hmac

import boto3
from requests import request
from requests.auth import AuthBase
from requests.compat import urlparse

__all__ = [
"SPAPIError",
Expand Down Expand Up @@ -68,132 +61,6 @@
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.


class AWSSigV4(AuthBase):
def __init__(self, service, **kwargs):
"""Create authentication mechanism

:param service: AWS Service identifier, for example `ec2`. This is required.
:param region: AWS Region, for example `us-east-1`. If not provided, it will be set using
the environment variables `AWS_DEFAULT_REGION` or using boto3, if available.
:param session: If boto3 is available, will attempt to get credentials using boto3,
unless passed explicitly. If using boto3, the provided session will be used or a new
session will be created.

"""

self.service = service
self.region = kwargs.get("region")
self.aws_access_key_id = kwargs.get("aws_access_key_id")
self.aws_session_token = kwargs.get("aws_session_token")
self.aws_secret_access_key = kwargs.get("aws_secret_access_key")

if not self.aws_access_key_id or not self.aws_secret_access_key:
raise KeyError("AWS Access Key ID and Secret Access Key are required.")

if self.region is None:
raise KeyError("Region is required.")

def __call__(self, request):
"""Called to add authentication information to request

:param request: `requests.models.PreparedRequest` object to modify

:returns: `requests.models.PreparedRequest`, modified to add authentication

"""

# Create a date for headers and the credential string.
time = datetime.datetime.utcnow()
self.amzdate = time.strftime("%Y%m%dT%H%M%SZ")
self.datestamp = time.strftime("%Y%m%d")

# Parse request to get URL parts.
parsed_url = urlparse(request.url)
host = parsed_url.hostname
uri = parsed_url.path

if len(parsed_url.query) > 0:
query_string = dict(map(lambda i: i.split("="), parsed_url.query.split("&")))
else:
query_string = dict()

# Setup Headers.
if "Host" not in request.headers:
request.headers["Host"] = host
if "Content-Type" not in request.headers:
request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8"
if "User-Agent" not in request.headers:
request.headers["User-Agent"] = "python-amazon-mws/0.0.1 (Language=Python)"
if self.aws_session_token:
request.headers["x-amz-security-token"] = self.aws_session_token
request.headers["X-AMZ-Date"] = self.amzdate

# ************* TASK 1: CREATE A CANONICAL REQUEST *************
# http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html

# Query string values must be URL-encoded (space=%20) and be sorted by name.
canonical_query_string = "&".join(map(lambda p: "=".join(p), sorted(query_string.items())))

# Create payload hash (hash of the request body content).
if request.method == "GET":
payload_hash = hashlib.sha256(b"").hexdigest()
else:
if request.body:
if isinstance(request.body, bytes):
payload_hash = hashlib.sha256(request.body).hexdigest()
else:
payload_hash = hashlib.sha256(request.body.encode("utf-8")).hexdigest()
else:
payload_hash = hashlib.sha256(b"").hexdigest()
request.headers["x-amz-content-sha256"] = payload_hash

# Create the canonical headers and signed headers. Header names
# must be trimmed and lowercase, and sorted in code point order from
# low to high. Note that there is a trailing \n.
headers_to_sign = sorted(
filter(
lambda h: h.startswith("x-amz-") or h == "host",
map(lambda H: H.lower(), request.headers.keys()),
)
)
canonical_headers = "".join([f"{h}:{request.headers[h]}\n" for h in headers_to_sign])
signed_headers = ";".join(headers_to_sign)

# Combine elements to create canonical request.
canonical_request = "\n".join(
[request.method, uri, canonical_query_string, canonical_headers, signed_headers, payload_hash]
)

# ************* TASK 2: CREATE THE STRING TO SIGN*************
credential_scope = "/".join([self.datestamp, self.region, self.service, "aws4_request"])
string_to_sign = "\n".join(
[
"AWS4-HMAC-SHA256",
self.amzdate,
credential_scope,
hashlib.sha256(canonical_request.encode("utf-8")).hexdigest(),
]
)

# ************* TASK 3: CALCULATE THE SIGNATURE *************
def sign(key, msg):
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()

key_date = sign(("AWS4" + self.aws_secret_access_key).encode("utf-8"), self.datestamp)
key_region = sign(key_date, self.region)
k_service = sign(key_region, self.service)
key_signing = sign(k_service, "aws4_request")
signature = hmac.new(key_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()

# ************* TASK 4: ADD SIGNING INFORMATION TO THE REQUEST *************
request.headers["Authorization"] = (
f"AWS4-HMAC-SHA256 Credential={self.aws_access_key_id}/{credential_scope},"
f" SignedHeaders={signed_headers}, Signature={signature}"
)

return request


class SPAPIError(Exception):
"""
Main SP-API Exception class
Expand All @@ -215,20 +82,14 @@ class SPAPI:

def __init__(
self,
iam_arn: str,
client_id: str,
client_secret: str,
refresh_token: str,
aws_access_key: str,
aws_secret_key: str,
country_code: str = "US",
) -> None:
self.iam_arn = iam_arn
self.client_id = client_id
self.client_secret = client_secret
self.refresh_token = refresh_token
self.aws_access_key = aws_access_key
self.aws_secret_key = aws_secret_key
self.country_code = country_code
self.region, self.endpoint, self.marketplace_id = Util.get_marketplace_data(country_code)

Expand All @@ -247,32 +108,6 @@ def get_access_token(self) -> str:
exception = SPAPIError(error=result.get("error"), error_description=result.get("error_description"))
raise exception

def get_auth(self) -> AWSSigV4:
try:
client = boto3.client(
"sts",
aws_access_key_id=self.aws_access_key,
aws_secret_access_key=self.aws_secret_key,
region_name=self.region,
)

response = client.assume_role(RoleArn=self.iam_arn, RoleSessionName="SellingPartnerAPI")

credentials = response["Credentials"]
access_key_id = credentials["AccessKeyId"]
secret_access_key = credentials["SecretAccessKey"]
session_token = credentials["SessionToken"]

return AWSSigV4(
service="execute-api",
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
region=self.region,
)
except Exception as e:
raise SPAPIError(error="invalid_aws_credentials", error_description=e)

def get_headers(self) -> dict:
return {"x-amz-access-token": self.get_access_token()}

Expand All @@ -296,7 +131,6 @@ def make_request(
params=params,
data=data,
headers=self.get_headers(),
auth=self.get_auth(),
)
return response.json()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@
"field_order": [
"is_active",
"section_break_4",
"iam_arn",
"refresh_token",
"column_break_1",
"client_id",
"client_secret",
"section_break_1",
"aws_access_key",
"aws_secret_key",
"column_break_2",
"country",
"section_break_2",
Expand Down Expand Up @@ -55,13 +51,6 @@
"fieldtype": "Section Break",
"label": "Seller Central Credentials"
},
{
"fieldname": "iam_arn",
"fieldtype": "Data",
"in_list_view": 1,
"label": "IAM ARN",
"reqd": 1
},
{
"fieldname": "refresh_token",
"fieldtype": "Text",
Expand All @@ -87,25 +76,6 @@
"label": "Client Secret",
"reqd": 1
},
{
"collapsible": 1,
"fieldname": "section_break_1",
"fieldtype": "Section Break",
"label": "AWS Credentials"
},
{
"fieldname": "aws_access_key",
"fieldtype": "Data",
"label": "AWS Access Key",
"reqd": 1,
"unique": 1
},
{
"fieldname": "aws_secret_key",
"fieldtype": "Password",
"label": "AWS Secret Key",
"reqd": 1
},
{
"fieldname": "column_break_2",
"fieldtype": "Column Break"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,9 @@ def validate_credentials(self):
)

validate_amazon_sp_api_credentials(
iam_arn=self.get("iam_arn"),
client_id=self.get("client_id"),
client_secret=self.get_password("client_secret"),
refresh_token=self.get("refresh_token"),
aws_access_key=self.get("aws_access_key"),
aws_secret_key=self.get_password("aws_secret_key"),
country=self.get("country"),
)

Expand Down