class Solution {
public:
    vector<int> printMatrix(vector<vector<int> > matrix) {
        vector<int> res;
        int m = matrix.size(), n = matrix[0].size();
        int layer = min((m + 1) / 2, (n + 1) / 2);
        for (int k = 0; k < layer; k++) {
            if (m - 2 * k == 1) {
                for (int j = k; j < n - k; j++) {
                    res.push_back(matrix[k][j]);
                }
            }
            else if (n - 2 * k == 1) {
                for (int i = k; i < m - k; i++) {
                    res.push_back(matrix[i][n - 1 - k]);
                }
            }
            else {
                for (int j = k; j < n - 1 - k; j++) {
                    res.push_back(matrix[k][j]);
                }
                for (int i = k; i < m - 1 - k; i++) {
                    res.push_back(matrix[i][n - 1 - k]);
                }
                for (int j = n - 1 - k; j > k; j--) {
                    res.push_back(matrix[m - 1 - k][j]);
                }
                for (int i = m - 1 - k; i > k; i--) {
                    res.push_back(matrix[i][k]);
                }
            }
        }
        return res;
    }
};