import argparse import functools from typing import Callable, List, Dict from dataclasses import dataclass from enum import Enum from typing import Protocol def hexToPaddedBin(hex: str): return "".join([bin(int(char, 16))[2:].zfill(4) for char in hex]) class DataBuffer: def __init__(self, data_str, already_binary=False): if already_binary: self.data = data_str else: self.data = hexToPaddedBin(data_str) def readInt(self, num_bits) -> int: ret_val = int(self.data[0:num_bits], 2) self.data = self.data[num_bits:] return ret_val def readRaw(self, num_bits) -> str: ret_val = self.data[0:num_bits] self.data = self.data[num_bits:] return ret_val def hasData(self, bits=0) -> bool: return len(self.data) >= bits def peek(self) -> str: return self.data[0] def pop(self) -> str: ret_val = self.data[0] self.data = self.data[1:] return ret_val def __str__(self) -> str: return self.data def readIntoNewBuffer(buffer: DataBuffer, num_bits: int) -> DataBuffer: payload = buffer.readRaw(num_bits) return DataBuffer(payload, True) MIN_PACKET_SIZE = 6 + 5 class PacketType(Enum): LITERAL = 4 SUM = 0 PRODUCT = 1 MINIMUM = 2 MAXIMUM = 3 GREATER_THAN = 5 LESS_THAN = 6 EQUALS = 7 @dataclass class ProtoPacket: version: int type: PacketType class Packet(Protocol): proto: ProtoPacket def sumVersions(self) -> int: raise NotImplemented def calculate(self) -> int: raise NotImplemented def __repr__(self): return self.__str__() class LiteralPacket(Packet): def __init__(self, proto: ProtoPacket, value: int): self.proto = proto self.value = value def sumVersions(self): return self.proto.version def calculate(self) -> int: return self.value def __str__(self): return f"{self.proto.type} v{self.proto.version} = {self.value}" class OperatorPacket(Packet): def __init__(self, proto: ProtoPacket, children: List[Packet]): self.proto = proto self.children = children def calculate(self) -> int: if self.proto.type == PacketType.SUM: return sum([child.calculate() for child in self.children]) elif self.proto.type == PacketType.PRODUCT: return functools.reduce(lambda acc, val: acc * val.calculate(), self.children, 1) elif self.proto.type == PacketType.MINIMUM: return min(*[child.calculate() for child in self.children]) if len(self.children) > 1 else self.children[0].calculate() elif self.proto.type == PacketType.MAXIMUM: return max(*[child.calculate() for child in self.children]) if len(self.children) > 1 else self.children[0].calculate() elif self.proto.type == PacketType.LESS_THAN: return int(self.children[0].calculate() < self.children[1].calculate()) elif self.proto.type == PacketType.GREATER_THAN: return int(self.children[0].calculate() > self.children[1].calculate()) elif self.proto.type == PacketType.EQUALS: return int(self.children[0].calculate() == self.children[1].calculate()) def sumVersions(self) -> int: return sum([child.sumVersions() for child in self.children] + [self.proto.version]) def __str__(self): return f"{self.proto.type} v{self.proto.version} # {len(self.children)} [{', '.join([str(child) for child in self.children])}]" def parseLiteral(proto: ProtoPacket, buffer: DataBuffer) -> Packet: sub_buffer = "" while buffer.readRaw(1) != "0": sub_buffer += buffer.readRaw(4) sub_buffer += buffer.readRaw(4) value = int(sub_buffer, 2) return LiteralPacket(proto, value) def parseRawOperator(proto: ProtoPacket, buffer: DataBuffer) -> Packet: # pull children for our operator out of the buffer children = parseBuffer(buffer) return OperatorPacket(proto, children) def parseNestedOperator(proto: ProtoPacket, buffer: DataBuffer, numPackets: int) -> Packet: children = parseNPacketsFromBuffer(buffer, numPackets) return OperatorPacket(proto, children) def parsePacket(protoPacket: ProtoPacket, buffer: DataBuffer) -> Packet: if protoPacket.type == PacketType.LITERAL: return parseLiteral(protoPacket, buffer) else: if buffer.readRaw(1) == "0": return parseRawOperator(protoPacket, readIntoNewBuffer(buffer, buffer.readInt(15))) else: return parseNestedOperator(protoPacket, buffer, buffer.readInt(11)) def _readProto(buffer: DataBuffer) -> ProtoPacket: new_ver = buffer.readInt(3) new_type = PacketType(buffer.readInt(3)) return ProtoPacket(new_ver, new_type) def parseNPacketsFromBuffer(buffer: DataBuffer, numPackets: int) -> List[Packet]: packets = [] for _ in range(0, numPackets): proto = _readProto(buffer) packets.append(parsePacket(proto, buffer)) return packets def parseBuffer(buffer: DataBuffer) -> List[Packet]: packets: List[Packet] = [] while buffer.hasData(MIN_PACKET_SIZE): proto = _readProto(buffer) packets.append(parsePacket(proto, buffer)) return packets def main(buffer: DataBuffer): packets = parseBuffer(buffer) print(packets) print(packets[0].calculate()) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("ifile", type=argparse.FileType('r')) args = parser.parse_args() superpacket = args.ifile.readline().strip() # convert hex to bin, then trim the leading python binary sig main(DataBuffer(superpacket))