import java.util.*;

/*
 * public class ListNode {
 *   int val;
 *   ListNode next = null;
 * }
 */

public class Solution {
    /**
     *
     * @param head ListNode类
     * @param k int整型
     * @return ListNode类
     */
    public ListNode reverseKGroup (ListNode head, int k) {
        // 先处理特殊情况
        if (head == null || head.next == null || k == 1) return head;
        //得到链表的长度
        int length = getLength(head);
        if (k > length) {
            return head;
        }
        //处理常规情况
        int groupCount = length / k;
        //最后一组是否满k个
        boolean isLastGroupLessThanK = false;
        if (length % k != 0) {
            groupCount += 1;
            isLastGroupLessThanK = true;
        }
        ListNode newHead = null;
        //用来记录上一组的尾部
        ListNode lastTail  = null;
        for (int i = 0; i < groupCount; i++) {
            //第一组和最后一组要特殊处理
            if (i == 0) {
                //第一组会得到头部
                //翻转k次
                ListNode pre = null;
                lastTail = head;
                for (int j = 0; j < k; j++) {
                    ListNode next = head.next;
                    head.next = pre;
                    pre = head;
                    head = next;
                }
                //执行到这里时 pre 是新头部,head 是下一组的头部
                newHead = pre;
            } else if (i == groupCount - 1 && isLastGroupLessThanK) {
                //最后一组 并且是不够K个要特殊处理
                lastTail.next = head;
            } else {
                ListNode pre = null;
                ListNode willTail = head;
                for (int j = 0; j < k; j++) {
                    ListNode next = head.next;
                    head.next = pre;
                    pre = head;
                    head = next;
                }
                //执行到这里后 pre 是当前组的头部 需要和上一组进行拼接
                lastTail.next = pre;
                lastTail = willTail;
            }
        }
        return newHead;
    }

    private int getLength(ListNode head) {
        int length = 0;
        ListNode temp = head;
        while (temp != null) {
            length ++;
            temp = temp.next;
        }
        return length;
    }
}