I have added some instrumentation to the function:
def flatten(nested, depth=0):
try:
print("{}Iterate on {}".format(' '*depth, nested))
for sublist in nested:
for element in flatten(sublist, depth+1):
print("{}got back {}".format(' '*depth, element))
yield element
except TypeError:
print('{}not iterable - return {}'.format(' '*depth, nested))
yield nested
Now calling
list(flatten([[1,2],3]))
displays
Iterate on [[1, 2], 3]
Iterate on [1, 2]
Iterate on 1
not iterable - return 1
got back 1
got back 1
Iterate on 2
not iterable - return 2
got back 2
got back 2
Iterate on 3
not iterable - return 3
got back 3