# [题解] JXOI2018 守卫

### 样例输入

998244353 2000


### 样例输出

883968974


### 解题思路

\begin{aligned} & \sum_{i=1}^n \sum_{j=1}^n ij \gcd(i,j)\ = & \sum_{i=1}^n \sum_{j=1}^n ij \sum_{d \mid i} \sum_{d \mid j} \boldsymbol{\varphi} (d) \ = & \sum_{d=1}^n \boldsymbol{\varphi}(d) \sum_{d \mid i} \sum_{d \mid j} ij \ = & \sum_{d=1}^n \boldsymbol{\varphi}(d) d^2 \left( \sum_{i=1}^{n/k} i \right)^2 \ = & \sum_{d=1}^n \boldsymbol{\varphi}(d) d^2 \sum_{i=1}^{n/k} i^3 \end{aligned}

$$\left( \sum_{i=1}^n i \right)^2 = \sum_{i=1}^n i^3$$

$$\sum_{i=1}^n i^2 = \frac{n(n+1)(2n+1)}{6}$$

### 代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 10000000 + 10;
const int MAXM = 2500 + 10;
const int SIZ = 10000000;

long long n, MOD, INV;
long long pri[MAXN], cnt;
long long phi[MAXN], sum1[MAXN];
bool isNotPrime[MAXN];

inline long long mul(long long a, long long b) {return a * b % MOD;}
inline long long add(long long a, long long b) {return (a + b) % MOD;}

inline long long powx(long long a, long long b) {
long long ans = 1;
for (; b; b >>= 1) {
if (b & 1) ans = mul(ans, a);
a = mul(a, a);
}
return ans;
}

inline void init() {
INV = powx(6, MOD - 2);
phi[1] = 1;
for (register int i=2; i<=SIZ; i++) {
if (!isNotPrime[i]) {
pri[++cnt] = i;
phi[i] = i - 1;
}

for (register int j=1; j<=cnt; j++) {
long long m = i * pri[j];
if (m > SIZ) break;
isNotPrime[m] = true;

if (i % pri[j] == 0) {
phi[m] = phi[i] * pri[j];
break;
} else {
phi[m] = phi[i] * (pri[j] - 1);
}
}
}

for (register int i=1; i<=SIZ; i++) {
sum1[i] = add(sum1[i-1], (mul(phi[i], mul(i, i))));
}
}

inline long long calc2(long long x) {
x %= MOD;
return mul(mul(mul(x, x+1), 2*x+1), INV);
}

inline long long calc3(long long x) {
x %= MOD;
return mul((x * (x + 1) / 2) % MOD, (x * (x + 1) / 2) % MOD);
}

long long sum2[MAXM];
bool vis[MAXM];

inline long long calc(long long x) {
if (x < SIZ) return sum1[x];

int k = n / x;
if (vis[k]) return sum2[k];

long long &ans = sum2[k];
long long r = 0;

vis[k] = true;
ans = calc3(x);

for (register long long l=2; l<=x; l=r+1) {
r = x / (x / l);
ans = (ans - ((((calc2(r) - calc2(l-1)) % MOD + MOD) % MOD ) * calc(x / l) % MOD)) % MOD;
if (ans < 0) ans += MOD;
}

return ans;
}

inline long long work () {
long long ans = 0;
long long r = 0;

for (register long long l=1; l<=n; l=r+1){
r = n / (n / l);
ans = (ans + ((calc(r) - calc(l-1) + MOD) % MOD * calc3(n / l)) % MOD) % MOD;
}

return ans;
}

int main(){
scanf("%lld%lld", &MOD, &n);
init();
printf("%lld", work());
return 0;
}