每辆火车都会进出站一次,n 辆火车就是 n 次进站,n 次出站,以下一次是 进站 还是 出站 进行回溯
import java.util.*;
public class Train {
/**
* 最终的所有结果
*/
private static final List<String> RES = new ArrayList<>();
/**
* 存储出站的结果
*/
private static final List<String> TEMP = new ArrayList<>();
/**
* 存储在火车站内的火车
*/
private static final Deque<String> STATION = new LinkedList<>();
/**
* 回溯求解
*
* @param arr 输入的数组
* @param idx 当前处理的数组元素下标
* @param n 数组长度
* @param pushCount 进站的次数
* @param popCount 出站的次数
*/
public static void bak(String[] arr, int idx, int n, int pushCount, int popCount) {
// 每辆火车都会进出一次,进总次数和出总次数 == n 跳出
if (pushCount == n && popCount == n) {
RES.add(String.join(" ", TEMP));
return;
}
// pushCount < n 的时候可以进站
if (pushCount < n) {
STATION.push(arr[idx]);
bak(arr, idx + 1, n, pushCount + 1, popCount);
STATION.pop();
}
// station 不为空,隐含了一个条件 popCount < pushCount
if (!STATION.isEmpty() ) {
String out = STATION.pop();
TEMP.add(out);
bak(arr, idx, n, pushCount, popCount + 1);
STATION.push(out);
TEMP.remove(TEMP.size() - 1);
}
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
while (in.hasNext()) {
int len = in.nextInt();
String[] arr = new String[len];
for (int i = 0; i < len; i++) {
arr[i] = in.next();
}
bak(arr, 0, arr.length, 0, 0);
Collections.sort(RES);
for (String s : RES) {
System.out.println(s);
}
}
in.close();
}
} 下面是 Python 的代码,一样的道理
from typing import List
class Solution:
def __init__(self):
self._res: List[str] = []
self._temp: List[str] = []
self._station: List[str] = []
def _bak(self,
arr: List[str],
idx: int,
n: int,
pushCount: int,
popCount: int):
if pushCount == n and popCount == n:
self._res.append(' '.join(self._temp))
return
if pushCount < n:
self._station.append(arr[idx])
self._bak(arr, idx + 1, n, pushCount + 1, popCount)
self._station.pop()
if len(self._station) > 0:
out: str = self._station.pop()
self._temp.append(out)
self._bak(arr, idx, n, pushCount, popCount + 1)
self._station.append(out)
self._temp.pop()
def process(self, n: int, arr: List[str]) -> List[str]:
self._bak(arr, 0, n, 0, 0)
self._res.sort()
return self._res
while True:
try:
n = int(input())
arr = input().strip().split(' ')
res = Solution().process(n, arr)
print('\n'.join(res))
except:
break 
京公网安备 11010502036488号