import java.util.*;

/*
public class ListNode {
    int val;
    ListNode next = null;

    ListNode(int val) {
        this.val = val;
    }
}*/
public class Partition {
    public ListNode partition(ListNode pHead, int x) {
        // write code here
    ListNode point= pHead;
    ListNode s1h = null;
    ListNode s1e = null;
    ListNode s2h = null;
    ListNode s2e = null;

    while(point !=null){
        if(point.val < x){
            if(s1h==null){
                s1h=point;
                s1e=point;
            }else{
                s1e.next = point;
                s1e = point;
            }
        }else{
            if(s2h==null){
                s2h=point;
                s2e=point;
            }else{
                s2e.next = point;
                s2e = point;
            }
        }
        point = point.next;
    }

    if(s2h!=null){
        s2e.next=null;
    }

    if(s1h==null){
        return s2h;
    }

    s1e.next = s2h;
    return s1h;

    }

}