123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- import argparse
- import functools
- from typing import Callable, List, Dict
- from dataclasses import dataclass
- from enum import Enum
- from typing import Protocol
- def hexToPaddedBin(hex: str):
- return "".join([bin(int(char, 16))[2:].zfill(4) for char in hex])
- class DataBuffer:
- def __init__(self, data_str, already_binary=False):
- if already_binary:
- self.data = data_str
- else:
- self.data = hexToPaddedBin(data_str)
- def readInt(self, num_bits) -> int:
- ret_val = int(self.data[0:num_bits], 2)
- self.data = self.data[num_bits:]
- return ret_val
- def readRaw(self, num_bits) -> str:
- ret_val = self.data[0:num_bits]
- self.data = self.data[num_bits:]
- return ret_val
- def hasData(self, bits=0) -> bool:
- return len(self.data) >= bits
- def peek(self) -> str:
- return self.data[0]
- def pop(self) -> str:
- ret_val = self.data[0]
- self.data = self.data[1:]
- return ret_val
- def __str__(self) -> str:
- return self.data
- def readIntoNewBuffer(buffer: DataBuffer, num_bits: int) -> DataBuffer:
- payload = buffer.readRaw(num_bits)
- return DataBuffer(payload, True)
- MIN_PACKET_SIZE = 6 + 5
- class PacketType(Enum):
- LITERAL = 4
- SUM = 0
- PRODUCT = 1
- MINIMUM = 2
- MAXIMUM = 3
- GREATER_THAN = 5
- LESS_THAN = 6
- EQUALS = 7
- @dataclass
- class ProtoPacket:
- version: int
- type: PacketType
- class Packet(Protocol):
- proto: ProtoPacket
- def sumVersions(self) -> int:
- raise NotImplemented
- def calculate(self) -> int:
- raise NotImplemented
- def __repr__(self):
- return self.__str__()
- class LiteralPacket(Packet):
- def __init__(self, proto: ProtoPacket, value: int):
- self.proto = proto
- self.value = value
- def sumVersions(self):
- return self.proto.version
- def calculate(self) -> int:
- return self.value
- def __str__(self):
- return f"{self.proto.type} v{self.proto.version} = {self.value}"
- class OperatorPacket(Packet):
- def __init__(self, proto: ProtoPacket, children: List[Packet]):
- self.proto = proto
- self.children = children
- def calculate(self) -> int:
- if self.proto.type == PacketType.SUM:
- return sum([child.calculate() for child in self.children])
- elif self.proto.type == PacketType.PRODUCT:
- return functools.reduce(lambda acc, val: acc * val.calculate(), self.children, 1)
- elif self.proto.type == PacketType.MINIMUM:
- return min(*[child.calculate() for child in self.children]) if len(self.children) > 1 else self.children[0].calculate()
- elif self.proto.type == PacketType.MAXIMUM:
- return max(*[child.calculate() for child in self.children]) if len(self.children) > 1 else self.children[0].calculate()
- elif self.proto.type == PacketType.LESS_THAN:
- return int(self.children[0].calculate() < self.children[1].calculate())
- elif self.proto.type == PacketType.GREATER_THAN:
- return int(self.children[0].calculate() > self.children[1].calculate())
- elif self.proto.type == PacketType.EQUALS:
- return int(self.children[0].calculate() == self.children[1].calculate())
-
- def sumVersions(self) -> int:
- return sum([child.sumVersions() for child in self.children] + [self.proto.version])
- def __str__(self):
- return f"{self.proto.type} v{self.proto.version} # {len(self.children)} [{', '.join([str(child) for child in self.children])}]"
- def parseLiteral(proto: ProtoPacket, buffer: DataBuffer) -> Packet:
- sub_buffer = ""
- while buffer.readRaw(1) != "0":
- sub_buffer += buffer.readRaw(4)
- sub_buffer += buffer.readRaw(4)
- value = int(sub_buffer, 2)
- return LiteralPacket(proto, value)
- def parseRawOperator(proto: ProtoPacket, buffer: DataBuffer) -> Packet:
- # pull children for our operator out of the buffer
- children = parseBuffer(buffer)
- return OperatorPacket(proto, children)
- def parseNestedOperator(proto: ProtoPacket, buffer: DataBuffer, numPackets: int) -> Packet:
- children = parseNPacketsFromBuffer(buffer, numPackets)
- return OperatorPacket(proto, children)
- def parsePacket(protoPacket: ProtoPacket, buffer: DataBuffer) -> Packet:
- if protoPacket.type == PacketType.LITERAL:
- return parseLiteral(protoPacket, buffer)
- else:
- if buffer.readRaw(1) == "0":
- return parseRawOperator(protoPacket, readIntoNewBuffer(buffer, buffer.readInt(15)))
- else:
- return parseNestedOperator(protoPacket, buffer, buffer.readInt(11))
- def _readProto(buffer: DataBuffer) -> ProtoPacket:
- new_ver = buffer.readInt(3)
- new_type = PacketType(buffer.readInt(3))
- return ProtoPacket(new_ver, new_type)
- def parseNPacketsFromBuffer(buffer: DataBuffer, numPackets: int) -> List[Packet]:
- packets = []
- for _ in range(0, numPackets):
- proto = _readProto(buffer)
- packets.append(parsePacket(proto, buffer))
- return packets
- def parseBuffer(buffer: DataBuffer) -> List[Packet]:
- packets: List[Packet] = []
- while buffer.hasData(MIN_PACKET_SIZE):
- proto = _readProto(buffer)
- packets.append(parsePacket(proto, buffer))
- return packets
- def main(buffer: DataBuffer):
- packets = parseBuffer(buffer)
- print(packets)
- print(packets[0].calculate())
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("ifile", type=argparse.FileType('r'))
- args = parser.parse_args()
- superpacket = args.ifile.readline().strip()
- # convert hex to bin, then trim the leading python binary sig
- main(DataBuffer(superpacket))
|