A tree is a simple graph in which every two vertices are connected by exactly one path. You are given a rooted tree with vertices and a lamp is placed on each vertex of the tree.
You are given queries of the following two types:
- : You switch the lamp placed on the vertex , that is, either from On to Off or Off to On.
- : Determine the number of vertices connected to the subtree of if you only consider the lamps that are in On state. In other words, determine the number of vertices in the subtree of , such as , that can reach from by using only the vertices that have lamps in the On state.
Note: Initially, all the lamps are turned On and the tree is rooted from vertex number .
Example: Input
5 4 1 2 2 3 1 4 4 5 1 3 2 2 1 3 2 2
Output
12
Approach
Java
import java.io.IOException;import java.io.InputStream;import java.io.PrintWriter;import java.util.ArrayList;public class TreeQuery {static int max = 5_00_001;static ArrayList<Integer>[] adj;static int[] st = new int[max];static int[] ft = new int[max];static int[] is = new int[max];static int max2 = 2 * (1 << (int) Math.ceil(Math.log(max) / Math.log(2)));static int[] t = new int[max2];static int[] mn = new int[max2];static int[] add = new int[max2];static int next = -1;public static void main(String[] args) {new Thread(null, TreeQuery::solve, "1", 1 << 26).start();}public static void solve() {FastIO f = new FastIO();PrintWriter pw = new PrintWriter(System.out);int n = f.nextInt();int q = f.nextInt();adj = new ArrayList[n];for (int i = 0; i < n; i++) {adj[i] = new ArrayList<>();}for (int i = 0; i < n - 1; i++) {int u = f.nextInt() - 1;int v = f.nextInt() - 1;adj[u].add(v);adj[v].add(u);}dfs(0, -1);build(0, n - 1, 1);while (q-- > 0) {int a = f.nextInt();int b = f.nextInt() - 1;if (a == 1) {add(st[b], ft[b], is[b] == 0 ? 1 : -1, 0, n - 1, 1);is[b] ^= 1;} else {if (is[b] == 1) {pw.println(0);} else {pw.println(get(st[b], ft[b], 0, n - 1, 1).b);}}}pw.flush();}private static Pair get(int l, int r, int s, int e, int v) {Pair p = new Pair();if (l <= s && e <= r) {p.a = mn[v];p.b = t[v];return p;}if (r < s || e < l)return p;int m = (s + e) / 2;Pair p1 = get(l, r, s, m, 2 * v);Pair p2 = get(l, r, m + 1, e, 2 * v + 1);p.a = Math.min(p1.a, p2.a);if (p.a == p1.a)p.b += p1.b;if (p.a == p2.a)p.b += p2.b;p.a += add[v];return p;}private static void dfs(int curr, int parent) {st[curr] = ++next;for (int i = 0; i < adj[curr].size(); i++) {if (adj[curr].get(i) != parent) {dfs(adj[curr].get(i), curr);}}ft[curr] = next;}static void build(int s, int e, int v) {t[v] = e - s + 1;if (s < e) {int mid = (s + e) / 2;build(s, mid, 2 * v);build(mid + 1, e, 2 * v + 1);}}static void add(int l, int r, int val, int s, int e, int v) {if (l <= s && e <= r) {mn[v] += val;add[v] += val;return;}if (r < s || e < l) {return;}int m = (s + e) / 2;add(l, r, val, s, m, 2 * v);add(l, r, val, m + 1, e, 2 * v + 1);mn[v] = Math.min(mn[2 * v], mn[2 * v + 1]);t[v] = 0;if (mn[v] == mn[2 * v]) {t[v] += t[2 * v];}if (mn[v] == mn[2 * v + 1]) {t[v] += t[2 * v + 1];}mn[v] += add[v];}}class FastIO {private final InputStream is;private final byte[] buf = new byte[1024];private int curChar;private int numChars;public FastIO() {this(System.in);}public FastIO(final InputStream is) {this.is = is;}public int[] nextArray(final int n) {final int[] a = new int[n];for (int i = 0; i < n; i++) {a[i] = nextInt();}return a;}public int read() {if (numChars == -1) {throw new RuntimeException();}if (curChar >= numChars) {curChar = 0;try {numChars = is.read(buf);} catch (final IOException e) {throw new RuntimeException();}if (numChars <= 0) {return -1;}}return buf[curChar++];}public String nextLine() {return readLine();}public String readLine() {int c = read();while (isSpaceChar(c)) {c = read();}final StringBuilder sb = new StringBuilder();do {sb.append((char) c);c = read();} while (!isEndOfLine(c));return sb.toString();}public String next() {int c = read();while (isSpaceChar(c)) {c = read();}final StringBuilder sb = new StringBuilder();do {sb.append((char) c);c = read();} while (!isSpaceChar(c));return sb.toString();}public long nextLong() {int c = read();while (isSpaceChar(c))c = read();int sgn = 1;if (c == '-') {sgn = -1;c = read();}long res = 0;do {res *= 10;res += c - '0';c = read();} while (!isSpaceChar(c));return res * sgn;}public int nextInt() {int c = read();while (isSpaceChar(c))c = read();int sgn = 1;if (c == '-') {sgn = -1;c = read();}int res = 0;do {res *= 10;res += c - '0';c = read();} while (!isSpaceChar(c));return res * sgn;}public boolean isSpaceChar(final int c) {return (c == ' ') || (c == '\n') || (c == '\r') || (c == '\t') || (c == -1);}public boolean isEndOfLine(final int c) {return (c == '\n') || (c == '\r') || (c == -1);}}class Pair {int a = Integer.MAX_VALUE, b = 0;}
No comments:
Post a Comment