牛客 上海理工L:捡贝壳 状压线段树+节点合并

题目省流:长度 nn 的序列,每个值表示成 aiaiaiai 有且仅有四种 (ai[1,4]ai∈[1,4])。mm 次操作或询问,操作 11[l,r][l,r] 赋值成同一个值 xx,操作 22 询问 [l,r][l,r] 内是否存在最短的一段连续区间,其中包含四种 aiai,有则输出长度,否则输出 1-1

10
1 2 3 4 1 2 3 4 1 2
6
1 3 5 2
1 4 6 3
2 2 10
2 7 10
2 6 9
2 6 10
4
4
-1
4

思路:

考虑状压

首先区间问题,容易想到是否可以用线段树维护,但是显而易见的是,线段树维护的信息需要满足区间可加性,仅仅知道区间里的贝壳种类显然不够,我们更趋向于知道到底有哪些贝壳。

发现 aiai 的取值并不大,为了记录当前手上有几种贝壳种类,可以采用状态压缩的思想,区间合并的时候也容易,只需要将左右儿子的状压值起来就行。现在考虑的是,我们如何维护走几步的问题,类比“你能回答这些问题吗”问题(链接在后面),注意到答案必然只有三种可能: alt 维护的时候,我们注意到仅仅用一个值,表示一个方向过来的价值并不够,我们更趋向于不仅要知道走多少步,还要知道他手上的贝壳实际的种类,考虑从状压入手。我们不妨在每个节点里面申请两个数组 slslsrsrsl[state]sl[state] 表示从左往右走,得到的贝壳种类表示成 statestate 的最小步数,其中 statestate 是状态压缩后的值,同时我们还需要快速的知道,当前区间里所拥有的贝壳具体种类,我们用 hashas 表示。(srsr 类似)

因为 ai[1,4]ai∈[1,4],为了方便记录,我们对 ai1ai-1 再映射成二进制位。

struct seg {
	
	int l, r;
	int sl[17], sr[17]; // 四种贝壳, 最多 (1111)2
	
	int has, len, ans;
	mutable int lazy; // 避免冲突 -1表示没有懒标记 
	
	seg() {}
	
	seg (int x, int l_, int r_, int lz=-1) {
		
		memset(sl, 0x3f, sizeof sl); // 因为一开始没有,所以我们要初始化Inf
		memset(sr, 0x3f, sizeof sr);
		
		has = 1<<x;	// 具体种类
		
		l = l_, r = r_, len = r-l+1;
		lazy = lz;
		sl[has] = sr[has] = 1;
		ans = Inf;
	}
	
};

合并节点

维护 slslsrsr 显然是重头戏。合并节点的时候,除了直接拿来左右节点的状压数组,发现左节点提供的贡献有:左节点整块+右节点左边的散块,画图表示为:(右节点同理) alt

retret 表示合并后的节点,leftleft 表示左节点,rightright 表示右节点,即为(右边同理)

for (int i=0;i < 16;i ++ ) {
  	ret.sl[i] = min(ret.sl[i], sl[i]);
	ret.sl[has|i] = std::min(ret.sl[has|i], len+rhs.sl[i]);
}

枚举答案的时候,手上一定是拥有了全部贝壳,即为 (1111)2=15(1111)_2=15,我们暴力枚举左右儿子对 1515 的贡献 iijj,当有 (i|j)==15 意义着我们拿到了所有贝壳。iijj 允许为 00,当 ii00 时,等价右节点从左往右走的答案,当 jj00 时,等价于左节点从左往右走的答案,但是还少了左节点往右,右节点往左的贡献,所以我们在一开始需要对左右节点的答案 ansansminmin。为了方便,我们将合并节点的 pushup写成了重载+号的写法。

seg operator+(const seg& rhs) const {
		
	seg ret;
		
	memset(ret.sl, 0x3f, sizeof ret.sl);
	memset(ret.sr, 0x3f, sizeof ret.sr);
		
	ret.ans = Inf;
	ret.lazy = -1; // -1
	ret.l = l, ret.r = rhs.r;
	ret.len = (ret.r-ret.l+1);
	ret.has = has | rhs.has;
		
	for (int i=0;i < 16;i ++ ) {
		ret.sl[i] = std::min(ret.sl[i], sl[i]);
		ret.sr[i] = std::min(ret.sr[i], rhs.sr[i]);
			
		ret.sl[has|i] = std::min(ret.sl[has|i], len+rhs.sl[i]);
		// 注意这里是 rhs.has 调了好久 T^T 
		ret.sr[rhs.has|i] = std::min(ret.sr[rhs.has|i], sr[i]+rhs.len);
	}
	 
	ret.ans = std::min({ret.ans, ans, rhs.ans});
		
	for (int i=0;i < 16;i ++ ) {
		for (int j=0;j < 16;j ++ ) {
			if ((i|j) == 15) ret.ans = std::min(ret.ans, sr[i]+rhs.sl[j]);
		}
	} 
        
    return ret;
}

其他地方需要注意的就是 modify 的时候,递归到满足 l<=tr[u].l && tr[u].r <= r 的时候,我们直接将 tr[u] = seg (x, tr[u].l, tr[u].r, x)

因为 seg()seg+seg 我们都默认将 lazy 清零,所以每次操作节点 u 前都需要 pushdown(u)

void modify(int u, int l, int r, int x) {

	pushdown(u);
	
	if (l <= tr[u].l && tr[u].r <= r) {
		int trl = tr[u].l, trr = tr[u].r;
		tr[u] = seg (x, trl, trr, x);
		return ;
	}
	
	int mid = (tr[u].l+tr[u].r)/2;
	if (l <= mid) modify(u<<1, l, r, x);
	if (r > mid) modify(u<<1|1, l, r, x);
	
	pushup(u); 
}

CodingCoding 时间到!

#include <iostream>
#include <vector>
#include <cstring>
#include <assert.h>
#include <algorithm>
#define int long long

using i64 = long long ;

constexpr int N = 1e5 + 12, M = 17;
constexpr int Inf = 2e9;

int n, m;
std::vector<int> w(N);

struct seg {
	
	int l, r;
	int sl[17], sr[17];
	
	int has, len, ans;
	mutable int lazy; // 避免冲突 -1表示没有懒标记 
	
	seg() {}
	
	seg (int x, int l_, int r_, int lz=-1) {
		
		memset(sl, 0x3f, sizeof sl);
		memset(sr, 0x3f, sizeof sr);
		sl[0] = sr[0] = 0; 
		
		has = 1<<x;
		if (has > 15) assert(0);
		
		l = l_, r = r_, len = r-l+1;
		lazy = lz;
		sl[has] = sr[has] = 1;
		ans = Inf;
	}

	seg operator+(const seg& rhs) const {
		
		seg ret;
		
		memset(ret.sl, 0x3f, sizeof ret.sl);
		memset(ret.sr, 0x3f, sizeof ret.sr);
		
		ret.ans = Inf;
		ret.lazy = -1; // -1
		ret.l = l, ret.r = rhs.r;
		ret.len = (ret.r-ret.l+1);
		ret.has = has | rhs.has;
		
		for (int i=0;i < 16;i ++ ) {
			ret.sl[i] = std::min(ret.sl[i], sl[i]);
			ret.sr[i] = std::min(ret.sr[i], rhs.sr[i]);
			
			ret.sl[has|i] = std::min(ret.sl[has|i], len+rhs.sl[i]);
			// 注意这里是 rhs.has 调了好久 T^T 
			ret.sr[rhs.has|i] = std::min(ret.sr[rhs.has|i], sr[i]+rhs.len);
		}
		
		ret.ans = std::min(ret.sl[15], ret.sr[15]); 
		ret.ans = std::min({ret.ans, ans, rhs.ans});
		
		for (int i=0;i < 16;i ++ ) {
			for (int j=0;j < 16;j ++ ) {
				if ((i|j) == 15) ret.ans = std::min(ret.ans, sr[i]+rhs.sl[j]);
			}
		} 
        
        return ret;
	} 
	
}tr[N<<2];

inline void pushup(int u) {
	tr[u] = tr[u<<1] + tr[u<<1|1];
}

inline void pushdown(int p) {
	
	seg& u = tr[p], &l = tr[p<<1], &r=tr[p<<1|1];
	
	if (~u.lazy) {
		
        int mid = (u.l+u.r)/2;
		l = seg (u.lazy, u.l, mid);
		r = seg (u.lazy, mid+1, u.r);
		
		l.lazy = r.lazy = u.lazy;
		u.lazy = -1;	
	}	
}

void build(int u, int l, int r) {
	
	tr[u].l = l, tr[u].r = r;
	if (l == r) {
		return tr[u] = seg (w[l], l, r), void();
	}
	
	int mid = (tr[u].r+tr[u].l)/2;
	build(u<<1, l, mid);
	build(u<<1|1, mid+1, r);	
	pushup(u);
}

void modify(int u, int l, int r, int x) {

	pushdown(u);
	
	if (l <= tr[u].l && tr[u].r <= r) {
		int trl = tr[u].l, trr = tr[u].r;
		tr[u] = seg (x, trl, trr, x);
		return ;
	}
	
	int mid = (tr[u].l+tr[u].r)/2;
	if (l <= mid) modify(u<<1, l, r, x);
	if (r > mid) modify(u<<1|1, l, r, x);
	
	pushup(u); 
}

seg query(int u, int l, int r) {
	
	//std::cerr<<u<<" "<<tr[u].l<<' '<<tr[u].r<<" "<<l<<" "<<r<<"\n";
	pushdown(u);
	
	if (l <= tr[u].l && tr[u].r <= r) {
		return tr[u];
	} 
	
	int mid = (tr[u].l + tr[u].r)/2;
	if (r <= mid) return query(u<<1, l, r);
	else if (l > mid) return query(u<<1|1, l, r);
	else {
		return query(u<<1, l, r) + query(u<<1|1, l, r);	
	}
} 

int32_t main() {
	
	std::ios::sync_with_stdio(false);
	std::cin.tie(nullptr);
	
	std::cin>>n;
	for (int i=1;i <= n;i ++ ) {
		std::cin>>w[i], w[i] -- ;
	}
	
	build(1, 1, n);

	std::cin >> m;
	for (int i=1;i <= m;i ++ ) {
		int op, l, r; std::cin>>op>>l>>r;
		
		if (op&1) {
			int x; std::cin>>x, x -- ;
			modify(1, l, r, x);
		} else {
			auto ans = query(1, l, r);
			std::cout<<(ans.ans<=n?ans.ans:-1)<<"\n"; 
		}
	}
	
	return 0;
} 

类似的题目:

你能回答这些问题吗,一道典题,

https://www.acwing.com/problem/content/246/