diff --git a/pyamf/tests/__init__.py b/pyamf/tests/__init__.py index 528ee4e5..7ea33bc1 100644 --- a/pyamf/tests/__init__.py +++ b/pyamf/tests/__init__.py @@ -7,39 +7,30 @@ @since: 0.1.0 """ -import unittest +import os.path -# some Python 2.3 unittest compatibility fixes -if not hasattr(unittest.TestCase, 'assertTrue'): - unittest.TestCase.assertTrue = unittest.TestCase.failUnless -if not hasattr(unittest.TestCase, 'assertFalse'): - unittest.TestCase.assertFalse = unittest.TestCase.failIf +try: + import unittest2 as unittest +except ImportError: + import unittest -mod_base = 'pyamf.tests' +def get_suite(): + """ + Return a unittest.TestSuite. + """ + loader = unittest.TestLoader() -def suite(): - import os.path - from glob import glob - - suite = unittest.TestSuite() - - for testcase in glob(os.path.join(os.path.dirname(__file__), 'test_*.py')): - mod_name = os.path.basename(testcase).split('.')[0] - full_name = '%s.%s' % (mod_base, mod_name) - - mod = __import__(full_name) - - for part in full_name.split('.')[1:]: - mod = getattr(mod, part) - - suite.addTest(mod.suite()) - - return suite + return loader.discover(os.path.dirname(__file__)) def main(): - unittest.main(defaultTest='suite') + """ + Run all of the tests when run as a module with -m. + """ + runner = unittest.TextTestRunner() + runner.run(get_suite()) + if __name__ == '__main__': main() diff --git a/setup.py b/setup.py index 810ad2dd..34fb0188 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ import sys, os.path from setuptools import setup, find_packages, Extension +from setuptools.command import test try: from Cython.Distutils import build_ext @@ -34,6 +35,17 @@ del sys.modules[k] +class TestCommand(test.test): + def run_tests(self): + import sys + + import unittest2 + + sys.modules['unittest'] = unittest2 + + return test.test.run_tests(self) + + def get_cpyamf_extensions(): """ Returns a list of all extensions for the cpyamf module. If for some reason @@ -133,10 +145,12 @@ def get_test_requirements(): ext_modules = get_extensions(), install_requires = get_install_requirements(), tests_require = get_test_requirements(), + test_suite = "pyamf.tests.get_suite", zip_safe = True, license = "MIT License", platforms = ["any"], cmdclass = { + 'test': TestCommand, 'build_ext': build_ext, }, extras_require = {