Solution in O(m log n):
def geometric(n,b,m):
T=1
e=b%m
total = 0
while n>0:
if n&1==1:
total = (e*total + T)%m
T = ((e+1)*T)%m
e = (e*e)%m
n = n//2
return total
def efficient_solve(n, m):
ans = 0
for x in range(1, min(n, m) + 1):
k = pow(x, m, m)
s = pow(x, x, m)
times = (n // m) + (x <= n % m)
ans += s * geometric(times, k, m)
ans = ans % m
return ans
geometric calculates the geometric series modulo m, taken from https://stackoverflow.com/a/42033401/3308055
Explanation
N is too large, we need a way to calculate multiple sums results in a single operation.
Note that x ^ x % m = (mk + i) ^ (mk + i) % m
with i < m.
(mk + i) ^ (mk + i) % m = (mk + i) * (mk + i) * (mk + i) * ... * (mk + i)
(mk + i) times
If we started distributing that, almost all results would have at least 1 mk as factor, and mk * whatever % m
will be 0.
The only result without an mk factor will be i * i * i * i * ... * i
(mk + i) times. That is, i^(mk + i)
.
So if n = 5 and m = 3, instead of solving 1^1 + 2^2 + 3^3 + 4^4 + 5^5 % 3
we can solve 1 ^ (0 + 1) + 2 ^ (0 + 2) + 0 ^ (3 + 0) + 1 ^ (3 + 1) + 2 ^ (3 + 2) % m
.
This is good, but we still need to do O(n) operations. Let´s try to group some of these sums. We will group based on i % m
, we have 3 groups:
1 ^ (0 + 1) + 1 ^ (3 + 1)
2 ^ (0 + 2) + 2 ^ (3 + 2)
0 ^ (3 + 0)
How can we calculate the result of each group efficiently?
Note that for each group, we have the same base, and the exponent increases by m for each sum. If we know the result of the first sum (1 ^ (0 + 1)
), how does the next sum (1 ^ (3 + 1)
) change regarding %m ?
1 ^ (3 + 1) % m = (1 ^ 1 % m) * (1 ^ 3 % m) % m
. If n were higher and we had 1 ^ (6 + 1)
in this group, 1 ^ (6 + 1) % m = (1 ^ 1 % m) * (1 ^ 3 % m) * (1 ^ 3 % m) % m
. Note that for each following sum in same group, we just need to add the result of (1 ^ 3 % m)
. More generally, we need to add base ^ m % m
.
How many sums we have in each group? Well, we will have 1 for every n < m, and then 1 more on each group every m numbers. That is, times = (n // m) + (x <= n % m)
.
Let´s call x
the index of the group, which will also be the base of the exponents. We will have min(n, m) groups
Let´s call k
the result of x ^ m % m
. Let´s call s
the result of x ^ x % m
.
The result of solving all the sums for this group will be:
s + s * k + s * k^2 + s * k^3 ... + s * k^(times - 1)
This is equivalent to:
s * (1 + k + k^2 + k^3 ... + k^(times - 1))
And there we have a geometric series, which we can calulate efficiently. With this, we have everything we need to calculate the answer of the problem.