import java.util.*;

/*
public class ListNode {
    int val;
    ListNode next = null;

    ListNode(int val) {
        this.val = val;
    }
}*/
public class Partition {
    public ListNode partition(ListNode pHead, int x) {
        ListNode beforeHead = new ListNode(0);
        ListNode before = beforeHead;
        ListNode afterHead = new ListNode(0);
        ListNode after = afterHead;
        //遍历原始链表
        while (pHead != null) {
            if (pHead.val < x) {
                before.next = pHead;
                before = before.next;
            } else {
                after.next = pHead;
                after = after.next;
            }
            //移动到原始链表的下一个节点
            pHead = pHead.next;
        }
        //将after链表的末尾设置为null,防止成环
        after.next = null;
        before.next = afterHead.next;
        return beforeHead.next;
    }
}