import java.util.*;
import java.util.ArrayList;
public class Solution {
    public ArrayList<Integer> printMatrix(int [][] matrix) {
        int x1 = 0;
        int x2 = matrix[0].length - 1;
        int y1 = 0;
        int y2 = matrix.length - 1;

        int curX = 0;
        int curY = 0;

        ArrayList<Integer> result = new ArrayList<>();
        // 0 -> x, 1-> y, 2 -> -x, 3 -> -y
        int flag = 0;
        while (curX <= x2 && curX >= x1 && curY <= y2 && curY >= y1) {
            result.add(matrix[curY][curX]);
            if (flag == 0) {
                if (curX == x2) {
                    curY++;
                    y1++;
                    flag = 1;
                } else {
                    curX++;
                }
            } else if (flag == 1) {
                if (curY == y2) {
                    curX--;
                    x2--;
                    flag = 2;
                } else {
                    curY++;
                }
            } else if (flag == 2) {
                if (curX == x1) {
                    curY--;
                    y2--;
                    flag = 3;
                } else {
                    curX--;
                }
            } else {
                if (curY == y1) {
                    curX++;
                    x1++;
                    flag = 0;
                } else {
                    curY--;
                }
            }
        }


        return result;
    }
}