You can pass plt.scatter
a c
argument which will allow you to select the colors. The code below defines a colors
dictionary to map your diamond colors to the plotting colors.
import matplotlib.pyplot as plt
import pandas as pd
carat = [5, 10, 20, 30, 5, 10, 20, 30, 5, 10, 20, 30]
price = [100, 100, 200, 200, 300, 300, 400, 400, 500, 500, 600, 600]
color =['D', 'D', 'D', 'E', 'E', 'E', 'F', 'F', 'F', 'G', 'G', 'G',]
df = pd.DataFrame(dict(carat=carat, price=price, color=color))
fig, ax = plt.subplots()
colors = {'D':'red', 'E':'blue', 'F':'green', 'G':'black'}
ax.scatter(df['carat'], df['price'], c=df['color'].apply(lambda x: colors[x]))
plt.show()
df['color'].apply(lambda x: colors[x])
effectively maps the colours from "diamond" to "plotting".
(Forgive me for not putting another example image up, I think 2 is enough :P)
seaborn
You can use seaborn
which is a wrapper around matplotlib
that makes it look prettier by default (rather opinion-based, I know :P) but also adds some plotting functions.
For this you could use seaborn.lmplot
with fit_reg=False
(which prevents it from automatically doing some regression).
The code below uses an example dataset. By selecting hue='color'
you tell seaborn to split your dataframe up based on your colours and then plot each one.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
carat = [5, 10, 20, 30, 5, 10, 20, 30, 5, 10, 20, 30]
price = [100, 100, 200, 200, 300, 300, 400, 400, 500, 500, 600, 600]
color =['D', 'D', 'D', 'E', 'E', 'E', 'F', 'F', 'F', 'G', 'G', 'G',]
df = pd.DataFrame(dict(carat=carat, price=price, color=color))
sns.lmplot('carat', 'price', data=df, hue='color', fit_reg=False)
plt.show()
seaborn
using pandas.groupby
If you don't want to use seaborn then you can use pandas.groupby
to get the colors alone and then plot them using just matplotlib, but you'll have to manually assign colors as you go, I've added an example below:
fig, ax = plt.subplots()
colors = {'D':'red', 'E':'blue', 'F':'green', 'G':'black'}
grouped = df.groupby('color')
for key, group in grouped:
group.plot(ax=ax, kind='scatter', x='carat', y='price', label=key, color=colors[key])
plt.show()
This code assumes the same DataFrame as above and then groups it based on color
. It then iterates over these groups, plotting for each one. To select a color I've created a colors
dictionary which can map the diamond color (for instance D
) to a real color (for instance red
).