[题解] JXOI2018 守卫

题目描述

输入一个整数 $n$ 和一个整数 $p$ ,你需要求出 $\sum_{i=1}^n \sum_{j=1}^n ij \gcd(i,j) \bmod p$,其中 $\gcd(a, b)$ 表示 $a$ 与 $b$ 的最大公约数

输入格式

一行两个整数 $p$ 和 $n$

输出格式

一行一个整数,表示答案

样例输入

998244353 2000

样例输出

883968974

数据范围及提示

对于 $20\%$的数据,$n \leq 1000​$。

对于 $30\%$ 的数据,$n \leq 5000$。

对于 $60\%$ 的数据,$n \leq 10^6$,时限1s。

对于另外 $20\%$ 的数据,$n \leq 10^9$,时限3s。

对于最后 $20\%$ 的数据,$n \leq 10^{10}$,时限6s。

对于 $100\%$ 的数据,$5 \times 10^8 \leq p \leq 1.1 \times 10^9$ 且 $p$ 为质数。

解题思路

$$
\begin{aligned}
& \sum_{i=1}^n \sum_{j=1}^n ij \gcd(i,j)\
= & \sum_{i=1}^n \sum_{j=1}^n ij \sum_{d \mid i} \sum_{d \mid j} \boldsymbol{\varphi} (d) \
= & \sum_{d=1}^n \boldsymbol{\varphi}(d) \sum_{d \mid i} \sum_{d \mid j} ij \
= & \sum_{d=1}^n \boldsymbol{\varphi}(d) d^2 \left( \sum_{i=1}^{n/k} i \right)^2 \
= & \sum_{d=1}^n \boldsymbol{\varphi}(d) d^2 \sum_{i=1}^{n/k} i^3
\end{aligned}
$$

其中最后一步是由
$$
\left( \sum_{i=1}^n i \right)^2 = \sum_{i=1}^n i^3
$$
得到的,同时
$$
\sum_{i=1}^n i^2 = \frac{n(n+1)(2n+1)}{6}
$$
所以我们可以通过杜教筛筛出 $\sum_{d=1}^n \boldsymbol{\varphi}(d) d^2$ 的值,然后使用整除分块,在 $O(n^{\frac{5}{6}})$ 的时间内算出答案

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 10000000 + 10;
const int MAXM = 2500 + 10;
const int SIZ = 10000000;

long long n, MOD, INV;
long long pri[MAXN], cnt;
long long phi[MAXN], sum1[MAXN];
bool isNotPrime[MAXN];

inline long long mul(long long a, long long b) {return a * b % MOD;}
inline long long add(long long a, long long b) {return (a + b) % MOD;}

inline long long powx(long long a, long long b) {
    long long ans = 1;
    for (; b; b >>= 1) {
        if (b & 1) ans = mul(ans, a);
        a = mul(a, a);
    }
    return ans;
}

inline void init() {
    INV = powx(6, MOD - 2);
    phi[1] = 1;
    for (register int i=2; i<=SIZ; i++) {
        if (!isNotPrime[i]) {
            pri[++cnt] = i;
            phi[i] = i - 1;
        }

        for (register int j=1; j<=cnt; j++) {
            long long m = i * pri[j];
            if (m > SIZ) break;
            isNotPrime[m] = true;

            if (i % pri[j] == 0) {
                phi[m] = phi[i] * pri[j];
                break;
            } else {
                phi[m] = phi[i] * (pri[j] - 1);
            }
        }
    }

    for (register int i=1; i<=SIZ; i++) {
        sum1[i] = add(sum1[i-1], (mul(phi[i], mul(i, i))));
    }
}

inline long long calc2(long long x) {
    x %= MOD;
    return mul(mul(mul(x, x+1), 2*x+1), INV);
}

inline long long calc3(long long x) {
    x %= MOD;
    return mul((x * (x + 1) / 2) % MOD, (x * (x + 1) / 2) % MOD);
}

long long sum2[MAXM];
bool vis[MAXM];

inline long long calc(long long x) {
    if (x < SIZ) return sum1[x];

    int k = n / x;
    if (vis[k]) return sum2[k];

    long long &ans = sum2[k];
    long long r = 0;

    vis[k] = true;
    ans = calc3(x);

    for (register long long l=2; l<=x; l=r+1) {
        r = x / (x / l);
        ans = (ans - ((((calc2(r) - calc2(l-1)) % MOD + MOD) % MOD ) * calc(x / l) % MOD)) % MOD;
        if (ans < 0) ans += MOD;
    }

    return ans;
}

inline long long work () {
    long long ans = 0;
    long long r = 0;

    for (register long long l=1; l<=n; l=r+1){
        r = n / (n / l);
        ans = (ans + ((calc(r) - calc(l-1) + MOD) % MOD * calc3(n / l)) % MOD) % MOD; 
    }

    return ans;
}

int main(){
    scanf("%lld%lld", &MOD, &n);
    init();
    printf("%lld", work());
    return 0;
}