1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
|
#include <cstring> #include <algorithm> using namespace std; int u[50010<<1],v[50010<<1],fir[50010],nxt[50010<<1],cnt,root,sz[50010],f[50010],middis[50010],midcnt,n,k,vis[50010],Siz; long long ans=0; void (int ui,int vi){ ++cnt; u[cnt]=ui; v[cnt]=vi; nxt[cnt]=fir[ui]; fir[ui]=cnt; } void getroot(int u,int fa){ sz[u]=1,f[u]=1; for(int i=fir[u];i;i=nxt[i]){ if(v[i]==fa||vis[v[i]]) continue; getroot(v[i],u); sz[u]+=sz[v[i]]; f[u]=max(f[u],sz[v[i]]); } f[u]=max(Siz-sz[u],f[u]); if(f[u]<f[root]) root=u; } void getdis(int u,int d,int fa){ middis[++midcnt]=d; for(int i=fir[u];i;i=nxt[i]){ if(vis[v[i]]||v[i]==fa) continue; getdis(v[i],d+1,u); } } int look1(int l,int k){ int ans=0,r=midcnt; while(l<=r){ int mid=(l+r)>>1; if(middis[mid]<k) l=mid+1; else ans=mid,r=mid-1; } return ans; } int look2(int l,int k){ int ans=0,r=midcnt; while(l<=r){ int mid=(l+r)>>1; if(middis[mid]<=k) l=mid+1,ans=mid; else r=mid-1; } return ans; } int solve(void){ sort(middis+1,middis+midcnt+1); int mid=0; int l=1; while(l<midcnt&&middis[l]+middis[midcnt]<k) ++l; while(l<midcnt&&k-middis[l]>=middis[l]){ int l2=look2(l+1,k-middis[l]),l1=look1(l+1,k-middis[l]); if(l2>=l1) mid+=l2-l1+1; l++; } return mid; } void divide(int u){ vis[u]=true; midcnt=0; getdis(u,0,0); ans+=solve(); for(int i=fir[u];i;i=nxt[i]){ if(vis[v[i]]) continue; midcnt=0; getdis(v[i],1,0); ans-=solve(); root=0; Siz=sz[v[i]]; getroot(v[i],u); divide(root); } } int main(){ scanf("%d %d",&n,&k); for(int i=1;i<=n-1;i++){ int a,b; scanf("%d %d",&a,&b); addedge(a,b); addedge(b,a); } Siz=n; f[0]=0x3f3f3f3f; getroot(1,0); divide(root); printf("%lld",ans); return 0; }
|
近期评论