n, m = map(int, input().split())
p = n * m
tmp=n % m
while tmp != 0:
    n = m
    m = tmp
    tmp = n % m
max = m
min = p//max 
print(max+min)