|
| 1 | +__all__ = ['recordtype'] |
| 2 | + |
| 3 | +import sys |
| 4 | +from textwrap import dedent |
| 5 | +from keyword import iskeyword |
| 6 | + |
| 7 | + |
| 8 | +def recordtype(typename, field_names, verbose=False, **default_kwds): |
| 9 | + '''Returns a new class with named fields. |
| 10 | +
|
| 11 | + @keyword field_defaults: A mapping from (a subset of) field names to default |
| 12 | + values. |
| 13 | + @keyword default: If provided, the default value for all fields without an |
| 14 | + explicit default in `field_defaults`. |
| 15 | +
|
| 16 | + >>> Point = recordtype('Point', 'x y', default=0) |
| 17 | + >>> Point.__doc__ # docstring for the new class |
| 18 | + 'Point(x, y)' |
| 19 | + >>> Point() # instantiate with defaults |
| 20 | + Point(x=0, y=0) |
| 21 | + >>> p = Point(11, y=22) # instantiate with positional args or keywords |
| 22 | + >>> p[0] + p.y # accessible by name and index |
| 23 | + 33 |
| 24 | + >>> p.x = 100; p[1] =200 # modifiable by name and index |
| 25 | + >>> p |
| 26 | + Point(x=100, y=200) |
| 27 | + >>> x, y = p # unpack |
| 28 | + >>> x, y |
| 29 | + (100, 200) |
| 30 | + >>> d = p.todict() # convert to a dictionary |
| 31 | + >>> d['x'] |
| 32 | + 100 |
| 33 | + >>> Point(**d) == p # convert from a dictionary |
| 34 | + True |
| 35 | + ''' |
| 36 | + # Parse and validate the field names. Validation serves two purposes, |
| 37 | + # generating informative error messages and preventing template injection attacks. |
| 38 | + if isinstance(field_names, basestring): |
| 39 | + # names separated by whitespace and/or commas |
| 40 | + field_names = field_names.replace(',', ' ').split() |
| 41 | + field_names = tuple(map(str, field_names)) |
| 42 | + if not field_names: |
| 43 | + raise ValueError('Records must have at least one field') |
| 44 | + for name in (typename,) + field_names: |
| 45 | + if not min(c.isalnum() or c=='_' for c in name): |
| 46 | + raise ValueError('Type names and field names can only contain ' |
| 47 | + 'alphanumeric characters and underscores: %r' % name) |
| 48 | + if iskeyword(name): |
| 49 | + raise ValueError('Type names and field names cannot be a keyword: %r' |
| 50 | + % name) |
| 51 | + if name[0].isdigit(): |
| 52 | + raise ValueError('Type names and field names cannot start with a ' |
| 53 | + 'number: %r' % name) |
| 54 | + seen_names = set() |
| 55 | + for name in field_names: |
| 56 | + if name.startswith('_'): |
| 57 | + raise ValueError('Field names cannot start with an underscore: %r' |
| 58 | + % name) |
| 59 | + if name in seen_names: |
| 60 | + raise ValueError('Encountered duplicate field name: %r' % name) |
| 61 | + seen_names.add(name) |
| 62 | + # determine the func_defaults of __init__ |
| 63 | + field_defaults = default_kwds.pop('field_defaults', {}) |
| 64 | + if 'default' in default_kwds: |
| 65 | + default = default_kwds.pop('default') |
| 66 | + init_defaults = tuple(field_defaults.get(f,default) for f in field_names) |
| 67 | + elif not field_defaults: |
| 68 | + init_defaults = None |
| 69 | + else: |
| 70 | + default_fields = field_names[-len(field_defaults):] |
| 71 | + if set(default_fields) != set(field_defaults): |
| 72 | + raise ValueError('Missing default parameter values') |
| 73 | + init_defaults = tuple(field_defaults[f] for f in default_fields) |
| 74 | + if default_kwds: |
| 75 | + raise ValueError('Invalid keyword arguments: %s' % default_kwds) |
| 76 | + # Create and fill-in the class template |
| 77 | + numfields = len(field_names) |
| 78 | + argtxt = ', '.join(field_names) |
| 79 | + reprtxt = ', '.join('%s=%%r' % f for f in field_names) |
| 80 | + dicttxt = ', '.join('%r: self.%s' % (f,f) for f in field_names) |
| 81 | + tupletxt = repr(tuple('self.%s' % f for f in field_names)).replace("'",'') |
| 82 | + inittxt = '; '.join('self.%s=%s' % (f,f) for f in field_names) |
| 83 | + itertxt = '; '.join('yield self.%s' % f for f in field_names) |
| 84 | + eqtxt = ' and '.join('self.%s==other.%s' % (f,f) for f in field_names) |
| 85 | + template = dedent(''' |
| 86 | + class %(typename)s(object): |
| 87 | + '%(typename)s(%(argtxt)s)' |
| 88 | +
|
| 89 | + __slots__ = %(field_names)r |
| 90 | +
|
| 91 | + def __init__(self, %(argtxt)s): |
| 92 | + %(inittxt)s |
| 93 | +
|
| 94 | + def __len__(self): |
| 95 | + return %(numfields)d |
| 96 | +
|
| 97 | + def __iter__(self): |
| 98 | + %(itertxt)s |
| 99 | +
|
| 100 | + def __getitem__(self, index): |
| 101 | + return getattr(self, self.__slots__[index]) |
| 102 | +
|
| 103 | + def __setitem__(self, index, value): |
| 104 | + return setattr(self, self.__slots__[index], value) |
| 105 | +
|
| 106 | + def todict(self): |
| 107 | + 'Return a new dict which maps field names to their values' |
| 108 | + return {%(dicttxt)s} |
| 109 | +
|
| 110 | + def __repr__(self): |
| 111 | + return '%(typename)s(%(reprtxt)s)' %% %(tupletxt)s |
| 112 | +
|
| 113 | + def __eq__(self, other): |
| 114 | + return isinstance(other, self.__class__) and %(eqtxt)s |
| 115 | +
|
| 116 | + def __ne__(self, other): |
| 117 | + return not self==other |
| 118 | +
|
| 119 | + def __getstate__(self): |
| 120 | + return %(tupletxt)s |
| 121 | +
|
| 122 | + def __setstate__(self, state): |
| 123 | + %(tupletxt)s = state |
| 124 | + ''') % locals() |
| 125 | + # Execute the template string in a temporary namespace |
| 126 | + namespace = {} |
| 127 | + try: |
| 128 | + exec template in namespace |
| 129 | + if verbose: print template |
| 130 | + except SyntaxError, e: |
| 131 | + raise SyntaxError(e.message + ':\n' + template) |
| 132 | + cls = namespace[typename] |
| 133 | + cls.__init__.im_func.func_defaults = init_defaults |
| 134 | + # For pickling to work, the __module__ variable needs to be set to the frame |
| 135 | + # where the named tuple is created. Bypass this step in enviroments where |
| 136 | + # sys._getframe is not defined (Jython for example). |
| 137 | + if hasattr(sys, '_getframe') and sys.platform != 'cli': |
| 138 | + cls.__module__ = sys._getframe(1).f_globals['__name__'] |
| 139 | + return cls |
| 140 | + |
| 141 | + |
| 142 | +if __name__ == '__main__': |
| 143 | + import doctest |
| 144 | + TestResults = recordtype('TestResults', 'failed, attempted') |
| 145 | + print TestResults(*doctest.testmod()) |
| 146 | + |
0 commit comments