import java.util.*;

/*
 * public class ListNode {
 *   int val;
 *   ListNode next = null;
 * }
 */

public class Solution {
    /**
     * 
     * @param head ListNode类 
     * @param x int整型 
     * @return ListNode类
     */
    public ListNode partition (ListNode head, int x) {
        // write code here
        if (head == null) return head;
        ListNode h1 = new ListNode(0);
        ListNode h2 = new ListNode(0);
        ListNode n1 = h1;
        ListNode n2 = h2;
        ListNode tmp = head;
        while (tmp != null) {
            if (tmp.val < x) {
                n1.next = tmp;
                n1 = tmp;
            } else {
                n2.next = tmp;
                n2 = tmp;
            }
            tmp = tmp.next;
        }
        n2.next = null;
        n1.next = h2.next;
        return h1.next;
    }
}