diff --git a/nl80211.py b/nl80211.py index 6aaeeab..2e93dae 100755 --- a/nl80211.py +++ b/nl80211.py @@ -248,24 +248,33 @@ class GenericNetlink(object): 7: ('mcast_groups', Array(_mcast_grp_attr)), }) - _msgtypes = { - 'nlctrl': [_ctrl_attr, 0x10], - } + _msgtypes = [ + [0x10, 'nlctrl', _ctrl_attr], + ] CTRL_CMD_NEWFAMILY = 0x01 CTRL_CMD_GETFAMILY = 0x03 def __init__(self): self._netlink = Netlink() + self._UpdateMsgTypes() self.Send('nlctrl', self._netlink.NLMSG_F_DUMP, self.CTRL_CMD_GETFAMILY, 1, '') for msg in self.Recv(): msgtype, attrs = msg assert msgtype == self.CTRL_CMD_NEWFAMILY, msgtype family_name = attrs['family_name'].rstrip('\0') - self._msgtypes.setdefault(family_name, [None, None])[1] = attrs['family_id'] + if family_name in self._msgtypes_by_name: + assert attrs['family_id'] == self._msgtypes_by_name[family_name][0], attrs['family_id'] + else: + self._msgtypes.append([attrs['family_id'], family_name, None]) + self._UpdateMsgTypes() + + def _UpdateMsgTypes(self): + self._msgtypes_by_id = dict((i[0], i) for i in self._msgtypes) + self._msgtypes_by_name = dict((i[1], i) for i in self._msgtypes) def RegisterMsgType(self, family_name, parser): - self._msgtypes[family_name][0] = parser + self._msgtypes_by_name[family_name][2] = parser def Send(self, msgtype, flags, cmd, version, msg): accumulator = Accumulator() @@ -275,13 +284,13 @@ class GenericNetlink(object): version=version, reserved=0) accumulator.Append(msg) - msgtype_id = self._msgtypes[msgtype][1] + msgtype_id = self._msgtypes_by_name[msgtype][0] self._netlink.Send(msgtype_id, flags, str(accumulator)) def Recv(self): for msgtype, iterator in self._netlink.Recv(): genlhdr = self._genlmsghdr.Unpack(iterator) - parser = [v[0] for v in self._msgtypes.itervalues() if v[1] == msgtype][0] + parser = self._msgtypes_by_id[msgtype][2] yield (genlhdr['cmd'], parser.Unpack(iterator))