bzoj2152 聪聪可可

题意:问一棵带权树上路径长度为$3$的倍数的点对有多少对,其中包括自身与自身,两点交换算两个点对。

点分治模板题,分治时记录模$3$的余数为$0,1,2$的分别有多少个点,分子树向前统计即可,最后加上重复和自反的点对。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
using namespace std;
inline int ()
{
char c;int tmp=0,x=1;c=getchar();
while(c>'9' || c<'0') {if(c=='-') x=-1;c=getchar();}
while(c>='0' && c<='9') {tmp=tmp*10+c-'0';c=getchar();}
return tmp*x;
}
inline int gcd(int a,int b){return b==0?a:gcd(b,a%b);}
const int maxn=20000+5;
int head[maxn],eg[maxn<<1],nxt[maxn<<1],W[maxn<<1],tot=0,n;
void addedge(int u,int v,int w)
{
eg[++tot]=v;nxt[tot]=head[u];W[tot]=w;head[u]=tot;
eg[++tot]=u;nxt[tot]=head[v];W[tot]=w;head[v]=tot;
}
bool Isc[maxn];
int siz[maxn],dis[3],par[3],Ans;
void Cal_Size(int v,int fa)
{
siz[v]=1;
for(int i=head[v];i;i=nxt[i]) {
int u=eg[i];
if(u!=fa && !Isc[u]) Cal_Size(u,v),siz[v]+=siz[u];
}
}
typedef pair<int ,int > pii;
#define fir first
#define sec second
#define MP make_pair
pii Find_Cent(int v,int fa,int N)
{
pii ret=MP(INT_MAX,-1);
int maxsiz=0,sum=1;
for(int i=head[v];i;i=nxt[i]) {
int u=eg[i];
if(u==fa || Isc[u]) continue;
ret=min(ret,Find_Cent(u,v,N));
maxsiz=max(maxsiz,siz[u]);
sum+=siz[u];
}
maxsiz=max(maxsiz,N-sum);
ret=min(ret,MP(maxsiz,v));
return ret;
}
void Cal_Dist(int v,int fa,int precost)
{
par[precost%3]++;
for(int i=head[v];i;i=nxt[i]) {
int u=eg[i];
if(u==fa || Isc[u]) continue;
Cal_Dist(u,v,precost+W[i]);
}
}
int upd()
{
int ret=0;
ret+=dis[0]*par[0];
ret+=dis[1]*par[2];
ret+=dis[2]*par[1];
return ret;
}
void solve(int v)
{
Cal_Size(v,-1);
int cv=Find_Cent(v,-1,siz[v]).sec;
Isc[cv]=true;
for(int i=head[cv];i;i=nxt[i]) {
int u=eg[i];
if(!Isc[u]) solve(u);
}
memset(dis,0,sizeof(dis));
dis[0]++;
for(int i=head[cv];i;i=nxt[i]) {
int u=eg[i];
if(!Isc[u]) {
memset(par,0,sizeof(par));
Cal_Dist(u,cv,W[i]%3);
Ans+=upd();
for(int j=0;j<3;j++) dis[j]+=par[j];
}
}
Isc[cv]=false;
}
int main()
{
n=readInt();
int u,v,w;
for(int i=1;i<=n-1;i++) {
u=readInt(),v=readInt(),w=readInt();
addedge(u,v,w);
}
solve(1);
int A=Ans*2+n,B=n*n;
int gd=gcd(A,B);
printf("%d/%dn",A/gd,B/gd);
return 0;
}