I tried to use permutations and some math formula to solve it, but failed miserably.
I ended up looking up a match editorial, and realized there was much easier and faster solution.
Here is the problem statement.
Problem Statement |
|||||||||||||
You're given ints N and R. Count the number of sequences A0, A1, ..., AN-1 such that each Ai is an integer satisfying 0 ≤ Ai ≤ R and A0 + A1 + ... + AN-1 = A0 | A1 | ... | AN-1. The '|' symbol stands for bitwise OR of the operands. Return the number of such sequences modulo 1,000,000,009. |
|||||||||||||
Definition |
|||||||||||||
|
|||||||||||||
Notes |
|||||||||||||
- | If a and b are single bits then a | b is defined as max(a, b). For two integers, A and B, in order to calculate A | B, they need to be represented in binary: A = (an...a1)2, B = (bn...b1)2 (if the lengths of their representations are different, the shorter one is prepended with the necessary number of leading zeroes). Then A | B = C = (cn...c1)2, where ci = ai | bi. For example, 10 | 3 = (1010)2 | (0011)2 = (1011)2 = 11. | ||||||||||||
Constraints |
|||||||||||||
- | N will be between 2 and 10, inclusive. | ||||||||||||
- | R will be between 1 and 15000, inclusive. | ||||||||||||
Examples |
|||||||||||||
0) | |||||||||||||
|
|||||||||||||
1) | |||||||||||||
|
|||||||||||||
2) | |||||||||||||
|
|||||||||||||
3) | |||||||||||||
|
Logic.
The first thing to notice is that we have to meet condition
(sum of all sequences ) equals ( result of OR operation over all sequences).
This can be interpreted "each sequence should have different bits set. They never should share any same bit set." Let's call it Rule #1
For example,
given R = 3, its binary representation is (11).
Possible combinations are
(2,1) ==> ( (10), (01) )
(3,0) ==> ( (11), (00) )
Invalid combinations are
(3,1) ==> ((11), (01)) because both 3 and 1 share the first set bit (counting from the least significant bit).
The second thing to notice is that the highest bit that R will ever use is 13 (indexed from 0), because constraint states that R <= 15000.
This tells that number of bits are small enough to iterate through them and count possible outcomes.
Since we concluded that this is a counting problem, we need to define "state" and "recurrence"
State can be set as
mem[col][n]
= number of ways to place bits when we have "col" number of columns left to set
and there are "n" number of rows that are still equal to R.
Then our recurrence should take (col, n) as input and output some number.
It will look like
long long rec(int col, int n)
Let's define recurrence.
We will start by n == N, meaning all rows are equal to R at current state.
We will iterate bits from left to right (from the most significant to least significant), and decide how to set current bit for N rows.
To be specific, for each bit we will do one of following
option 1) Pick one row and set its current bit to 1. According to rule #1, bits of all other rows should be set to 0.
option 2)Set bits of all rows to 0
So, our recurrence call will look something like
rec(col, n) calls rec(col-1, z) where z <= n, because number of rows whose values are equal to R decreases or stays the same.
Since our recurrence guarantees that value of col decrease by one every time, it will reach the base case.
Now, let's look at our base case.
Base case:
If no more columns to set (col == 0), and all previous columns are set correctly, there is only one possible result ( = current bit shape). So return 1
We start our recurrence by setting n == N , meaning we have N rows that are equal to R.
We will go through each column(bit) from right to left (from most significant to least significant), and count ways to set bits for N rows.
Let bool isZero = true if current bit of R is set to 0, and false otherwise.
If (n == N and isZero)
We cannot set any bit to 1, because that will make a sequence bigger than R, which is forbidden in problem statement.
So we call rec(n-1, n) ==> Move to the next bit, and all rows are still equal to R
If( n == N and not isZero)
We have two options.
1) Use 0 in all bits of this column, making all the rows smaller than R. This will also make n = 0, because no row is equal to R anymore.
==> rec(col-1, 0)
2) Use 1 in one of N rows. According to Rule #1, all other rows must have 0 in that bit.
This will also make the row whose current bit is set to 1 the only row whose value equal to R.
And you can choose any one of N rows for this.
==> N * rec(col-1, 1)
By anlyzing n == N, we know now that n will be one of three values : 0,1, N
We also know that this recurrence will end at most 13 steps, because the highest bit R can have is 13 and col decrease by 1 at each step.
If(n == 1 and isZero)
1) We can only set bit to 1 to (N-1) rows whose value is already smaller than R.
=> (N-1) * rec(col-1, 1);
2) We don't set any bit to 1, keeping our current bit shape
=> rec(col-1, 0)
if(n == 1 and not isZero)
1) We can set bit to 1 to the only row whose value is still equal to R.
=> rec(col-1, 1)
2) We set bit to 1 to one of (N-1) other rows whose values are already smaller than R. This will also make the value of the row whose value is currently equal to R now smaller than R, because its current bit has to be set to 0 according to rule #1.
=> (N-1) * rec(col-1,0)
3) We set the current bit to 0 to all N rows. This will make the value of the row whose value is equal to R now smaller than R.
=> rec(col -1, 0)
if(n==0)
If n== 0, no row has value eqaul to R. So isZero does not affect our count.
1) We can set bit to 1 to one of (N) rows whose value is smaller than R.
=> N * rec(col-1, 0)
2) we set all bit to 0 to all N rows.
=> rec(col-1, 0)
Follow is source code ( It passed system test.)
#include <vector>
#include <list>
#include <map>
#include <set>
#include <queue>
#include <deque>
#include <stack>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <string>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <ctime>
using namespace std;
#define REP(i,a,b) for(int i=a; i < b; i++)
#define REPE(i, a, b) for(int i=a; i <=b; i++)
int INF = numeric_limits<int>::max();
int nINF = numeric_limits<int>::min();
typedef long long ll;
class YetAnotherORProblem2 {
public:
int countSequences(int, int);
};
ll mem[15][11];
int N;
int R;
int mod = 1000000009;
ll rec(int col, int numEqualRows)
{
ll& res = mem[col][numEqualRows];
if(res != -1) return res;
if(col == 0)
{
res = 1;
}
else
{
res = 0;
bool isZero = ( (R & (1<<(col-1))) == 0);
if(numEqualRows == N)
{
if(isZero)
res = rec(col-1, numEqualRows) % mod;
else
{
res = rec(col-1, 0);
res += N * rec(col-1, 1);
}
}
if(numEqualRows == 1)
{
if(isZero) res = rec(col-1, 1) + (N-1) * rec(col-1, 1);
else res = rec(col-1, 1) + rec(col-1, 0) + (N-1) * rec(col-1, 0);
}
if(numEqualRows == 0)
{
res = rec(col-1, 0);
res += N * rec(col-1, 0);
}
}
res %= mod;
return res;
}
int YetAnotherORProblem2::countSequences(int _N, int _R) {
N = _N;
R = _R;
memset(mem, -1, sizeof(mem));
return rec(14, N);
}
No comments:
Post a Comment