跳到主要内容

P8110 矩阵

看似矩阵快速幂,实则推式子 + 普通快速幂。

难度普及+/提高
算法快速幂
日期2025-09-24

题意简述

给定两个长度为 nn 的数列 {a1,a2,,an},{b1,b2,,bn}\{a_1,a_2,\dots,a_n\},\{b_1,b_2,\dots,b_n\}n×nn\times n 的矩阵 AA 满足 Aij=ai×bjA_{ij}=a_i\times b_j

求:(i=1nj=1nAijk)modM(\sum_{i=1}^n\sum_{j=1}^nA_{ij}^k)\bmod M,其中 M=998244353M=998244353

范围:1n1051\leq n\leq10^50kM0\leq k\leq Mai,biZa_i,b_i\in\mathbb{Z}ai,bi109|a_i|,|b_i|\leq10^9

M=998244353109M=998244353\approx10^9

本题中 A0A^0 为单位矩阵。

思路

如果用常规矩阵快速幂,求一次矩阵乘法 O(n2)O(n^2),总共 O(nlogn2)O(n\log n^2),显然不行。

考虑推式子。


Aij2=Aij×Aij=k=1naibk×akbj=aibjk=1nakbkA^{2}_{ij}=A_{ij}\times A_{ij}=\sum_{k=1}^n a_ib_k\times a_kb_j=a_ib_j\sum_{k=1}^na_kb_k

不妨记 σ=k=1nakbk\sigma=\sum_{k=1}^na_kb_k

Aij2=aibjσA^{2}_{ij}=a_ib_j\sigma

类似地,Aij4=Aij2×Aij2=k=1naibkσ×akbjσ=aibjσ2k=1naibj=aibjσ3A^{4}_{ij}=A^{2}_{ij}\times A^{2}_{ij}=\sum_{k=1}^na_ib_k\sigma\times a_kb_j\sigma=a_ib_j\sigma^2\sum_{k=1}^na_ib_j=a_ib_j\sigma^3

不难发现,Aijk=aibjσk1A^k_{ij}=a_ib_j\sigma^{k-1}

故:

i=1nj=1nAijk=i=1nj=1naibjσk1=(i=1nai)(j=1nbj)(i=1naibi)k1.\sum_{i=1}^n\sum_{j=1}^nA_{ij}^k=\sum_{i=1}^n\sum_{j=1}^na_ib_j\sigma^{k-1}=(\sum_{i=1}^na_i)\cdot(\sum_{j=1}^nb_j)\cdot(\sum_{i=1}^na_ib_i)^{k-1}.

所以我们只需要用 O(n)O(n) 的时间统计三个 \sum 的值,然后对 σ\sigma 用快速幂,总计 O(n+logn)=O(n)O(n+\log n)=O(n),可以通过本题。

编码

  • 注意 ai,bia_i,b_i 可能小于零,需要在输入时取一次 ai(aimodM+M)modMa_i\gets(a_i\bmod M+M)\bmod M

  • 考虑边界情况,k=0k=0ans=n\text{ans}=n

用时 73 ms 内存 1.92 MB
/*
* P8110 [Cnoi2021] 矩阵
*/
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll MOD = 998244353;
const int MAXN = 1e5 + 1;
ll a[MAXN], b[MAXN];
ll qpow(ll a, ll n) {
ll res = 1;
while (n) {
if (n & 1) res = (res * a) % MOD;
a = (a * a) % MOD;
n >>= 1;
}
return res;
}
int main() {
ll n, k, o;
cin >> n >> k;
ll suma = 0, sumb = 0, sumab = 0;
for (ll i = 1; i <= n; i++) {
cin >> o;
a[i] = (o % MOD + MOD) % MOD;
suma = (suma + a[i]) % MOD;
}
for (ll i = 1; i <= n; i++) {
cin >> o;
b[i] = (o % MOD + MOD) % MOD;
sumb = (sumb + b[i]) % MOD;
}
for (ll i = 1; i <= n; i++) {
sumab = (sumab + ((a[i] * b[i]) % MOD)) % MOD;
}
if (k == 0) { // 必须特判
cout << n << endl;
} else {
sumab = qpow(sumab, k - 1);
ll ans = (suma * sumb) % MOD;
cout << (ans * sumab) % MOD << endl;
}
return 0;
}