import java.util.*;

public class Solution {
    public int[][] rotateMatrix(int[][] mat, int n) {
        // write code here
        if (n == 0){
            return new int[][]{};
        }
        int firstRow = 0;
        int firstCol = 0;
        int lastRow = n - 1;
        int lastCol = n - 1;
        while (firstRow < lastRow && firstCol < lastCol){
            process(mat, firstRow++, firstCol++, lastRow--, lastCol--);
        }
        return mat;
    }
    public static void process(int[][] mat, int fr, int fc, int lr, int lc){
        if (fr >= lr || fc >= lc){
            return;
        }
        int temp = 0;
        for (int i = 0; i < lc - fc; i++) {
            temp = mat[fr][fc + i];
            mat[fr][fc + i] = mat[lr - i][fc];
            mat[lr - i][fc] = mat[lr][lc - i];
            mat[lr][lc - i] = mat[fr + i][lc];
            mat[fr + i][lc] = temp;
        }
    }
}