Skip to content

Commit 478c249

Browse files
committed
handle unicode objects and utf-8 strings in url and params and encode them to utf-8 when serializing
1 parent 63bd405 commit 478c249

File tree

2 files changed

+148
-29
lines changed

2 files changed

+148
-29
lines changed

oauth2/__init__.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,47 @@ def build_xoauth_string(url, consumer, token=None):
9090

9191
def to_unicode(s):
9292
""" Convert to unicode, raise exception with instructive error
93-
message if s is not unicode or ascii. """
93+
message if s is not unicode, ascii, or utf-8. """
9494
if not isinstance(s, unicode):
9595
if not isinstance(s, str):
9696
raise TypeError('You are required to pass either unicode or string here, not: %r (%s)' % (type(s), s))
9797
try:
98-
s = s.decode('ascii')
98+
s = s.decode('utf-8')
9999
except UnicodeDecodeError, le:
100-
raise TypeError('You are required to pass either a unicode object or an ascii string here. You passed a Python string object which contained non-ascii: %r. The UnicodeDecodeError that resulted from attempting to interpret it as ascii was: %s' % (s, le,))
100+
raise TypeError('You are required to pass either a unicode object or a utf-8 string here. You passed a Python string object which contained non-utf-8: %r. The UnicodeDecodeError that resulted from attempting to interpret it as utf-8 was: %s' % (s, le,))
101101
return s
102102

103+
def to_utf8(s):
104+
return to_unicode(s).encode('utf-8')
105+
106+
def to_unicode_if_string(s):
107+
if isinstance(s, basestring):
108+
return to_unicode(s)
109+
else:
110+
return s
111+
112+
def to_utf8_if_string(s):
113+
if isinstance(s, basestring):
114+
return to_utf8(s)
115+
else:
116+
return s
117+
118+
def to_unicode_optional_iterator(x):
119+
"""
120+
Raise TypeError if x is a str containing non-utf8 bytes or if x is
121+
an iterable which contains such a str.
122+
"""
123+
if isinstance(x, basestring):
124+
return to_unicode(x)
125+
126+
try:
127+
l = list(x)
128+
except TypeError, e:
129+
assert 'is not iterable' in str(e)
130+
return x
131+
else:
132+
return [ to_unicode(e) for e in l ]
133+
103134
def escape(s):
104135
"""Escape a URL including any /."""
105136
s = to_unicode(s)
@@ -292,8 +323,12 @@ def __init__(self, method=HTTP_METHOD, url=None, parameters=None):
292323
self.url = to_unicode(url)
293324
self.method = method
294325
if parameters is not None:
295-
self.update(parameters)
296-
326+
for k, v in parameters.iteritems():
327+
k = to_unicode(k)
328+
v = to_unicode_optional_iterator(v)
329+
self[k] = v
330+
331+
297332
@setter
298333
def url(self, value):
299334
self.__dict__['url'] = value
@@ -383,7 +418,7 @@ def get_parameter(self, parameter):
383418
raise Error('Parameter not found: %s' % parameter)
384419

385420
return ret
386-
421+
387422
def get_normalized_parameters(self):
388423
"""Return a string that contains the parameters that must be signed."""
389424
items = []
@@ -392,16 +427,22 @@ def get_normalized_parameters(self):
392427
continue
393428
# 1.0a/9.1.1 states that kvp must be sorted by key, then by value,
394429
# so we unpack sequence values into multiple items for sorting.
395-
if hasattr(value, '__iter__'):
396-
items.extend((key, item) for item in value)
430+
if isinstance(value, basestring):
431+
items.append((to_utf8_if_string(key), to_utf8(value)))
397432
else:
398-
items.append((key, value))
433+
try:
434+
value = list(value)
435+
except TypeError, e:
436+
assert 'is not iterable' in str(e)
437+
items.append((to_utf8_if_string(key), to_utf8_if_string(value)))
438+
else:
439+
items.extend((to_utf8_if_string(key), to_utf8_if_string(item)) for item in value)
399440

400441
# Include any query string parameters from the provided URL
401442
query = urlparse.urlparse(self.url)[4]
402-
443+
403444
url_items = self._split_url_string(query).items()
404-
non_oauth_url_items = list([(k, v) for k, v in url_items if not k.startswith('oauth_')])
445+
non_oauth_url_items = list([(to_utf8(k), to_utf8(v)) for k, v in url_items if not k.startswith('oauth_')])
405446
items.extend(non_oauth_url_items)
406447

407448
encoded_str = urllib.urlencode(sorted(items))
@@ -410,7 +451,7 @@ def get_normalized_parameters(self):
410451
# (http://tools.ietf.org/html/draft-hammer-oauth-07#section-3.6)
411452
# Spaces must be encoded with "%20" instead of "+"
412453
return encoded_str.replace('+', '%20').replace('%7E', '~')
413-
454+
414455
def sign_request(self, signature_method, consumer, token):
415456
"""Set the signature parameter to the result of sign."""
416457

tests/test_oauth.py

Lines changed: 95 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,24 @@ def test_from_string(self):
248248
new = oauth.Token.from_string(string)
249249
self._compare_tokens(new)
250250

251-
class TestRequest(unittest.TestCase):
251+
class ReallyEqualMixin:
252+
def failUnlessReallyEqual(self, a, b, msg=None):
253+
self.failUnlessEqual(a, b, msg=msg)
254+
self.failUnlessEqual(type(a), type(b), msg="a :: %r, b :: %r, %r" % (a, b, msg))
255+
256+
class TestFuncs(unittest.TestCase):
257+
def test_to_unicode(self):
258+
self.failUnlessRaises(TypeError, oauth.to_unicode, '\xae')
259+
self.failUnlessRaises(TypeError, oauth.to_unicode_optional_iterator, '\xae')
260+
self.failUnlessRaises(TypeError, oauth.to_unicode_optional_iterator, ['\xae'])
261+
262+
self.failUnlessEqual(oauth.to_unicode(':-)'), u':-)')
263+
self.failUnlessEqual(oauth.to_unicode(u'\u00ae'), u'\u00ae')
264+
self.failUnlessEqual(oauth.to_unicode('\xc2\xae'), u'\u00ae')
265+
self.failUnlessEqual(oauth.to_unicode_optional_iterator([':-)']), [u':-)'])
266+
self.failUnlessEqual(oauth.to_unicode_optional_iterator([u'\u00ae']), [u'\u00ae'])
267+
268+
class TestRequest(unittest.TestCase, ReallyEqualMixin):
252269
def test_setter(self):
253270
url = "http://example.com"
254271
method = "GET"
@@ -342,9 +359,11 @@ def test_get_nonoauth_parameters(self):
342359
}
343360

344361
other_params = {
345-
'foo': 'baz',
346-
'bar': 'foo',
347-
'multi': ['FOO','BAR']
362+
u'foo': u'baz',
363+
u'bar': u'foo',
364+
u'multi': [u'FOO',u'BAR'],
365+
u'uni_utf8': u'\xae',
366+
u'uni_unicode': u'\u00ae'
348367
}
349368

350369
params = oauth_params
@@ -461,6 +480,24 @@ def test_to_url_with_query(self):
461480
self.assertEquals(b['max-contacts'], ['10'])
462481
self.assertEquals(a, b)
463482

483+
def test_signature_base_string_nonascii(self):
484+
consumer = oauth.Consumer('consumer_token', 'consumer_secret')
485+
486+
url = "http://api.simplegeo.com:80/1.0/places/address.json?q=monkeys&category=animal&address=41+Decatur+St%2C+San+Francisc%E2%9D%A6%2C+CA"
487+
req = oauth.Request("GET", url)
488+
self.failUnlessReallyEqual(req.normalized_url, u'http://api.simplegeo.com/1.0/places/address.json')
489+
self.assertEquals(req.url, u'http://api.simplegeo.com:80/1.0/places/address.json?q=monkeys&category=animal&address=41+Decatur+St%2C+San+Francisc%E2%9D%A6%2C+CA')
490+
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), consumer, None)
491+
492+
def test_signature_base_string_nonascii_nonutf8(self):
493+
consumer = oauth.Consumer('consumer_token', 'consumer_secret')
494+
495+
url = "http://api.simplegeo.com:80/1.0/places/address.json?q=monkeys&category=animal&address=41+Decatur+St%2C+San+Francisc%E2%9D%A6%2C+CA"
496+
req = oauth.Request("GET", url)
497+
self.failUnlessReallyEqual(req.normalized_url, u'http://api.simplegeo.com/1.0/places/address.json')
498+
self.assertEquals(req.url, u'http://api.simplegeo.com:80/1.0/places/address.json?q=monkeys&category=animal&address=41+Decatur+St%2C+San+Francisc%E2%9D%A6%2C+CA')
499+
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), consumer, None)
500+
464501
def test_signature_base_string_with_query(self):
465502
url = "https://www.google.com/m8/feeds/contacts/default/full/?alt=json&max-contacts=10"
466503
params = {
@@ -495,16 +532,18 @@ def test_get_normalized_parameters(self):
495532
'oauth_consumer_key': "0685bd9184jfhq22",
496533
'oauth_signature_method': "HMAC-SHA1",
497534
'oauth_token': "ad180jjd733klru7",
498-
'multi': ['FOO','BAR'],
535+
'multi': ['FOO','BAR', u'\u00ae', '\xc2\xae'],
536+
'uni_utf8': '\xc2\xae',
537+
'uni_unicode': u'\u00ae'
499538
}
500539

501540
req = oauth.Request("GET", url, params)
502541

503542
res = req.get_normalized_parameters()
504-
505-
srtd = [(k, v if type(v) != ListType else sorted(v)) for k,v in sorted(params.items())]
506543

507-
self.assertEquals(urllib.urlencode(srtd, True), res)
544+
expected='multi=BAR&multi=FOO&multi=%C2%AE&multi=%C2%AE&oauth_consumer_key=0685bd9184jfhq22&oauth_nonce=4572616e48616d6d65724c61686176&oauth_signature_method=HMAC-SHA1&oauth_timestamp=137131200&oauth_token=ad180jjd733klru7&oauth_version=1.0&uni_unicode=%C2%AE&uni_utf8=%C2%AE'
545+
546+
self.assertEquals(expected, res)
508547

509548
def test_get_normalized_parameters_ignores_auth_signature(self):
510549
url = "http://sp.example.com/"
@@ -558,24 +597,63 @@ def test_get_normalized_string_escapes_spaces_properly(self):
558597
expected = urllib.urlencode(sorted(params.items())).replace('+', '%20')
559598
self.assertEqual(expected, res)
560599

561-
def test_request_nonascii_bytes(self):
562-
# If someone has a sequence of bytes which is not ascii, we'll
563-
# raise an exception as early as possible.
564-
url = "http://sp.example.com/\x92"
600+
@mock.patch('oauth2.Request.make_timestamp')
601+
@mock.patch('oauth2.Request.make_nonce')
602+
def test_request_nonascii_bytes(self, mock_make_nonce, mock_make_timestamp):
603+
mock_make_nonce.return_value = 5
604+
mock_make_timestamp.return_value = 6
565605

606+
tok = oauth.Token(key="tok-test-key", secret="tok-test-secret")
607+
con = oauth.Consumer(key="con-test-key", secret="con-test-secret")
566608
params = {
567609
'oauth_version': "1.0",
568610
'oauth_nonce': "4572616e48616d6d65724c61686176",
569-
'oauth_timestamp': "137131200"
611+
'oauth_timestamp': "137131200",
612+
'oauth_token': tok.key,
613+
'oauth_consumer_key': con.key
570614
}
571615

572-
tok = oauth.Token(key="tok-test-key", secret="tok-test-secret")
573-
con = oauth.Consumer(key="con-test-key", secret="con-test-secret")
616+
# If someone passes a sequence of bytes which is not ascii for
617+
# url, we'll raise an exception as early as possible.
618+
url = "http://sp.example.com/\x92" # It's actually cp1252-encoding...
619+
self.assertRaises(TypeError, oauth.Request, method="GET", url=url, parameters=params)
574620

575-
params['oauth_token'] = tok.key
576-
params['oauth_consumer_key'] = con.key
621+
# And if they pass an unicode, then we'll use it.
622+
url = u'http://sp.example.com/\u2019'
623+
req = oauth.Request(method="GET", url=url, parameters=params)
624+
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
625+
self.failUnlessReallyEqual(req['oauth_signature'], '/DgF7cY2friC01cmOAFdu8S0z+A=')
626+
627+
# And if it is a utf-8-encoded-then-percent-encoded non-ascii
628+
# thing, we'll decode it and use it.
629+
url = "http://sp.example.com/%E2%80%99"
630+
req = oauth.Request(method="GET", url=url, parameters=params)
631+
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
632+
self.failUnlessReallyEqual(req['oauth_signature'], 'anzjnpdqCUJWvePgDiwMb7Q8g28=')
633+
634+
# Same thing with the params.
635+
url = "http://sp.example.com/"
636+
637+
# If someone passes a sequence of bytes which is not ascii in
638+
# params, we'll raise an exception as early as possible.
639+
params['non_oauth_thing'] = '\xae', # It's actually cp1252-encoding...
577640
self.assertRaises(TypeError, oauth.Request, method="GET", url=url, parameters=params)
578641

642+
# And if they pass a unicode, then we'll use it.
643+
params['non_oauth_thing'] = u'\u2019'
644+
req = oauth.Request(method="GET", url=url, parameters=params)
645+
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
646+
self.failUnlessReallyEqual(req['oauth_signature'], 'QcgQMe9XzNxDWpechlQKFCd2orw=')
647+
648+
# And if it is a utf-8-encoded non-ascii thing, we'll decode
649+
# it and use it.
650+
params['non_oauth_thing'] = '\xc2\xae'
651+
req = oauth.Request(method="GET", url=url, parameters=params)
652+
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), con, None)
653+
self.failUnlessReallyEqual(req['oauth_signature'], 'OuMkgNFhlgcmEA1gIMII7aWLDgE=')
654+
655+
656+
579657
def test_sign_request(self):
580658
url = "http://sp.example.com/"
581659

0 commit comments

Comments
 (0)