1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
from sys import stdin
def index(m, n, i, j, c):
return (m[i][j], m[n-j-1][i], m[n-i-1][n-j-1], m[j][n-i-1])[c]
def rot(m, c):
n = len(m)
if n != len(m[0]): raise ValueError
for i in range(n):
yield ''.join(index(m, n, i-j, j, c) for j in range(i+1))
for i in range(1, n):
yield ''.join(index(m, n, n-j-1, i+j, c) for j in range(n-i))
reverse_all = lambda strings: (s[::-1] for s in strings)
count = lambda strings: sum(s.count('XMAS') for s in strings)
lines = tuple(map(str.strip, stdin.readlines()))
columns = tuple(''.join(column) for column in zip(*lines))
print((count(lines) + count(reverse_all(lines))
+ count(columns) + count(reverse_all(columns))
+ sum(count(rot(lines, c)) for c in range(4))))
|