def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    n = int(input[ptr])
    ptr += 1
    m = int(input[ptr])
    ptr += 1
    
    r = list(map(int, input[ptr:ptr+n]))
    ptr += n
    c = list(map(int, input[ptr:ptr+m]))
    ptr += m
    
    # 合并并标记类型:(数值, 类型),类型0为行,1为列
    arr = []
    for num in r:
        arr.append((-num, 0))  # 用负数实现降序排序(避免用reverse,更快)
    for num in c:
        arr.append((-num, 1))
    
    # 排序
    arr.sort()
    # 若行 i 的数值 r [i] > 列 j 的数值 c [j],则让行 i 最后操作(覆盖列 j)更优;反之则让列 j 最后操作。
    # 将所有行和列的数值合并排序,按从大到小处理:
    # 处理一个数值时,若它是行值,则它能覆盖所有未被更大列值覆盖的列;
    # 若它是列值,则它能覆盖所有未被更大行值覆盖的行。
    remaining_rows = n
    remaining_cols = m
    total = 0
    
    for val, typ in arr:
        val = -val  # 恢复原数值
        if typ == 0:
            # 行操作:覆盖剩余的列
            total += val * remaining_cols
            #能覆盖的行数-1
            remaining_rows -= 1
        else:
            # 列操作:覆盖剩余的行
            total += val * remaining_rows
            #剩余的列-1
            remaining_cols -= 1
    
    print(total)

if __name__ == "__main__":
    main()