Leetcode-Count Complete Tree Nodes(Java)

Question:

Given a complete binary tree, count the number of nodes.

Definition of a complete binary tree from Wikipedia:
In a complete binary tree every level, except possibly the last, is completely filled, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.

Thinking:

Firstly, I want to traverse the tree, but we can do better. Because we know its structure is special. For example, if this root’s height is this root’s right child’s height minus 1, its left subtree must be a full tree. So we need to caclulate the rest nodes which remain in its right subtree. Otherwise, we need to go to left.

Solution:

public int countNodes(TreeNode root) {
    int h = height(root);
    if (h == 0)
        return 0;
    if (h == height(root.right)+1){
        return (1 << (h-1)) + countNodes(root.right);
    } else {
        return (1 << (h-2)) + countNodes(root.left);
    }
}

private int height(TreeNode root) {
    return root == null? 0 : 1 + height(root.left);
}