import argparse import functools from typing import List import binascii from dataclasses import dataclass from enum import Enum from typing import Protocol def hexToPaddedBin(hex: str): return "".join([bin(int(char, 16))[2:].zfill(4) for char in hex]) class DataBuffer: def __init__(self, data_str, already_binary=False): if already_binary: self.data = data_str else: self.data = hexToPaddedBin(data_str) def readInt(self, num_bits) -> int: ret_val = int(self.data[0:num_bits], 2) self.data = self.data[num_bits:] return ret_val def readRaw(self, num_bits) -> str: ret_val = self.data[0:num_bits] self.data = self.data[num_bits:] return ret_val def hasData(self, bits=0) -> bool: return len(self.data) >= bits def peek(self) -> str: return self.data[0] def pop(self) -> str: ret_val = self.data[0] self.data = self.data[1:] return ret_val def __str__(self) -> str: return self.data def readIntoNewBuffer(buffer: DataBuffer, num_bits: int) -> DataBuffer: payload = buffer.readRaw(num_bits) return DataBuffer(payload, True) MIN_PACKET_SIZE = 6 + 5 class PacketType(Enum): LITERAL = 4 OPERATOR = 0 @dataclass class ProtoPacket: version: int type: PacketType class Packet(Protocol): proto: ProtoPacket def sumVersions(self) -> int: raise NotImplemented def __repr__(self): return self.__str__() class LiteralPacket(Packet): def __init__(self, proto: ProtoPacket, value: int): self.proto = proto self.value = value def sumVersions(self): return self.proto.version def __str__(self): return f"{self.proto.type} v{self.proto.version} = {self.value}" class OperatorPacket(Packet): def __init__(self, proto: ProtoPacket, children: List[Packet]): self.proto = proto self.children = children def sumVersions(self) -> int: return sum([child.sumVersions() for child in self.children] + [self.proto.version]) def __str__(self): return f"{self.proto.type} v{self.proto.version} # {len(self.children)} [{', '.join([str(child) for child in self.children])}]" def parseLiteral(proto: ProtoPacket, buffer: DataBuffer) -> Packet: sub_buffer = "" while buffer.readRaw(1) != "0": sub_buffer += buffer.readRaw(4) sub_buffer += buffer.readRaw(4) value = int(sub_buffer, 2) return LiteralPacket(proto, value) def parseRawOperator(proto: ProtoPacket, buffer: DataBuffer) -> Packet: # pull children for our operator out of the buffer children = parseBuffer(buffer) return OperatorPacket(proto, children) def parseNestedOperator(proto: ProtoPacket, buffer: DataBuffer, numPackets: int) -> Packet: children = parseNPacketsFromBuffer(buffer, numPackets) return OperatorPacket(proto, children) def parsePacket(protoPacket: ProtoPacket, buffer: DataBuffer) -> Packet: if protoPacket.type == PacketType.LITERAL: return parseLiteral(protoPacket, buffer) else: if buffer.readRaw(1) == "0": return parseRawOperator(protoPacket, readIntoNewBuffer(buffer, buffer.readInt(15))) else: return parseNestedOperator(protoPacket, buffer, buffer.readInt(11)) def _readProto(buffer: DataBuffer) -> ProtoPacket: new_ver = buffer.readInt(3) new_type = PacketType.LITERAL if buffer.readInt(3) == 4 else PacketType.OPERATOR return ProtoPacket(new_ver, new_type) def parseNPacketsFromBuffer(buffer: DataBuffer, numPackets: int) -> List[Packet]: packets = [] for _ in range(0, numPackets): proto = _readProto(buffer) packets.append(parsePacket(proto, buffer)) return packets def parseBuffer(buffer: DataBuffer) -> List[Packet]: packets: List[Packet] = [] while buffer.hasData(MIN_PACKET_SIZE): proto = _readProto(buffer) packets.append(parsePacket(proto, buffer)) return packets def main(buffer: DataBuffer): packets = parseBuffer(buffer) print(packets) print(sum([packet.sumVersions() for packet in packets])) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("ifile", type=argparse.FileType('r')) args = parser.parse_args() superpacket = args.ifile.readline().strip() # convert hex to bin, then trim the leading python binary sig main(DataBuffer(superpacket))