说明

实现TreeSet和TreeMap的一个<mark>重要问题是提供对迭代器类的支持</mark>。
当然,在内部,迭代器保留到迭代中“当前”节点的一个连接
<mstyle mathcolor="&#35;ff0011"> </mstyle> \color{#ff0011}{困难部分是到下一个节点高效的推进}

<mark>存在几种可能的解决方案</mark>,其中的一些方案叙述如下:

  1. 在构造迭代器时,让每个迭代器包含每个TreeSet项的数组作为该迭代器的数据存储。<mark>不足在于,我们完全可以使用toArray,不需要迭代器。</mark>
  2. 让迭代器保留存储通向当前节点的路径上的节点的一个栈。根据这信息,可以退出迭代器中的下一个节点。<mark>不在在于,迭代器会"变大",代码变得臃肿。.</mark>
  3. 让查找树种的每个节点除存储子节点外还要存储它的父节点。<mark>这时候,迭代器不至于那么大,但是在每个节点上需要额外的内存,并且代码依然臃肿</mark>
    (↑↑ <mstyle mathcolor="&#35;ff0011"> </mstyle> \color{#ff0011}{按照练习的要求,本文要用这种方法实现} ↑↑↑)
  4. 让每个节点保留两个附加的链:一个通向下一个更小的节点,另一个下一个更大的节点。<mark>这要占用空间,不过迭代器做起来非常简单(代码不臃肿),并且保留这些链也很容易</mark>。
    (↑↑↑实现:https://blog.csdn.net/LawssssCat/article/details/103099359↑↑↑)
  5. 只对那些具有null左链或null右链的节点保留附加的链。通过使用附加的boolean变量使得这些例程判断是个左链正在被用作标准的二叉树的左链还是一个痛下下一个更小的节点的链。雷士的,对右链也有雷士的判断(见练习4.50)。<mark>这种做法叫做线索树(threaded tree),用于许多平衡二叉查找树的实现中</mark>。


要求

编写TreeSet类的实现程序,其中相关的迭代器使用二叉查找树。
<mark>在每个节点上添加一个指向其父节点的链</mark>



代码

构造方法&常规方法/属性

package cn.edut.tree;
import java.util.Iterator;
import java.util.LinkedList;
import org.junit.Test;
public class MyTreeSet<T extends Comparable<? super T>> {
	private BinaryNode<T> root;
	private int size;
	private int modCount;
	/** * 构造方法 */
	public MyTreeSet() {
		doClear();
		modCount = 0;
	}
	public boolean isEmpty() {
		return size == 0;
	}
	public int size() {
		return size;
	}
	public void makeEmpty() {
		doClear();
	}
	private void doClear() {
		size = 0;
		root = null;
		modCount++;
	}
}



二叉节点

/** * 二叉查找树 * * @param <T> */
	private static class BinaryNode<T> {
		T elementData;
		private BinaryNode<T> left;
		private BinaryNode<T> right;
		private BinaryNode<T> father;

		public BinaryNode(T e, BinaryNode<T> left, BinaryNode<T> right, BinaryNode<T> father) {
			this.elementData = e;
			this.left = left;
			this.right = right;
			this.father = father;
		}
	}


添加

	public void add(T e) {
		root = add(e, root, null);

	}
	private BinaryNode<T> add(T e, BinaryNode<T> currentNode, BinaryNode<T> fatherNode) {
		// 如果为空
		if (currentNode == null) {
			size++;
			modCount++;
			return new BinaryNode<T>(e, null, null, fatherNode);
		}
		// 判断添加数据的大小
		int flag = e.compareTo(currentNode.elementData);
		if (flag < 0) {
			currentNode.left = add(e, currentNode.left, currentNode);
		} else if (flag > 0) {
			currentNode.right = add(e, currentNode.right, currentNode);
		} else {
			// 已经存在了
		}

		return currentNode;
	}


找最小

	private BinaryNode<T> finMinNode(BinaryNode<T> currentNode) {
		return currentNode.left != null ? finMinNode(currentNode.left) : currentNode;
	}



删除

	public void remove(T e) {
		root = remove(e, root);
	}

	private BinaryNode<T> remove(T e, BinaryNode<T> currentNode) {
		// null
		if (currentNode == null) {
			return currentNode;
		}
		int flag = e.compareTo(currentNode.elementData);

		if (flag < 0) {
			currentNode.left = remove(e, currentNode.left);
		} else if (flag > 0) {
			currentNode.right = remove(e, currentNode.right);
		} else if (currentNode.right != null && currentNode.left != null) {
			currentNode.elementData = finMinNode(currentNode.right).elementData;
			currentNode.right = remove(currentNode.elementData, currentNode.right);
		} else {
			currentNode = currentNode.left != null ? currentNode.left : currentNode.right;
			size--;
			modCount++;
		}
		return currentNode;
	}



<mstyle mathcolor="&#35;ff0011"> </mstyle> \color{#ff0011}{**迭代器**}


	

	public Iterator<T> iterator() {
		return new Itr();
	}
	private class Itr implements Iterator<T> {
		private int index;
		private BinaryNode<T> currentNode;
		private BinaryNode<T> prevNode;
		private int expectModCount;
		private boolean okToRemove;

		public Itr() {
			okToRemove = false;
			index = 0;
			currentNode = finMinNode(root);
			expectModCount = modCount;
		}

		@Override
		public boolean hasNext() {
			return index < size();
		}

		/** * 返回下一个节点 */
		private BinaryNode<T> next(BinaryNode<T> currentNode) {
			
			if (currentNode.right != null) {
				return finMinNode(currentNode.right);
			}
			BinaryNode<T> father = currentNode;
			while ((father = father.father) != null) {
				if (father.elementData.compareTo(currentNode.elementData) > 0) {
					return father;
				}
			}
			
			return null;
		}

		@Override
		public T next() {
			isExpectMod();
			BinaryNode<T> next = next(currentNode);
			prevNode = currentNode;
			currentNode = next;
			index++;
			okToRemove = true;
			return prevNode.elementData;
		}

		@Override
		public void remove() {
			isExpectMod();
			if (prevNode == null || !okToRemove) {
				throw new RuntimeException();
			}
			
			root = MyTreeSet.this.remove(prevNode.elementData, root);
			expectModCount++;
			index--;
			okToRemove=false;
		}

		private void isExpectMod() {
			if (expectModCount != modCount) {
				throw new RuntimeException();
			}
		}





测试

两种打印模式

	@Override
	public String toString() {
		return listNode(root).toString();
	}
	public void listAll() {
		if (root != null) {
			listAll_mid(root, 0);
		} else {
			System.out.println("数据为空");
		}
	}
private LinkedList<T> listNode(BinaryNode<T> currentNode) {
		LinkedList<T> sb = new LinkedList<>();
		if (currentNode.left != null) {
			sb.addAll(listNode(currentNode.left));
		}
		sb.add(currentNode.elementData);

		if (currentNode.right != null) {
			sb.addAll(listNode(currentNode.right));
		}
		return sb;
	}

	private void listAll_mid(BinaryNode<T> currentNode, int depth) {
		if (currentNode.right != null) {
			listAll_mid(currentNode.right, depth + 1);
		}
		System.out.print("[");
		for (int i = 0; i < depth; i++) {
			System.out.print("~~ ");
		}
		System.out.println(currentNode.elementData + ":");
		if (currentNode.left != null) {
			listAll_mid(currentNode.left, depth + 1);
		}
	}

	


测试方法

/** * isEmpty() 增删查改 add,toString,listAll, remove,Iterator */
	@Test
	public void test001() {
		/* * 准备阶段 */
		// 创建Set
		MyTreeSet<Integer> mts = new MyTreeSet<>();
		// 查看初始化情况
		System.out.println("size:" + mts.size);
		System.out.println(mts.isEmpty());

		/* * * 添加、打印元素 * */
		// 添加
		mts.add(6);
		for (int i = 0; i < 10; i++) {
			mts.add(i / 2);
			mts.add(i * 3);
			mts.add(i);
		}
		// 打印
		System.out.println("打印元素");
		mts.listAll();
		System.out.println("---toString");
		System.out.println(mts.toString());
		System.out.println("size:" + mts.size());

		/* * 删除、打印 */
		System.out.println("删除6:");
		mts.remove(6);
		mts.listAll();
		System.out.println("size:" + mts.size());

		System.out.println("删除27:");
		mts.remove(27);
		mts.listAll();
		System.out.println("size:" + mts.size());

		System.out.println("删除222:");
		mts.remove(222);
		mts.listAll();
		System.out.println("size:" + mts.size());

		System.out.println("---");
		System.out.println("size:" + mts.size());

		/* * Iterator、Iterator.remove */
		System.out.println("用Iterator遍历:");
		Iterator<Integer> it = mts.iterator();
		while (it.hasNext()) {
			System.out.println(it.next());
			;
			it.remove();
			mts.listAll();
		}
		System.out.println("size:" + mts.size());

	}



完整代码

package cn.edut.tree;

import java.util.Iterator;
import java.util.LinkedList;

import org.junit.Test;


public class MyTreeSet<T extends Comparable<? super T>> {
	/** * isEmpty() 增删查改 add,toString,listAll, remove,Iterator */
	@Test
	public void test001() {
		/* * 准备阶段 */
		// 创建Set
		MyTreeSet<Integer> mts = new MyTreeSet<>();
		// 查看初始化情况
		System.out.println("size:" + mts.size);
		System.out.println(mts.isEmpty());

		/* * * 添加、打印元素 * */
		// 添加
		mts.add(6);
		for (int i = 0; i < 10; i++) {
			mts.add(i / 2);
			mts.add(i * 3);
			mts.add(i);
		}
		// 打印
		System.out.println("打印元素");
		mts.listAll();
		System.out.println("---toString");
		System.out.println(mts.toString());
		System.out.println("size:" + mts.size());

		/* * 删除、打印 */
		System.out.println("删除6:");
		mts.remove(6);
		mts.listAll();
		System.out.println("size:" + mts.size());

		System.out.println("删除27:");
		mts.remove(27);
		mts.listAll();
		System.out.println("size:" + mts.size());

		System.out.println("删除222:");
		mts.remove(222);
		mts.listAll();
		System.out.println("size:" + mts.size());

		System.out.println("---");
		System.out.println("size:" + mts.size());

		/* * Iterator、Iterator.remove */
		System.out.println("用Iterator遍历:");
		Iterator<Integer> it = mts.iterator();
		while (it.hasNext()) {
			System.out.println(it.next());
			;
			it.remove();
			mts.listAll();
		}
		System.out.println("size:" + mts.size());

	}

	private BinaryNode<T> root;
	private int size;
	private int modCount;

	/** * 构造方法 */
	public MyTreeSet() {
		doClear();
		modCount = 0;
	}

	public void add(T e) {
		root = add(e, root, null);

	}

	public void listAll() {
		if (root != null) {
			listAll_mid(root, 0);
		} else {
			System.out.println("数据为空");
		}
	}

	public Iterator<T> iterator() {
		return new Itr();
	}

	public void remove(T e) {
		root = remove(e, root);
	}

	public boolean isEmpty() {
		return size == 0;
	}

	public int size() {
		return size;
	}

	public void makeEmpty() {
		doClear();
	}

	@Override
	public String toString() {
		return listNode(root).toString();
	}

	private LinkedList<T> listNode(BinaryNode<T> currentNode) {
		LinkedList<T> sb = new LinkedList<>();
		if (currentNode.left != null) {
			sb.addAll(listNode(currentNode.left));
		}
		sb.add(currentNode.elementData);

		if (currentNode.right != null) {
			sb.addAll(listNode(currentNode.right));
		}
		return sb;
	}

	private BinaryNode<T> remove(T e, BinaryNode<T> currentNode) {
		// null
		if (currentNode == null) {
			return currentNode;
		}
		int flag = e.compareTo(currentNode.elementData);

		if (flag < 0) {
			currentNode.left = remove(e, currentNode.left);
		} else if (flag > 0) {
			currentNode.right = remove(e, currentNode.right);
		} else if (currentNode.right != null && currentNode.left != null) {
			currentNode.elementData = finMinNode(currentNode.right).elementData;
			currentNode.right = remove(currentNode.elementData, currentNode.right);
		} else {
			currentNode = currentNode.left != null ? currentNode.left : currentNode.right;
			size--;
			modCount++;
		}
		return currentNode;
	}

	private BinaryNode<T> finMinNode(BinaryNode<T> currentNode) {
		return currentNode.left != null ? finMinNode(currentNode.left) : currentNode;
	}

	private void doClear() {
		size = 0;
		root = null;
		modCount++;
	}

	private BinaryNode<T> add(T e, BinaryNode<T> currentNode, BinaryNode<T> fatherNode) {
		// 如果为空
		if (currentNode == null) {
			size++;
			modCount++;
			return new BinaryNode<T>(e, null, null, fatherNode);
		}
		// 判断添加数据的大小
		int flag = e.compareTo(currentNode.elementData);
		if (flag < 0) {
			currentNode.left = add(e, currentNode.left, currentNode);
		} else if (flag > 0) {
			currentNode.right = add(e, currentNode.right, currentNode);
		} else {
			// 已经存在了
		}

		return currentNode;
	}

	private void listAll_mid(BinaryNode<T> currentNode, int depth) {
		if (currentNode.right != null) {
			listAll_mid(currentNode.right, depth + 1);
		}
		System.out.print("[");
		for (int i = 0; i < depth; i++) {
			System.out.print("~~ ");
		}
		System.out.println(currentNode.elementData + ":");
		if (currentNode.left != null) {
			listAll_mid(currentNode.left, depth + 1);
		}
	}

	private class Itr implements Iterator<T> {
		private int index;
		private BinaryNode<T> currentNode;
		private BinaryNode<T> prevNode;
		private int expectModCount;
		private boolean okToRemove;

		public Itr() {
			okToRemove = false;
			index = 0;
			currentNode = finMinNode(root);
			expectModCount = modCount;
		}

		@Override
		public boolean hasNext() {
			return index < size();
		}

		/** * 返回下一个节点 */
		private BinaryNode<T> next(BinaryNode<T> currentNode) {
			
			if (currentNode.right != null) {
				return finMinNode(currentNode.right);
			}
			BinaryNode<T> father = currentNode;
			while ((father = father.father) != null) {
				if (father.elementData.compareTo(currentNode.elementData) > 0) {
					return father;
				}
			}
			
			return null;
		}

		@Override
		public T next() {
			isExpectMod();
			BinaryNode<T> next = next(currentNode);
			prevNode = currentNode;
			currentNode = next;
			index++;
			okToRemove = true;
			return prevNode.elementData;
		}

		@Override
		public void remove() {
			isExpectMod();
			if (prevNode == null || !okToRemove) {
				throw new RuntimeException();
			}
			
			root = MyTreeSet.this.remove(prevNode.elementData, root);
			expectModCount++;
			index--;
			okToRemove=false;
		}

		private void isExpectMod() {
			if (expectModCount != modCount) {
				throw new RuntimeException();
			}
		}

	}

	/** * 二叉查找树 * * @param <T> */
	private static class BinaryNode<T> {
		T elementData;
		private BinaryNode<T> left;
		private BinaryNode<T> right;
		private BinaryNode<T> father;

		public BinaryNode(T e, BinaryNode<T> left, BinaryNode<T> right, BinaryNode<T> father) {
			this.elementData = e;
			this.left = left;
			this.right = right;
			this.father = father;
		}
	}
}



书上的答案

package cn.edut.tree;

class UnderflowException extends Exception {
};

public class MyTreeSet2<AnyType extends Comparable<? super AnyType>> {
	private static class BinaryNode<AnyType> {
		BinaryNode(AnyType theElement) {
			this(theElement, null, null, null);
		}

		BinaryNode(AnyType theElement, BinaryNode<AnyType> lt, BinaryNode<AnyType> rt, BinaryNode<AnyType> pt) {
			element = theElement;
			left = lt;
			right = rt;
			parent = pt;
		}

		AnyType element;
		BinaryNode<AnyType> left;
		BinaryNode<AnyType> right;
		BinaryNode<AnyType> parent;
	}

	public java.util.Iterator<AnyType> iterator() {
		return new MyTreeSet2Iterator();
	}

	private class MyTreeSet2Iterator implements java.util.Iterator<AnyType> {
		private BinaryNode<AnyType> current = findMin(root);
		private BinaryNode<AnyType> previous;
		private int expectedModCount = modCount;
		private boolean okToRemove = false;
		private boolean atEnd = false;

		public boolean hasNext() {
			return !atEnd;
		}

		public AnyType next() {
			if (modCount != expectedModCount)
				throw new java.util.ConcurrentModificationException();
			if (!hasNext())
				throw new java.util.NoSuchElementException();
			AnyType nextItem = current.element;

			previous = current;
			// if there is a right child, next node is min in right subtree
			if (current.right != null) {
				current = findMin(current.right);
			} else {
				// else, find ancestor that it is left of
				BinaryNode<AnyType> child = current;
				current = current.parent;
				while (current != null && current.left != child) {
					child = current;
					current = current.parent;
				}
				if (current == null)
					atEnd = true;
			}
			okToRemove = true;
			return nextItem;
		}

		public void remove() {
			if (modCount != expectedModCount)
				throw new java.util.ConcurrentModificationException();
			if (!okToRemove)
				throw new IllegalStateException();
			MyTreeSet2.this.remove(previous.element);
			okToRemove = false;
		}
	}

	public MyTreeSet2() {
		root = null;
	}

	public void makeEmpty() {
		modCount++;
		root = null;
	}

	public boolean isEmpty() {
		return root == null;
	}

	public boolean contains(AnyType x) {
		return contains(x, root);
	}

	public AnyType findMin() throws UnderflowException {
		if (isEmpty())
			throw new UnderflowException();
		else
			return findMin(root).element;
	}

	public AnyType findMax() throws UnderflowException {
		if (isEmpty())
			throw new UnderflowException();
		else
			return findMax(root).element;
	}

	public void insert(AnyType x) {
		root = insert(x, root, null);
	}

	public void remove(AnyType x) {
		root = remove(x, root);
	}

	public void printTree() {
		if (isEmpty())
			System.out.println("Empty tree");
		else
			printTree(root);
	}

	private void printTree(BinaryNode<AnyType> t) {
		if (t != null) {
			printTree(t.left);
			System.out.println(t.element);
			printTree(t.right);
		}
	}

	private boolean contains(AnyType x, BinaryNode<AnyType> t) {
		if (t == null)
			return false;
		int compareResult = x.compareTo(t.element);
		if (compareResult < 0)
			return contains(x, t.left);
		else if (compareResult > 0)
			return contains(x, t.right);
		else
			return true; // match
	}

	private BinaryNode<AnyType> findMin(BinaryNode<AnyType> t) {
		if (t == null)
			return null;
		else if (t.left == null)
			return t;
		return findMin(t.left);
	}

	private BinaryNode<AnyType> findMax(BinaryNode<AnyType> t) {
		if (t == null)
			return null;
		else if (t.right == null)
			return t;
		return findMax(t.right);
	}

	private BinaryNode<AnyType> insert(AnyType x, BinaryNode<AnyType> t, BinaryNode<AnyType> pt) {
		if (t == null) {
			modCount++;
			return new BinaryNode<AnyType>(x, null, null, pt);
		}
		int compareResult = x.compareTo(t.element);
		if (compareResult < 0)
			t.left = insert(x, t.left, t);
		else if (compareResult > 0)
			t.right = insert(x, t.right, t);
		else
			; // duplicate
		return t;
	}

	private BinaryNode<AnyType> remove(AnyType x, BinaryNode<AnyType> t) {
		if (t == null)
			return t; // not found
		int compareResult = x.compareTo(t.element);
		if (compareResult < 0)
			t.left = remove(x, t.left);
		else if (compareResult > 0)
			t.right = remove(x, t.right);
		else if (t.left != null && t.right != null) // two children
		{
			t.element = findMin(t.right).element;
			t.right = remove(t.element, t.right);
		} else {
			modCount++;
			BinaryNode<AnyType> oneChild;
			oneChild = (t.left != null) ? t.left : t.right;
			oneChild.parent = t.parent; // update parent link
			t = oneChild;
		}
		return t;
	}

	private BinaryNode<AnyType> root;
	int modCount = 0;
}