part1.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import argparse
  2. import functools
  3. from typing import List
  4. import binascii
  5. from dataclasses import dataclass
  6. from enum import Enum
  7. from typing import Protocol
  8. def hexToPaddedBin(hex: str):
  9. return "".join([bin(int(char, 16))[2:].zfill(4) for char in hex])
  10. class DataBuffer:
  11. def __init__(self, data_str, already_binary=False):
  12. if already_binary:
  13. self.data = data_str
  14. else:
  15. self.data = hexToPaddedBin(data_str)
  16. def readInt(self, num_bits) -> int:
  17. ret_val = int(self.data[0:num_bits], 2)
  18. self.data = self.data[num_bits:]
  19. return ret_val
  20. def readRaw(self, num_bits) -> str:
  21. ret_val = self.data[0:num_bits]
  22. self.data = self.data[num_bits:]
  23. return ret_val
  24. def hasData(self, bits=0) -> bool:
  25. return len(self.data) >= bits
  26. def peek(self) -> str:
  27. return self.data[0]
  28. def pop(self) -> str:
  29. ret_val = self.data[0]
  30. self.data = self.data[1:]
  31. return ret_val
  32. def __str__(self) -> str:
  33. return self.data
  34. def readIntoNewBuffer(buffer: DataBuffer, num_bits: int) -> DataBuffer:
  35. payload = buffer.readRaw(num_bits)
  36. return DataBuffer(payload, True)
  37. MIN_PACKET_SIZE = 6 + 5
  38. class PacketType(Enum):
  39. LITERAL = 4
  40. OPERATOR = 0
  41. @dataclass
  42. class ProtoPacket:
  43. version: int
  44. type: PacketType
  45. class Packet(Protocol):
  46. proto: ProtoPacket
  47. def sumVersions(self) -> int:
  48. raise NotImplemented
  49. def __repr__(self):
  50. return self.__str__()
  51. class LiteralPacket(Packet):
  52. def __init__(self, proto: ProtoPacket, value: int):
  53. self.proto = proto
  54. self.value = value
  55. def sumVersions(self):
  56. return self.proto.version
  57. def __str__(self):
  58. return f"{self.proto.type} v{self.proto.version} = {self.value}"
  59. class OperatorPacket(Packet):
  60. def __init__(self, proto: ProtoPacket, children: List[Packet]):
  61. self.proto = proto
  62. self.children = children
  63. def sumVersions(self) -> int:
  64. return sum([child.sumVersions() for child in self.children] + [self.proto.version])
  65. def __str__(self):
  66. return f"{self.proto.type} v{self.proto.version} # {len(self.children)} [{', '.join([str(child) for child in self.children])}]"
  67. def parseLiteral(proto: ProtoPacket, buffer: DataBuffer) -> Packet:
  68. sub_buffer = ""
  69. while buffer.readRaw(1) != "0":
  70. sub_buffer += buffer.readRaw(4)
  71. sub_buffer += buffer.readRaw(4)
  72. value = int(sub_buffer, 2)
  73. return LiteralPacket(proto, value)
  74. def parseRawOperator(proto: ProtoPacket, buffer: DataBuffer) -> Packet:
  75. # pull children for our operator out of the buffer
  76. children = parseBuffer(buffer)
  77. return OperatorPacket(proto, children)
  78. def parseNestedOperator(proto: ProtoPacket, buffer: DataBuffer, numPackets: int) -> Packet:
  79. children = parseNPacketsFromBuffer(buffer, numPackets)
  80. return OperatorPacket(proto, children)
  81. def parsePacket(protoPacket: ProtoPacket, buffer: DataBuffer) -> Packet:
  82. if protoPacket.type == PacketType.LITERAL:
  83. return parseLiteral(protoPacket, buffer)
  84. else:
  85. if buffer.readRaw(1) == "0":
  86. return parseRawOperator(protoPacket, readIntoNewBuffer(buffer, buffer.readInt(15)))
  87. else:
  88. return parseNestedOperator(protoPacket, buffer, buffer.readInt(11))
  89. def _readProto(buffer: DataBuffer) -> ProtoPacket:
  90. new_ver = buffer.readInt(3)
  91. new_type = PacketType.LITERAL if buffer.readInt(3) == 4 else PacketType.OPERATOR
  92. return ProtoPacket(new_ver, new_type)
  93. def parseNPacketsFromBuffer(buffer: DataBuffer, numPackets: int) -> List[Packet]:
  94. packets = []
  95. for _ in range(0, numPackets):
  96. proto = _readProto(buffer)
  97. packets.append(parsePacket(proto, buffer))
  98. return packets
  99. def parseBuffer(buffer: DataBuffer) -> List[Packet]:
  100. packets: List[Packet] = []
  101. while buffer.hasData(MIN_PACKET_SIZE):
  102. proto = _readProto(buffer)
  103. packets.append(parsePacket(proto, buffer))
  104. return packets
  105. def main(buffer: DataBuffer):
  106. packets = parseBuffer(buffer)
  107. print(packets)
  108. print(sum([packet.sumVersions() for packet in packets]))
  109. if __name__ == "__main__":
  110. parser = argparse.ArgumentParser()
  111. parser.add_argument("ifile", type=argparse.FileType('r'))
  112. args = parser.parse_args()
  113. superpacket = args.ifile.readline().strip()
  114. # convert hex to bin, then trim the leading python binary sig
  115. main(DataBuffer(superpacket))