11"""Mypy static type checker plugin for Pytest"""
22
3+ import json
34import os
5+ from tempfile import NamedTemporaryFile
46
7+ from filelock import FileLock
58import pytest
69import mypy .api
710
@@ -20,11 +23,44 @@ def pytest_addoption(parser):
2023 help = "suppresses error messages about imports that cannot be resolved" )
2124
2225
26+ def _is_master (config ):
27+ """
28+ True if the code running the given pytest.config object is running in
29+ an xdist master node or not running xdist at all.
30+ """
31+ return not hasattr (config , 'slaveinput' )
32+
33+
2334def pytest_configure (config ):
2435 """
25- Register a custom marker for MypyItems,
36+ Initialize the path used to cache mypy results,
37+ register a custom marker for MypyItems,
2638 and configure the plugin based on the CLI.
2739 """
40+ if _is_master (config ):
41+
42+ # Get the path to a temporary file and delete it.
43+ # The first MypyItem to run will see the file does not exist,
44+ # and it will run and parse mypy results to create it.
45+ # Subsequent MypyItems will see the file exists,
46+ # and they will read the parsed results.
47+ with NamedTemporaryFile (delete = True ) as tmp_f :
48+ config ._mypy_results_path = tmp_f .name
49+
50+ # If xdist is enabled, then the results path should be exposed to
51+ # the slaves so that they know where to read parsed results from.
52+ if config .pluginmanager .getplugin ('xdist' ):
53+ class _MypyXdistPlugin :
54+ def pytest_configure_node (self , node ): # xdist hook
55+ """Pass config._mypy_results_path to workers."""
56+ node .slaveinput ['_mypy_results_path' ] = \
57+ node .config ._mypy_results_path
58+ config .pluginmanager .register (_MypyXdistPlugin ())
59+
60+ # pytest_terminal_summary cannot accept config before pytest 4.2.
61+ global _pytest_terminal_summary_config
62+ _pytest_terminal_summary_config = config
63+
2864 config .addinivalue_line (
2965 'markers' ,
3066 '{marker}: mark tests to be checked by mypy.' .format (
@@ -45,46 +81,6 @@ def pytest_collect_file(path, parent):
4581 return None
4682
4783
48- def pytest_runtestloop (session ):
49- """Run mypy on collected MypyItems, then sort the output."""
50- mypy_items = {
51- os .path .abspath (str (item .fspath )): item
52- for item in session .items
53- if isinstance (item , MypyItem )
54- }
55- if mypy_items :
56-
57- terminal = session .config .pluginmanager .getplugin ('terminalreporter' )
58- terminal .write (
59- '\n Running {command} on {file_count} files... ' .format (
60- command = ' ' .join (['mypy' ] + mypy_argv ),
61- file_count = len (mypy_items ),
62- ),
63- )
64- stdout , stderr , status = mypy .api .run (
65- mypy_argv + [str (item .fspath ) for item in mypy_items .values ()],
66- )
67- terminal .write ('done with status {status}\n ' .format (status = status ))
68-
69- unmatched_lines = []
70- for line in stdout .split ('\n ' ):
71- if not line :
72- continue
73- mypy_path , _ , error = line .partition (':' )
74- try :
75- item = mypy_items [os .path .abspath (mypy_path )]
76- except KeyError :
77- unmatched_lines .append (line )
78- else :
79- item .mypy_errors .append (error )
80- if any (unmatched_lines ):
81- color = {"red" : True } if status != 0 else {"green" : True }
82- terminal .write_line ('\n ' .join (unmatched_lines ), ** color )
83-
84- if stderr :
85- terminal .write_line (stderr , red = True )
86-
87-
8884class MypyItem (pytest .Item , pytest .File ):
8985
9086 """A File that Mypy Runs On."""
@@ -94,12 +90,28 @@ class MypyItem(pytest.Item, pytest.File):
9490 def __init__ (self , * args , ** kwargs ):
9591 super ().__init__ (* args , ** kwargs )
9692 self .add_marker (self .MARKER )
97- self .mypy_errors = []
9893
9994 def runtest (self ):
10095 """Raise an exception if mypy found errors for this item."""
101- if self .mypy_errors :
102- raise MypyError ('\n ' .join (self .mypy_errors ))
96+ results = _cached_json_results (
97+ results_path = (
98+ self .config ._mypy_results_path
99+ if _is_master (self .config ) else
100+ self .config .slaveinput ['_mypy_results_path' ]
101+ ),
102+ results_factory = lambda :
103+ _mypy_results_factory (
104+ abspaths = [
105+ os .path .abspath (str (item .fspath ))
106+ for item in self .session .items
107+ if isinstance (item , MypyItem )
108+ ],
109+ )
110+ )
111+ abspath = os .path .abspath (str (self .fspath ))
112+ errors = results ['abspath_errors' ].get (abspath )
113+ if errors :
114+ raise MypyError ('\n ' .join (errors ))
103115
104116 def reportinfo (self ):
105117 """Produce a heading for the test report."""
@@ -119,8 +131,70 @@ def repr_failure(self, excinfo):
119131 return super ().repr_failure (excinfo )
120132
121133
134+ def _cached_json_results (results_path , results_factory = None ):
135+ """
136+ Read results from results_path if it exists;
137+ otherwise, produce them with results_factory,
138+ and write them to results_path.
139+ """
140+ with FileLock (results_path + '.lock' ):
141+ try :
142+ with open (results_path , mode = 'r' ) as results_f :
143+ results = json .load (results_f )
144+ except FileNotFoundError :
145+ if not results_factory :
146+ raise
147+ results = results_factory ()
148+ with open (results_path , mode = 'w' ) as results_f :
149+ json .dump (results , results_f )
150+ return results
151+
152+
153+ def _mypy_results_factory (abspaths ):
154+ """Run mypy on abspaths and return the results as a JSON-able dict."""
155+
156+ stdout , stderr , status = mypy .api .run (mypy_argv + abspaths )
157+
158+ abspath_errors , unmatched_lines = {}, []
159+ for line in stdout .split ('\n ' ):
160+ if not line :
161+ continue
162+ path , _ , error = line .partition (':' )
163+ abspath = os .path .abspath (path )
164+ if abspath in abspaths :
165+ abspath_errors [abspath ] = abspath_errors .get (abspath , []) + [error ]
166+ else :
167+ unmatched_lines .append (line )
168+
169+ return {
170+ 'stdout' : stdout ,
171+ 'stderr' : stderr ,
172+ 'status' : status ,
173+ 'abspath_errors' : abspath_errors ,
174+ 'unmatched_stdout' : '\n ' .join (unmatched_lines ),
175+ }
176+
177+
122178class MypyError (Exception ):
123179 """
124180 An error caught by mypy, e.g a type checker violation
125181 or a syntax error.
126182 """
183+
184+
185+ def pytest_terminal_summary (terminalreporter ):
186+ """Report stderr and unrecognized lines from stdout."""
187+ config = _pytest_terminal_summary_config
188+ try :
189+ results = _cached_json_results (config ._mypy_results_path )
190+ except FileNotFoundError :
191+ # No MypyItems executed.
192+ return
193+ if results ['unmatched_stdout' ] or results ['stderr' ]:
194+ terminalreporter .section ('mypy' )
195+ if results ['unmatched_stdout' ]:
196+ color = {'red' : True } if results ['status' ] else {'green' : True }
197+ terminalreporter .write_line (results ['unmatched_stdout' ], ** color )
198+ if results ['stderr' ]:
199+ terminalreporter .write_line (results ['stderr' ], yellow = True )
200+ os .remove (config ._mypy_results_path )
0 commit comments