星星博客 »  > 

Codeforces Round #721 (Div. 2) E. Partition Game 线段树优化DP

原题链接:https://codeforces.ml/contest/1527/problem/E

题意

给了一个长度为n的序列,其中 C o s t ( t ) = ∑ x ∈ s e t ( t ) l a s t ( x ) − f i r s t ( x ) Cost(t)=\sum_{x∈set(t)}last(x)-first(x) Cost(t)=xset(t)last(x)first(x),我们可以将序列分成k段,问minCost是多少

分析

不难想到dp的状态 d p [ i ] [ j ] dp[i][j] dp[i][j]代表前i个数分成j组时的最小花费,然后先推出一个暴力的DP方程

d p [ i ] [ j ] = m i n ( d [ k ] [ j − 1 ] + v a l ( k + 1 , i ) ) k ∈ [ 0 , j − 1 ] dp[i][j] = min(d[k][j-1]+val(k+1,i))k∈[0, j-1] dp[i][j]=min(d[k][j1]+val(k+1,i))k[0,j1]

v a l ( i , j ) 代 表 从 [ i , j ] 区 间 的 花 费 val(i,j)代表从[i,j]区间的花费 val(i,j)[i,j]

如果直接暴力去找肯定是超时的,这时候就可以用数据结构去优化DP。首先考虑怎么算区间内的花费,我们记录一个last[x]表示这个数前一次出现的位置,然后存入 i − l a s t [ x ] i-last[x] ilast[x],统计区间和,这样就可以算出每个数最晚出现和最早出现之差,但这样是有问题的,因为有些数的 l a s t [ x ] < k + 1 last[x]<k+1 last[x]<k+1,这样的值对于区域是没有贡献的,因此我们倒过来考虑,去累加当前x对哪些区间有影响。 l a s t [ x ] > = k + 1 last[x]>=k+1 last[x]>=k+1推出 k < = l a s t [ x ] − 1 k<=last[x]-1 k<=last[x]1,也就是说当前i对 [ 0 , l a s t [ x ] − 1 ] [0,last[x]-1] [0,last[x]1] i − l a s t [ x ] i-last[x] ilast[x]的贡献。

Code

#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define re register
typedef long long ll;
typedef pair<ll, ll> PII;
typedef unsigned long long ull;
const int N = 35005 + 20, M = 1e6 + 5, INF = 0x3f3f3f3f;
const int MOD = 1e9+9;
int dp[N];
int a[N], last[N], pre[N];
struct node {
    int l, r;
    int sum, tag;
}t[N<<2];
void push_up(int u) {
    t[u].sum = min(t[u<<1].sum, t[u<<1|1].sum);
}
void push_down(int u) {
    if (t[u].tag) {
        t[u<<1].tag += t[u].tag;
        t[u<<1|1].tag += t[u].tag;
        t[u<<1].sum += t[u].tag;
        t[u<<1|1].sum += t[u].tag;
        t[u].tag = 0;
    }
}
void build(int u, int l, int r) {
    t[u].l = l, t[u].r = r, t[u].tag = 0, t[u].sum = INF;
    if (l == r) {
        t[u].sum = dp[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(u<<1, l, mid);
    build(u<<1|1, mid+1, r);
    push_up(u);
}
void modify(int u, int ql, int qr, int val) {
    if (ql <= t[u].l && qr >= t[u].r) {
        t[u].sum += val;
        t[u].tag += val;
        return;
    }
    push_down(u);
    int mid = (t[u].l + t[u].r) >> 1;
    if (ql <= mid) modify(u<<1, ql, qr, val);
    if (qr > mid) modify(u<<1|1, ql, qr, val);
    push_up(u);
}
int query(int u, int ql, int qr) {
    if (ql <= t[u].l && qr >= t[u].r) return t[u].sum;
    int mid = (t[u].l + t[u].r) >> 1;
    int ans = INF;
    push_down(u);
    if (ql <= mid) ans = min(ans, query(u<<1, ql, qr));
    if (qr > mid) ans = min(ans, query(u<<1|1, ql, qr));
    return ans;
}
void solve() {
    int n, k; cin >> n >> k;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        if (!last[a[i]]) pre[i] = i;
        else pre[i] = last[a[i]];
        last[a[i]] = i;
    }
    memset(dp, 0x3f, sizeof dp);
    dp[0] = 0;
    for (int i = 1; i <= k; i++) {
        build(1, 0, n);
        for (int j = 1; j <= n; j++) {
            int p = pre[j];
            int val = j - p;
            modify(1, 0, p-1, val);
            dp[j] = query(1, 0, j-1);
        }
    }
    cout << dp[n] << endl;
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
#ifdef ACM_LOCAL
    freopen("input", "r", stdin);
    freopen("output", "w", stdout);
#endif
    solve();
}

相关文章