import java.util.*;

/*
 * public class ListNode {
 *   int val;
 *   ListNode next = null;
 * }
 */

public class Solution {
    /**
     * 
     * @param head ListNode类 
     * @param x int整型 
     * @return ListNode类
     */
    public ListNode partition (ListNode head, int x) {
        // write code here
        if (head == null || head.next == null) {
            return head;
        }
        ListNode pre = new ListNode(0), great = new ListNode(0), node = head, xx = pre, yy = great;
        while (node != null) {
            if (node.val < x) {
                pre.next = node;
                pre = pre.next;
            } else {
                great.next = node;
                great = great.next;
            }
            node = node.next;
        }
        great.next = null;
        pre.next = yy.next;
        return xx.next;
    }
}