1

I was having a look at practice problems in Codechef where I found this one. I am new at Python. I have written the code below using Python3. I am continuously getting 'Exceeding Timelimit' Error. Looking up for some optimisation for the code.

The problem statement is as follows:

Given n and m, calculate 1^1 + 2^2 + 3^3 + ... + n^n modulo m.

Input: The first line contains 1 ≤ t ≤ 10, the number of test cases. Then the test case definitions follow. Each test case is of the form: 1 ≤ n 1018, 1 ≤ m ≤ 200000

Example

Input:
6
1 100000
2 100000
3 100000
4 100000
5 100000
6 100000

Output:
1
5
32
288
3413
50069

And here is my code:

t = int(input())
for j in range (1,t+1):
    ans = 0
    n, m = [int(x) for x in input().split()]
    for i in range (1,n+1):
        ans = (ans + pow(i,i))%m
    print (ans)

Thank you.

Shubham
  • 27
  • 5

2 Answers2

0

Solution in O(m log n):

def geometric(n,b,m):
    T=1
    e=b%m
    total = 0
    while n>0:
        if n&1==1:
            total = (e*total + T)%m
        T = ((e+1)*T)%m
        e = (e*e)%m
        n = n//2
    return total

def efficient_solve(n, m):
    ans = 0
    for x in range(1, min(n, m) + 1):
        k = pow(x, m, m)
        s = pow(x, x, m)

        times = (n // m) + (x <= n % m)

        ans += s * geometric(times, k, m)
        ans = ans % m
    return ans  

geometric calculates the geometric series modulo m, taken from https://stackoverflow.com/a/42033401/3308055

Explanation

N is too large, we need a way to calculate multiple sums results in a single operation.

Note that x ^ x % m = (mk + i) ^ (mk + i) % m with i < m.

(mk + i) ^ (mk + i) % m = (mk + i) * (mk + i) * (mk + i) * ... * (mk + i) (mk + i) times

If we started distributing that, almost all results would have at least 1 mk as factor, and mk * whatever % m will be 0.

The only result without an mk factor will be i * i * i * i * ... * i (mk + i) times. That is, i^(mk + i).

So if n = 5 and m = 3, instead of solving 1^1 + 2^2 + 3^3 + 4^4 + 5^5 % 3 we can solve 1 ^ (0 + 1) + 2 ^ (0 + 2) + 0 ^ (3 + 0) + 1 ^ (3 + 1) + 2 ^ (3 + 2) % m.

This is good, but we still need to do O(n) operations. Let´s try to group some of these sums. We will group based on i % m, we have 3 groups:

  • 1 ^ (0 + 1) + 1 ^ (3 + 1)
  • 2 ^ (0 + 2) + 2 ^ (3 + 2)
  • 0 ^ (3 + 0)

How can we calculate the result of each group efficiently? Note that for each group, we have the same base, and the exponent increases by m for each sum. If we know the result of the first sum (1 ^ (0 + 1)), how does the next sum (1 ^ (3 + 1)) change regarding %m ?

1 ^ (3 + 1) % m = (1 ^ 1 % m) * (1 ^ 3 % m) % m. If n were higher and we had 1 ^ (6 + 1) in this group, 1 ^ (6 + 1) % m = (1 ^ 1 % m) * (1 ^ 3 % m) * (1 ^ 3 % m) % m. Note that for each following sum in same group, we just need to add the result of (1 ^ 3 % m). More generally, we need to add base ^ m % m.

How many sums we have in each group? Well, we will have 1 for every n < m, and then 1 more on each group every m numbers. That is, times = (n // m) + (x <= n % m).

Let´s call x the index of the group, which will also be the base of the exponents. We will have min(n, m) groups

Let´s call k the result of x ^ m % m. Let´s call s the result of x ^ x % m.

The result of solving all the sums for this group will be:

s + s * k + s * k^2 + s * k^3 ... + s * k^(times - 1)

This is equivalent to:

s * (1 + k + k^2 + k^3 ... + k^(times - 1))

And there we have a geometric series, which we can calulate efficiently. With this, we have everything we need to calculate the answer of the problem.

juvian
  • 15,212
  • 2
  • 30
  • 35
0

a numpy solution

import numpy as np

n = 6
m = 100000
arange = np.arange(1, n+1)
power = np.power(arange, arange)
result = np.cumsum(power) % m
bobrobbob
  • 1,200
  • 10
  • 21