import java.util.*;
import java.util.List;
import java.util.ArrayList;
import java.util.Stack;

/*
 * public class ListNode {
 *   int val;
 *   ListNode next = null;
 * }
 */

public class Solution {
    /**
     * 
     * @param head ListNode类 
     * @param k int整型 
     * @return ListNode类
     */
    public ListNode reverseKGroup (ListNode head, int k) {
        if(head == null || head.next == null){
            return head;
        }
        int size = getNodeSize(head);
        if(size < k){
            return head;
        }
        //第一步对链表进行切分
        List<Stack<ListNode>> list = new ArrayList<>();
        int curLength = 0;
        Stack<ListNode> stack = new Stack<>();

        while(size >= (k - curLength)){
            if(curLength == k){
                list.add(stack);
                stack = new Stack<>();
                curLength = 0;
            } else {
                stack.push(head);
                head = head.next;
                curLength ++;
                size --;
            }
        }
        //第二步重新组装
        ListNode newHead = null;
        ListNode tail = null;
        for(int i = 0;i<list.size();i++){
            Stack<ListNode> next = list.get(i);
            while(!next.isEmpty()){
                ListNode node = next.pop();
                if(newHead == null){
                    newHead = node;
                    tail = newHead;
                } else {
                    tail.next = node;
                    tail = tail.next;
                }
            }
        }

        tail.next = head;
        return newHead;
    }
    public int getNodeSize(ListNode node){
        if(node == null){
            return 0;
        } else {
            ListNode temp = node;
            int res = 0;
            while(temp != null){
                res ++;
                temp = temp.next;
            }
            return res;
        }
    }

}