关于DP的优化方法有很多种,低级的有矩阵快速幂,高级一点的比如四边形不等式优化、斜率优化等等。
因为在动态规划中,有这样的一类问题
状态转移方程 dp[i][j]=min{dp[i][k-1]+dp[k][j]}+w[i][j] (k>i&&k<=j) 时间复杂度为 O(n^3)
且有如下一些定义和定理:
如果一个函数w[i][j],满足 w[i][j]+w[i'][j']<=w[i][j']+w[i'][j]( i<=i'<=j<=j') 则称w满足凸四边形不等式
如果一个函数w[i][j],满足 w[i'][j]<=w[i][j'] ( i<=i'<=j<=j' )则称w关于区间包含关系单调
定理1:如果w同时满足四边形不等式和区间单调关系,则dp也满足四边形不等式
定理2:如果定理1条件满足时让dp[i][j]取最小值的k为K[i][j],则K[i][j-1]<=K[i][j]<=K[i+1][j]
注:定理2是四边形不等式优化的关键所在,它说明了决策量具有单调性,然后我们可以据此来缩小决策枚举的区间,进行优化
定理3:w为凸当且仅当 w[i][j]+w[i+1][j+1]<=w[i+1][j]+w[i][j+1]
注:定理3其实就是验证w是否为凸的方法,就是固定一个变量,然后看成是一个一元函数,进而判断单调性。
如,我们可以固定j算出w[i][j+1]-w[i][j]关于i的表达式,看它是关于i递增还是递减,如果是递减,则w为凸
以上三个定理来自于黑书,具体证明过程太啰嗦,各种乱走符号,所以懒得贴上来,记住结论会用就差不多了。
这种优化方法的一般步骤是:
先证明w[i][j+1]-w[i][j]关于i的表达式的单调性,如果递减,则w满足凸四边形不等式,再证明w是否同时满足区间关系单调性,
如果两条都满足,则推出dp也满足凸四边形不等式,所以状态转移方程dp[i][j]=min{dp[i][k-1]+dp[k][j]}+w[i][j] (k>i&&k<=j)中的决策量s[i][j](也就是k)满足s[i][j-1]<=s[i][j]<=s[i+1][j],因此s[i][j]的枚举区间由(i+1,j)缩小为(s[i][j-1],s[i+1,j]),使复杂度从O(n^3)下降到O(n^2)。
PS:实际操作中,我们往往并不需要进行烦躁的证明,而只需要打表,然后观察就行了
如w[i][j],dp[i][j]是否满足四边形不等式啊,w[i][j]是否单调啊,决策函数K[i][j]是否满足定理2的不等式关系啊,都可以通过打表来搞
抽象的理论太烦躁,看一下具体的例子。。。。
HDU2829
Lawrence
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others)Total Submission(s): 2202 Accepted Submission(s): 965
容易求出dp方程:f[m][n] = Min{f[m-1][k]+w[k+1][n]}
令w[i][j](i<=j)表示(a[i]*a[i+1]+a[i]*a[i+2]+....+a[i]*a[j])+(a[i+1]*a[i+2]+a[i+1]=a[i+3]+....)+....+a[j-1]*a[j]——(1)
则(1)式可以表示为[(a[i]+a[i+1]+...+a[j])*(a[i]+a[i+1]+...+a[j])-a[i]^2-a[i+1]^2-....-a[j]^2] / 2
即(1) = [(sum1[j]-sum1[i-1])^2-(sum2[j]-sum2[i-1])]/2。
其中sum1[i]表示从0到i元素的和,sum2[i]表示从0到i元素的平方的和,可以用O(n)的时间预处理。
现在我们固定j,看一看函数 w[i][j+1]-w[i][j] 随i增加的增减性,很明显如果i加1的话,w[i][j+1]比w[i][j]减少的要多,所以该函数递减,所以w为凸;
而区间包含关系是很显然的,所以w函数满足四边形不等式,所以可以推得f[m][n]也满足四边形不等式。
决策变量范围:s[m-1][n]<=s[m][n]<=s[m][n+1]。
1 #include2 #include 3 #include 4 usingnamespace std; 5 6 constint N =1010; 7 8 typedef longlong llg; 9 10 int n, m, a[N], sum1[N], sum2[N], s[N][N];11 llg f[N][N];12 13 void dp()14 {15 int i, j, k, a, b;16 llg tmp;17 memset(f, -1, sizeof(f));18 for(i =0; i <= n; i++)19 {20 tmp = sum1[i];21 f[0][i] = (tmp*tmp - sum2[i]) /2;22 s[0][i] =0;23 }24 for(i =1; i <= m; i++)25 for(j = n; j >= i; j--)26 {27 s[i][n+1] = n;28 a = s[i-1][j];29 b = s[i][j+1];30 for(k = a; k <= b; k++)31 {32 tmp = sum1[j] - sum1[k];33 tmp = f[i-1][k]+(tmp*tmp - (sum2[j]-sum2[k])) /2;34 if(f[i][j]==-1|| f[i][j]>tmp)35 {36 f[i][j] = tmp;37 s[i][j] = k;38 }39 }40 }41 }42 43 int main()44 {45 int i;46 while(scanf("%d%d", &n, &m) != EOF)47 {48 if(n==0&& m==0) break;49 sum1[0] = sum2[0] =0;50 for(i =1; i <= n; i++)51 {52 scanf("%d", a+i);53 sum1[i] = sum1[i-1] + a[i];54 sum2[i] = sum2[i-1] + a[i]*a[i];55 }56 dp();57 printf("%lld\n", f[m][n]);58 }59 return0;60 }