part2.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import argparse
  2. import functools
  3. from typing import Callable, List, Dict
  4. from dataclasses import dataclass
  5. from enum import Enum
  6. from typing import Protocol
  7. def hexToPaddedBin(hex: str):
  8. return "".join([bin(int(char, 16))[2:].zfill(4) for char in hex])
  9. class DataBuffer:
  10. def __init__(self, data_str, already_binary=False):
  11. if already_binary:
  12. self.data = data_str
  13. else:
  14. self.data = hexToPaddedBin(data_str)
  15. def readInt(self, num_bits) -> int:
  16. ret_val = int(self.data[0:num_bits], 2)
  17. self.data = self.data[num_bits:]
  18. return ret_val
  19. def readRaw(self, num_bits) -> str:
  20. ret_val = self.data[0:num_bits]
  21. self.data = self.data[num_bits:]
  22. return ret_val
  23. def hasData(self, bits=0) -> bool:
  24. return len(self.data) >= bits
  25. def peek(self) -> str:
  26. return self.data[0]
  27. def pop(self) -> str:
  28. ret_val = self.data[0]
  29. self.data = self.data[1:]
  30. return ret_val
  31. def __str__(self) -> str:
  32. return self.data
  33. def readIntoNewBuffer(buffer: DataBuffer, num_bits: int) -> DataBuffer:
  34. payload = buffer.readRaw(num_bits)
  35. return DataBuffer(payload, True)
  36. MIN_PACKET_SIZE = 6 + 5
  37. class PacketType(Enum):
  38. LITERAL = 4
  39. SUM = 0
  40. PRODUCT = 1
  41. MINIMUM = 2
  42. MAXIMUM = 3
  43. GREATER_THAN = 5
  44. LESS_THAN = 6
  45. EQUALS = 7
  46. @dataclass
  47. class ProtoPacket:
  48. version: int
  49. type: PacketType
  50. class Packet(Protocol):
  51. proto: ProtoPacket
  52. def sumVersions(self) -> int:
  53. raise NotImplemented
  54. def calculate(self) -> int:
  55. raise NotImplemented
  56. def __repr__(self):
  57. return self.__str__()
  58. class LiteralPacket(Packet):
  59. def __init__(self, proto: ProtoPacket, value: int):
  60. self.proto = proto
  61. self.value = value
  62. def sumVersions(self):
  63. return self.proto.version
  64. def calculate(self) -> int:
  65. return self.value
  66. def __str__(self):
  67. return f"{self.proto.type} v{self.proto.version} = {self.value}"
  68. class OperatorPacket(Packet):
  69. def __init__(self, proto: ProtoPacket, children: List[Packet]):
  70. self.proto = proto
  71. self.children = children
  72. def calculate(self) -> int:
  73. if self.proto.type == PacketType.SUM:
  74. return sum([child.calculate() for child in self.children])
  75. elif self.proto.type == PacketType.PRODUCT:
  76. return functools.reduce(lambda acc, val: acc * val.calculate(), self.children, 1)
  77. elif self.proto.type == PacketType.MINIMUM:
  78. return min(*[child.calculate() for child in self.children]) if len(self.children) > 1 else self.children[0].calculate()
  79. elif self.proto.type == PacketType.MAXIMUM:
  80. return max(*[child.calculate() for child in self.children]) if len(self.children) > 1 else self.children[0].calculate()
  81. elif self.proto.type == PacketType.LESS_THAN:
  82. return int(self.children[0].calculate() < self.children[1].calculate())
  83. elif self.proto.type == PacketType.GREATER_THAN:
  84. return int(self.children[0].calculate() > self.children[1].calculate())
  85. elif self.proto.type == PacketType.EQUALS:
  86. return int(self.children[0].calculate() == self.children[1].calculate())
  87. def sumVersions(self) -> int:
  88. return sum([child.sumVersions() for child in self.children] + [self.proto.version])
  89. def __str__(self):
  90. return f"{self.proto.type} v{self.proto.version} # {len(self.children)} [{', '.join([str(child) for child in self.children])}]"
  91. def parseLiteral(proto: ProtoPacket, buffer: DataBuffer) -> Packet:
  92. sub_buffer = ""
  93. while buffer.readRaw(1) != "0":
  94. sub_buffer += buffer.readRaw(4)
  95. sub_buffer += buffer.readRaw(4)
  96. value = int(sub_buffer, 2)
  97. return LiteralPacket(proto, value)
  98. def parseRawOperator(proto: ProtoPacket, buffer: DataBuffer) -> Packet:
  99. # pull children for our operator out of the buffer
  100. children = parseBuffer(buffer)
  101. return OperatorPacket(proto, children)
  102. def parseNestedOperator(proto: ProtoPacket, buffer: DataBuffer, numPackets: int) -> Packet:
  103. children = parseNPacketsFromBuffer(buffer, numPackets)
  104. return OperatorPacket(proto, children)
  105. def parsePacket(protoPacket: ProtoPacket, buffer: DataBuffer) -> Packet:
  106. if protoPacket.type == PacketType.LITERAL:
  107. return parseLiteral(protoPacket, buffer)
  108. else:
  109. if buffer.readRaw(1) == "0":
  110. return parseRawOperator(protoPacket, readIntoNewBuffer(buffer, buffer.readInt(15)))
  111. else:
  112. return parseNestedOperator(protoPacket, buffer, buffer.readInt(11))
  113. def _readProto(buffer: DataBuffer) -> ProtoPacket:
  114. new_ver = buffer.readInt(3)
  115. new_type = PacketType(buffer.readInt(3))
  116. return ProtoPacket(new_ver, new_type)
  117. def parseNPacketsFromBuffer(buffer: DataBuffer, numPackets: int) -> List[Packet]:
  118. packets = []
  119. for _ in range(0, numPackets):
  120. proto = _readProto(buffer)
  121. packets.append(parsePacket(proto, buffer))
  122. return packets
  123. def parseBuffer(buffer: DataBuffer) -> List[Packet]:
  124. packets: List[Packet] = []
  125. while buffer.hasData(MIN_PACKET_SIZE):
  126. proto = _readProto(buffer)
  127. packets.append(parsePacket(proto, buffer))
  128. return packets
  129. def main(buffer: DataBuffer):
  130. packets = parseBuffer(buffer)
  131. print(packets)
  132. print(packets[0].calculate())
  133. if __name__ == "__main__":
  134. parser = argparse.ArgumentParser()
  135. parser.add_argument("ifile", type=argparse.FileType('r'))
  136. args = parser.parse_args()
  137. superpacket = args.ifile.readline().strip()
  138. # convert hex to bin, then trim the leading python binary sig
  139. main(DataBuffer(superpacket))