Given a tree with n
nodes (n
can be as large as 2 * 10^5
), where each node has a cost associated with it, let us define the following functions:
g(u, v) = the sum of all costs on the simple path from u to v
f(n) = the (n + 1)th Fibonacci number (n + 1 is not a typo)
The problem I'm working on requires me to compute the sum of f(g(u, v))
over all possible pairs of nodes in the tree modulo 10^9 + 7
.
As an example, let's take a tree with 3
nodes.
- without loss of generality, let's say node
1
is the root, and its children are2
and3
costs[1] = 2, cost[2] = 1, cost[3] = 1
g(1, 1) = 2; f(2) = 2
g(2, 2) = 1; f(1) = 1
g(3, 3) = 1; f(1) = 1
g(1, 2) = 3; f(3) = 3
g(2, 1) = 3; f(3) = 3
g(1, 3) = 3; f(3) = 3
g(3, 1) = 3; f(3) = 3
g(2, 3) = 4; f(4) = 5
g(3, 2) = 4; f(4) = 5
Summing all of the values, and taking the result modulo 10^9 + 7
gives 26
as the correct answer.
My attempt:
I implemented an algorithm to compute g(u, v)
in O(log n)
by finding the lowest common ancestor using a sparse table.
For the finding of the appropriate Fibonacci values, I tried two approaches, namely using exponentiation on the matrix form and another by noticing that the sequence modulo 10^9 + 7
is cyclical.
Now comes the extremely tricky part. No matter how I do the above computations, I still end up going to up to O(n^2)
pairs when calculating the sum of all possible f(g(u, v))
. I mean there's the obvious improvement of only going up to n * (n - 1) / 2
pairs but that's still quadratic.
What am I missing? I've been at it for several hours, but I can't see a way to get that sum without actually producing a quadratic algorithm.
To know how many times the cost of a node X is to be included in the total sum, we divide the other nodes into 3 (or more) groups:
When two nodes belong to different groups, their simple path goes through X. So the number of simple paths that go through X is:
So by counting the total number of nodes N, and the size of the subtrees under X, you can calculate how many times the cost of node X should be included in the total sum. Do this for every node and you have the total cost.
The code for this could be straightforward. I'll assume that the total number of nodes N is known, and that you can add properties to the nodes (both of these assumptions simplify the algorithm, but it can be done without them).
We'll add a child_count to store the number of descendants of the node, and a path_count to store the number of simple paths that the node is part of; both are initialised to zero.
For each node, starting from the root: