I was trying some of the solutions here but then I actually came up with my own one. I hope this might be useful for the next one so I share it here:
def sort_correlation_matrix(correlation_matrix):
cor = correlation_matrix.abs()
top_col = cor[cor.columns[0]][1:]
top_col = top_col.sort_values(ascending=False)
ordered_columns = [cor.columns[0]] + top_col.index.tolist()
return correlation_matrix[ordered_columns].reindex(ordered_columns)