part2.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import argparse
  2. import functools
  3. from typing import NamedTuple, List, Dict, Set
  4. parser = argparse.ArgumentParser()
  5. parser.add_argument("ifile", type=argparse.FileType('r'))
  6. args = parser.parse_args()
  7. class Pairing(NamedTuple):
  8. signals: List[Set[str]]
  9. digits: List[str]
  10. SEGMENTS_TO_INT = {
  11. frozenset('abcefg'): 0,
  12. frozenset('cf'): 1,
  13. frozenset('acdeg'): 2,
  14. frozenset('acdfg'): 3,
  15. frozenset('bcdf'): 4,
  16. frozenset('abdfg'): 5,
  17. frozenset('abdefg'): 6,
  18. frozenset('acf'): 7,
  19. frozenset('abcdefg'): 8,
  20. frozenset('abcdfg'): 9
  21. }
  22. lines = [Pairing([set(sig) for sig in line.split('|')[0].strip().split(' ')], line.split('|')[1].strip().split(' ')) for line in args.ifile.readlines()]
  23. def locateByLengthAndIntersect(intsersectWith: Set[str], length: int, signals: Set[Set[str]]) -> Set[str]:
  24. candidates = list(filter(lambda sig: intsersectWith.issubset(sig),
  25. filter(lambda sig: len(sig) == length, signals)))
  26. assert(len(candidates) == 1)
  27. return candidates[0]
  28. def locateByLengthAndDisjoint(disjointWith: Set[str], length: int, signals: Set[Set[str]]) -> Set[str]:
  29. candidates = list(filter(lambda sig: disjointWith.isdisjoint(sig),
  30. filter(lambda sig: len(sig) == length, signals)))
  31. assert(len(candidates) == 1)
  32. return candidates[0]
  33. def invertDictionary(dict: Dict):
  34. return {v: k for (k, v) in dict.items()}
  35. def computeLinks(signals: List[Set[str]]) -> Dict[str, str]:
  36. segmentToSignal: Dict[str, str] = dict()
  37. (one, seven, four, eight) = list(map(lambda x: set(x),
  38. [functools.reduce(lambda acc, val: val if len(val) == length else acc,
  39. signals,
  40. None)
  41. for length in (2, 3, 4, 7)]))
  42. for digit in (one, seven, four, eight):
  43. signals.remove(digit)
  44. segmentToSignal['a'] = (seven - one).pop()
  45. three = locateByLengthAndIntersect(one, 5, signals)
  46. signals.remove(three)
  47. segmentToSignal['d'] = (three.intersection(four) - one).pop()
  48. zero = locateByLengthAndDisjoint(set(segmentToSignal['d']), 6, signals)
  49. signals.remove(zero)
  50. nine = locateByLengthAndIntersect(one, 6, signals)
  51. signals.remove(nine)
  52. six = locateByLengthAndIntersect(set(segmentToSignal['d']), 6, signals)
  53. signals.remove(six)
  54. segmentToSignal['d'] = (eight - zero).pop()
  55. five = list(filter(lambda sig: len(sig - six) == 0, filter(lambda sig: len(sig) == 5, signals)))[0]
  56. signals.remove(five)
  57. segmentToSignal['e'] = (six - five).pop()
  58. segmentToSignal['f'] = one.intersection(five).pop()
  59. segmentToSignal['c'] = (one - set(segmentToSignal['f'])).pop()
  60. segmentToSignal['b'] = (five - three).pop()
  61. segmentToSignal['g'] = (five - four - seven).pop()
  62. return segmentToSignal
  63. def calculateDisplay(digits: List[str], mapping: Dict[str, str]):
  64. # print(f"calculating {digits}")
  65. return sum(
  66. [SEGMENTS_TO_INT[frozenset([mapping[sig] for sig in digit])] * (pow(10, 3 - pos)) for pos, digit in enumerate(digits)]
  67. )
  68. print(sum(map(lambda pairing: calculateDisplay(pairing.digits, invertDictionary(computeLinks(pairing.signals))), lines)))