def main():
import sys
input = sys.stdin.read().split()
ptr = 0
n = int(input[ptr])
ptr += 1
m = int(input[ptr])
ptr += 1
r = list(map(int, input[ptr:ptr+n]))
ptr += n
c = list(map(int, input[ptr:ptr+m]))
ptr += m
# 合并并标记类型:(数值, 类型),类型0为行,1为列
arr = []
for num in r:
arr.append((-num, 0)) # 用负数实现降序排序(避免用reverse,更快)
for num in c:
arr.append((-num, 1))
# 排序
arr.sort()
# 若行 i 的数值 r [i] > 列 j 的数值 c [j],则让行 i 最后操作(覆盖列 j)更优;反之则让列 j 最后操作。
# 将所有行和列的数值合并排序,按从大到小处理:
# 处理一个数值时,若它是行值,则它能覆盖所有未被更大列值覆盖的列;
# 若它是列值,则它能覆盖所有未被更大行值覆盖的行。
remaining_rows = n
remaining_cols = m
total = 0
for val, typ in arr:
val = -val # 恢复原数值
if typ == 0:
# 行操作:覆盖剩余的列
total += val * remaining_cols
#能覆盖的行数-1
remaining_rows -= 1
else:
# 列操作:覆盖剩余的行
total += val * remaining_rows
#剩余的列-1
remaining_cols -= 1
print(total)
if __name__ == "__main__":
main()