二叉搜索树的中序遍历是有序的,如果二叉搜索树中两个节点被互换了,那么其中序遍历中必定有两个节点“错位”,因此中序遍历是解题的关键。中序遍历本身不难,但是题目要求常数级别的空间复杂度,因此想到了线索二叉树。

总结下来两种思路:

  1. 空间复杂度为O(n)——线索二叉树
  2. 空间复杂度为O(logn)——递归,这里用了指针的引用这一特殊语法

线索二叉树

//
// Created by jt on 2020/8/22.
//
#include <cstdio>
using namespace std;

class Solution {
public:
    void recoverTree(TreeNode *root) {
        // pre保存上一个节点或者第一个错位节点
        // chg保存第二个错位节点, detected用来指示是否检测到了第一个错位节点
        TreeNode *pre = nullptr, *chg = nullptr, *cur = root;
        bool detected = false;
        while (cur) {
            if (cur->left) {
                // 如果存在左子树
                TreeNode *p = cur->left;
                while (p->right && p->right != cur) p = p->right;
                if (p->right == cur) {
                    // 如果已经建立过线索,更新pre和chg,然后消除线索
                    if (pre && pre->val > cur->val) {chg = cur; detected = true;}
                    if (!pre || !detected) pre = cur;
                    p->right = nullptr;
                    cur = cur->right;
                } else {
                    p->right = cur;
                    cur = cur->left;
                }
            } else {
                // 如果不存在左子树,根据线索访问
                if (pre && pre->val > cur->val) {chg = cur; detected = true;}
                if (!pre || !detected) pre = cur;
                cur = cur->right;
            }
        }
        if (pre && chg) {
            int tmp = chg->val;
            chg->val = pre->val;
            pre->val = tmp;
        }
    }
};

递归实现

//
// Created by jt on 2020/8/22.
//
#include <cstdio>
using namespace std;

class Solution {
public:
    void recoverTree(TreeNode *root) {
        TreeNode *prev = nullptr, *current = nullptr;
        inOrder(root, prev, current);
        if (prev && current) {
            int tmp = prev->val;
            prev->val = current->val;
            current->val = tmp;
        }
    }

    void inOrder(TreeNode *root, TreeNode *&prev, TreeNode *&current) {
        if (!root) return;
        if (root->left) inOrder(root->left, prev, current);

        if (prev && prev->val > root->val) current = root;
        if (!prev || !current) prev = root;

        if (root->right) inOrder(root->right, prev, current);
        return;
    }
};