Tree query

A tree is a simple graph in which every two vertices are connected by exactly one path. You are given a rooted tree with n vertices and a lamp is placed on each vertex of the tree. 

You are given q queries of the following two types:

  • 1 v: You switch the lamp placed on the vertex v, that is, either from On to Off or Off to On.
  • 2 v: Determine the number of vertices connected to the subtree of v if you only consider the lamps that are in On state. In other words, determine the number of vertices in the subtree of v, such as u, that can reach from v by using only the vertices that have lamps in the On state.

Note: Initially, all the lamps are turned On and the tree is rooted from vertex number 1.


Example: Input

5 4 1 2 2 3 1 4 4 5 1 3 2 2 1 3 2 2

Output

1
2

Approach

Java


import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;

public class TreeQuery {
    static int max = 5_00_001;
    static ArrayList<Integer>[] adj;
    static int[] st = new int[max];
    static int[] ft = new int[max];
    static int[] is = new int[max];
    static int max2 = 2 * (1 << (intMath.ceil(Math.log(max) / Math.log(2)));
    static int[] t = new int[max2];
    static int[] mn = new int[max2];
    static int[] add = new int[max2];
    static int next = -1;

    public static void main(String[] args) {
        new Thread(null, TreeQuery::solve, "1"1 << 26).start();
    }

    public static void solve() {
        FastIO f = new FastIO();
        PrintWriter pw = new PrintWriter(System.out);
        int n = f.nextInt();
        int q = f.nextInt();
        adj = new ArrayList[n];
        for (int i = 0; i < n; i++) {
            adj[i] = new ArrayList<>();
        }
        for (int i = 0; i < n - 1; i++) {
            int u = f.nextInt() - 1;
            int v = f.nextInt() - 1;
            adj[u].add(v);
            adj[v].add(u);
        }
        dfs(0, -1);
        build(0, n - 11);
        while (q-- > 0) {
            int a = f.nextInt();
            int b = f.nextInt() - 1;
            if (a == 1) {
                add(st[b], ft[b], is[b] == 0 ? 1 : -10, n - 11);
                is[b] ^= 1;
            } else {
                if (is[b] == 1) {
                    pw.println(0);
                } else {
                    pw.println(get(st[b], ft[b], 0, n - 11).b);

                }
            }
        }
        pw.flush();
    }

    private static Pair get(int lint rint sint eint v) {
        Pair p = new Pair();
        if (l <= s && e <= r) {
            p.a = mn[v];
            p.b = t[v];
            return p;
        }
        if (r < s || e < l)
            return p;
        int m = (s + e) / 2;
        Pair p1 = get(l, r, s, m, 2 * v);
        Pair p2 = get(l, r, m + 1, e, 2 * v + 1);
        p.a = Math.min(p1.ap2.a);
        if (p.a == p1.a)
            p.b += p1.b;
        if (p.a == p2.a)
            p.b += p2.b;
        p.a += add[v];
        return p;
    }

    private static void dfs(int currint parent) {
        st[curr] = ++next;
        for (int i = 0; i < adj[curr].size(); i++) {
            if (adj[curr].get(i) != parent) {
                dfs(adj[curr].get(i), curr);
            }
        }
        ft[curr] = next;
    }

    static void build(int sint eint v) {
        t[v] = e - s + 1;
        if (s < e) {
            int mid = (s + e) / 2;
            build(s, mid, 2 * v);
            build(mid + 1, e, 2 * v + 1);
        }

    }

    static void add(int lint rint valint sint eint v) {
        if (l <= s && e <= r) {
            mn[v] += val;
            add[v] += val;
            return;
        }
        if (r < s || e < l) {
            return;
        }
        int m = (s + e) / 2;
        add(l, r, val, s, m, 2 * v);
        add(l, r, val, m + 1, e, 2 * v + 1);
        mn[v] = Math.min(mn[2 * v], mn[2 * v + 1]);
        t[v] = 0;
        if (mn[v] == mn[2 * v]) {
            t[v] += t[2 * v];
        }
        if (mn[v] == mn[2 * v + 1]) {
            t[v] += t[2 * v + 1];
        }
        mn[v] += add[v];
    }
}

class FastIO {
    private final InputStream is;
    private final byte[] buf = new byte[1024];
    private int curChar;

    private int numChars;

    public FastIO() {
        this(System.in);
    }

    public FastIO(final InputStream is) {
        this.is = is;
    }

    public int[] nextArray(final int n) {
        final int[] a = new int[n];
        for (int i = 0; i < n; i++) {
            a[i] = nextInt();
        }
        return a;
    }

    public int read() {
        if (numChars == -1) {
            throw new RuntimeException();
        }
        if (curChar >= numChars) {
            curChar = 0;
            try {
                numChars = is.read(buf);
            } catch (final IOException e) {
                throw new RuntimeException();
            }
            if (numChars <= 0) {
                return -1;
            }
        }
        return buf[curChar++];
    }

    public String nextLine() {
        return readLine();
    }

    public String readLine() {
        int c = read();
        while (isSpaceChar(c)) {
            c = read();
        }
        final StringBuilder sb = new StringBuilder();
        do {
            sb.append((char) c);
            c = read();
        } while (!isEndOfLine(c));
        return sb.toString();
    }

    public String next() {
        int c = read();
        while (isSpaceChar(c)) {
            c = read();
        }
        final StringBuilder sb = new StringBuilder();
        do {
            sb.append((char) c);
            c = read();
        } while (!isSpaceChar(c));
        return sb.toString();
    }

    public long nextLong() {
        int c = read();
        while (isSpaceChar(c))
            c = read();
        int sgn = 1;
        if (c == '-') {
            sgn = -1;
            c = read();
        }
        long res = 0;
        do {
            res *= 10;
            res += c - '0';
            c = read();
        } while (!isSpaceChar(c));
        return res * sgn;
    }

    public int nextInt() {
        int c = read();
        while (isSpaceChar(c))
            c = read();
        int sgn = 1;
        if (c == '-') {
            sgn = -1;
            c = read();
        }
        int res = 0;
        do {
            res *= 10;
            res += c - '0';
            c = read();
        } while (!isSpaceChar(c));
        return res * sgn;
    }

    public boolean isSpaceChar(final int c) {
        return (c == ' ') || (c == '\n') || (c == '\r') || (c == '\t') || (c == -1);
    }

    public boolean isEndOfLine(final int c) {
        return (c == '\n') || (c == '\r') || (c == -1);
    }

}

class Pair {
    int a = Integer.MAX_VALUE, b = 0;
}


No comments:

Post a Comment