It’s actually pretty simple (FWIW, I originally thought to do it your way):
df['bar', 'three'] = [0, 1, 2]
df = df.sort_index(axis=1)
print(df)
bar baz
one two three one two
A -0.212901 0.503615 0 -1.660945 0.446778
B -0.803926 -0.417570 1 -0.336827 0.989343
C 3.400885 -0.214245 2 0.895745 1.011671