import argparse import functools from typing import NamedTuple, List, Dict, Set parser = argparse.ArgumentParser() parser.add_argument("ifile", type=argparse.FileType('r')) args = parser.parse_args() class Pairing(NamedTuple): signals: List[Set[str]] digits: List[str] SEGMENTS_TO_INT = { frozenset('abcefg'): 0, frozenset('cf'): 1, frozenset('acdeg'): 2, frozenset('acdfg'): 3, frozenset('bcdf'): 4, frozenset('abdfg'): 5, frozenset('abdefg'): 6, frozenset('acf'): 7, frozenset('abcdefg'): 8, frozenset('abcdfg'): 9 } lines = [Pairing([set(sig) for sig in line.split('|')[0].strip().split(' ')], line.split('|')[1].strip().split(' ')) for line in args.ifile.readlines()] def locateByLengthAndIntersect(intsersectWith: Set[str], length: int, signals: Set[Set[str]]) -> Set[str]: candidates = list(filter(lambda sig: intsersectWith.issubset(sig), filter(lambda sig: len(sig) == length, signals))) assert(len(candidates) == 1) return candidates[0] def locateByLengthAndDisjoint(disjointWith: Set[str], length: int, signals: Set[Set[str]]) -> Set[str]: candidates = list(filter(lambda sig: disjointWith.isdisjoint(sig), filter(lambda sig: len(sig) == length, signals))) assert(len(candidates) == 1) return candidates[0] def invertDictionary(dict: Dict): return {v: k for (k, v) in dict.items()} def computeLinks(signals: List[Set[str]]) -> Dict[str, str]: segmentToSignal: Dict[str, str] = dict() (one, seven, four, eight) = list(map(lambda x: set(x), [functools.reduce(lambda acc, val: val if len(val) == length else acc, signals, None) for length in (2, 3, 4, 7)])) for digit in (one, seven, four, eight): signals.remove(digit) segmentToSignal['a'] = (seven - one).pop() three = locateByLengthAndIntersect(one, 5, signals) signals.remove(three) segmentToSignal['d'] = (three.intersection(four) - one).pop() zero = locateByLengthAndDisjoint(set(segmentToSignal['d']), 6, signals) signals.remove(zero) nine = locateByLengthAndIntersect(one, 6, signals) signals.remove(nine) six = locateByLengthAndIntersect(set(segmentToSignal['d']), 6, signals) signals.remove(six) segmentToSignal['d'] = (eight - zero).pop() five = list(filter(lambda sig: len(sig - six) == 0, filter(lambda sig: len(sig) == 5, signals)))[0] signals.remove(five) segmentToSignal['e'] = (six - five).pop() segmentToSignal['f'] = one.intersection(five).pop() segmentToSignal['c'] = (one - set(segmentToSignal['f'])).pop() segmentToSignal['b'] = (five - three).pop() segmentToSignal['g'] = (five - four - seven).pop() return segmentToSignal def calculateDisplay(digits: List[str], mapping: Dict[str, str]): # print(f"calculating {digits}") return sum( [SEGMENTS_TO_INT[frozenset([mapping[sig] for sig in digit])] * (pow(10, 3 - pos)) for pos, digit in enumerate(digits)] ) print(sum(map(lambda pairing: calculateDisplay(pairing.digits, invertDictionary(computeLinks(pairing.signals))), lines)))