Leetcode-Minimum Height Trees(Java)

Question:

For a undirected graph with tree characteristics, we can choose any node as the root. The result graph is then a rooted tree. Among all possible rooted trees, those with minimum height are called minimum height trees (MHTs). Given such a graph, write a function to find all the MHTs and return a list of their root labels.

Format
The graph contains n nodes which are labeled from 0 to n - 1. You will be given the number n and a list of undirected edges (each edge is a pair of labels).

You can assume that no duplicate edges will appear in edges. Since all edges are undirected, [0, 1] is the same as [1, 0] and thus will not appear together in edges.

Example 1:

Given n = 4, edges = [[1, 0], [1, 2], [1, 3]]

  0
  |
  1
 / \
2   3

return [1]

Example 2:

Given n = 6, edges = [[0, 3], [1, 3], [2, 3], [4, 3], [5, 4]]

0  1  2
 \ | /
   3
   |
   4
   |
   5

return [3, 4]

Hint:

How many MHTs can a graph have at most?
Note:

(1) According to the definition of tree on Wikipedia: “a tree is an undirected graph in which any two vertices are connected by exactly one path. In other words, any connected graph without simple cycles is a tree.”

(2) The height of a rooted tree is the number of edges on the longest downward path between the root and a leaf.

Thinking:

We should track the path from every leaves until there are only one or two nodes left. In other words, in the middle of the graph will be the root of the minimum height tree.

Solution:

public List<Integer> findMinHeightTrees(int n, int[][] edges) {
    if (n == 1){
        List<Integer> res = new ArrayList<Integer>();
        res.add(0);
        return res;
    }
    List<Integer> leaves = new ArrayList<Integer>();
    List<Set<Integer>> adj = new ArrayList<Set<Integer>>(n);
    for (int i = 0; i < n; i++)
        adj.add(new HashSet<Integer>());
    for (int[] edge: edges){
        adj.get(edge[0]).add(edge[1]);
        adj.get(edge[1]).add(edge[0]);
    }
    for (int i = 0; i < n; i++){
        if (adj.get(i).size() == 1)
            leaves.add(i);
    }

    while (n > 2){
        n -= leaves.size();
        List<Integer> newLeaves = new ArrayList<Integer>();
        for (int i: leaves){
            int j = adj.get(i).iterator().next();
            adj.get(j).remove(i);
            if (adj.get(j).size() == 1)
                newLeaves.add(j);
        }
        leaves = newLeaves;
    }

    return leaves;
}

Reference:https://leetcode.com/discuss/71763/share-some-thoughts

My previous code(LTE):

public List<Integer> findMinHeightTrees(int n, int[][] edges) {
    List<Integer> res = new ArrayList<Integer>();
    int min = Integer.MAX_VALUE;
    int res1 = -1;
    int res2 = -1;

    for (int i = 0; i < n; i++){
        int temp = bfs(n, i, edges);
        if (temp < min){
            min = temp;
            res1 = i;
        }
        else if (temp == min){
            res2 = i;
        }
    }

    if (res1 != -1)
        res.add(res1);
    if (res2 != -1)
        res.add(res2);

    return res;
}

private int bfs(int n, int i, int[][] edges){
    int height = 0;
    boolean[] used = new boolean[edges.length];
    Queue<Integer> q = new LinkedList<Integer>();
    q.add(i);

    while (!q.isEmpty()){
        int num = q.size();
        while (num > 0){
            int temp = q.poll();
            for (int j = 0; j < edges.length; j++){
                if (edges[j][0] == temp || edges[j][1] == temp){
                    if (used[j] == false){
                        if (edges[j][0] == temp)
                            q.add(edges[j][1]);
                        else
                            q.add(edges[j][0]);
                        used[j] = true;
                    }
                }
            }
            num--;
        }
        height++;
    }

    return height;
}