[scala] How do I check for equality using Spark Dataframe without SQL Query?

I want to select a column that equals to a certain value. I am doing this in scala and having a little trouble.

Heres my code


this returns the state column with boolean values instead of just TX

Ive also tried


but this doesn't work either.

This question is related to scala apache-spark dataframe apache-spark-sql

The answer is

We can write multiple Filter/where conditions in Dataframe.

For example:

.filter($"Col_1_name" === "buddy")  // check for equal to string
.filter($"Col_2_name" === "A")
.filter(not($"Col_2_name".contains(" .sql")))  // filter a string which is    not relevent
.filter("Col_2_name is not null")   // no null filter

In Spark 2.4

To compare with one value:

df.filter(lower(trim($"col_name")) === "<value>").show()

To compare with collection of value:

df.filter($"col_name".isInCollection(new HashSet<>(Arrays.asList("value1", "value2")))).show()

Let's create a sample dataset and do a deep dive into exactly why OP's code didn't work.

Here's our sample data:

val df = Seq(
  ("Rockets", 2, "TX"),
  ("Warriors", 6, "CA"),
  ("Spurs", 5, "TX"),
  ("Knicks", 2, "NY")
).toDF("team_name", "num_championships", "state")

We can pretty print our dataset with the show() method:

|  Rockets|                2|   TX|
| Warriors|                6|   CA|
|    Spurs|                5|   TX|
|   Knicks|                2|   NY|

Let's examine the results of df.select(df("state")==="TX").show():

|(state = TX)|
|        true|
|       false|
|        true|
|       false|

It's easier to understand this result by simply appending a column - df.withColumn("is_state_tx", df("state")==="TX").show():

|  Rockets|                2|   TX|       true|
| Warriors|                6|   CA|      false|
|    Spurs|                5|   TX|       true|
|   Knicks|                2|   NY|      false|

The other code OP tried (df.select(df("state")=="TX").show()) returns this error:

<console>:27: error: overloaded method value select with alternatives:
  [U1](c1: org.apache.spark.sql.TypedColumn[org.apache.spark.sql.Row,U1])org.apache.spark.sql.Dataset[U1] <and>
  (col: String,cols: String*)org.apache.spark.sql.DataFrame <and>
  (cols: org.apache.spark.sql.Column*)org.apache.spark.sql.DataFrame
 cannot be applied to (Boolean)

The === operator is defined in the Column class. The Column class doesn't define a == operator and that's why this code is erroring out. Read this blog for more background information about the Spark Column class.

Here's the accepted answer that works:


|  Rockets|                2|   TX|
|    Spurs|                5|   TX|

As other posters have mentioned, the === method takes an argument with an Any type, so this isn't the only solution that works. This works too for example:

df.filter(df("state") === lit("TX")).show

|  Rockets|                2|   TX|
|    Spurs|                5|   TX|

The Column equalTo method can also be used:


|  Rockets|                2|   TX|
|    Spurs|                5|   TX|

It worthwhile studying this example in detail. Scala's syntax seems magical at times, especially when method are invoked without dot notation. It's hard for the untrained eye to see that === is a method defined in the Column class!

See this blog post if you'd like even more details on Spark Column equality.

To get the negation, do this ...

df.filter(not( ..expression.. ))


df.filter(not($"state" === "TX"))

df.filter($"state" like "T%%") for pattern matching

df.filter($"state" === "TX") or df.filter("state = 'TX'") for equality

Here is the complete example using spark2.2+ taking data in json...

val myjson = "[{\"name\":\"Alabama\",\"abbreviation\":\"AL\"},{\"name\":\"Alaska\",\"abbreviation\":\"AK\"},{\"name\":\"American Samoa\",\"abbreviation\":\"AS\"},{\"name\":\"Arizona\",\"abbreviation\":\"AZ\"},{\"name\":\"Arkansas\",\"abbreviation\":\"AR\"},{\"name\":\"California\",\"abbreviation\":\"CA\"},{\"name\":\"Colorado\",\"abbreviation\":\"CO\"},{\"name\":\"Connecticut\",\"abbreviation\":\"CT\"},{\"name\":\"Delaware\",\"abbreviation\":\"DE\"},{\"name\":\"District Of Columbia\",\"abbreviation\":\"DC\"},{\"name\":\"Federated States Of Micronesia\",\"abbreviation\":\"FM\"},{\"name\":\"Florida\",\"abbreviation\":\"FL\"},{\"name\":\"Georgia\",\"abbreviation\":\"GA\"},{\"name\":\"Guam\",\"abbreviation\":\"GU\"},{\"name\":\"Hawaii\",\"abbreviation\":\"HI\"},{\"name\":\"Idaho\",\"abbreviation\":\"ID\"},{\"name\":\"Illinois\",\"abbreviation\":\"IL\"},{\"name\":\"Indiana\",\"abbreviation\":\"IN\"},{\"name\":\"Iowa\",\"abbreviation\":\"IA\"},{\"name\":\"Kansas\",\"abbreviation\":\"KS\"},{\"name\":\"Kentucky\",\"abbreviation\":\"KY\"},{\"name\":\"Louisiana\",\"abbreviation\":\"LA\"},{\"name\":\"Maine\",\"abbreviation\":\"ME\"},{\"name\":\"Marshall Islands\",\"abbreviation\":\"MH\"},{\"name\":\"Maryland\",\"abbreviation\":\"MD\"},{\"name\":\"Massachusetts\",\"abbreviation\":\"MA\"},{\"name\":\"Michigan\",\"abbreviation\":\"MI\"},{\"name\":\"Minnesota\",\"abbreviation\":\"MN\"},{\"name\":\"Mississippi\",\"abbreviation\":\"MS\"},{\"name\":\"Missouri\",\"abbreviation\":\"MO\"},{\"name\":\"Montana\",\"abbreviation\":\"MT\"},{\"name\":\"Nebraska\",\"abbreviation\":\"NE\"},{\"name\":\"Nevada\",\"abbreviation\":\"NV\"},{\"name\":\"New Hampshire\",\"abbreviation\":\"NH\"},{\"name\":\"New Jersey\",\"abbreviation\":\"NJ\"},{\"name\":\"New Mexico\",\"abbreviation\":\"NM\"},{\"name\":\"New York\",\"abbreviation\":\"NY\"},{\"name\":\"North Carolina\",\"abbreviation\":\"NC\"},{\"name\":\"North Dakota\",\"abbreviation\":\"ND\"},{\"name\":\"Northern Mariana Islands\",\"abbreviation\":\"MP\"},{\"name\":\"Ohio\",\"abbreviation\":\"OH\"},{\"name\":\"Oklahoma\",\"abbreviation\":\"OK\"},{\"name\":\"Oregon\",\"abbreviation\":\"OR\"},{\"name\":\"Palau\",\"abbreviation\":\"PW\"},{\"name\":\"Pennsylvania\",\"abbreviation\":\"PA\"},{\"name\":\"Puerto Rico\",\"abbreviation\":\"PR\"},{\"name\":\"Rhode Island\",\"abbreviation\":\"RI\"},{\"name\":\"South Carolina\",\"abbreviation\":\"SC\"},{\"name\":\"South Dakota\",\"abbreviation\":\"SD\"},{\"name\":\"Tennessee\",\"abbreviation\":\"TN\"},{\"name\":\"Texas\",\"abbreviation\":\"TX\"},{\"name\":\"Utah\",\"abbreviation\":\"UT\"},{\"name\":\"Vermont\",\"abbreviation\":\"VT\"},{\"name\":\"Virgin Islands\",\"abbreviation\":\"VI\"},{\"name\":\"Virginia\",\"abbreviation\":\"VA\"},{\"name\":\"Washington\",\"abbreviation\":\"WA\"},{\"name\":\"West Virginia\",\"abbreviation\":\"WV\"},{\"name\":\"Wisconsin\",\"abbreviation\":\"WI\"},{\"name\":\"Wyoming\",\"abbreviation\":\"WY\"}]"
import spark.implicits._
val df = spark.read.json(Seq(myjson).toDS)
   import spark.implicits._
    val df = spark.read.json(Seq(myjson).toDS)

    scala> df.show
    |abbreviation|                name|
    |          AL|             Alabama|
    |          AK|              Alaska|
    |          AS|      American Samoa|
    |          AZ|             Arizona|
    |          AR|            Arkansas|
    |          CA|          California|
    |          CO|            Colorado|
    |          CT|         Connecticut|
    |          DE|            Delaware|
    |          DC|District Of Columbia|
    |          FM|Federated States ...|
    |          FL|             Florida|
    |          GA|             Georgia|
    |          GU|                Guam|
    |          HI|              Hawaii|
    |          ID|               Idaho|
    |          IL|            Illinois|
    |          IN|             Indiana|
    |          IA|                Iowa|
    |          KS|              Kansas|

    // equals matching
    scala> df.filter(df("abbreviation") === "TX").show
    |abbreviation| name|
    |          TX|Texas|
    // or using lit

    scala> df.filter(df("abbreviation") === lit("TX")).show
    |abbreviation| name|
    |          TX|Texas|

    //not expression
    scala> df.filter(not(df("abbreviation") === "TX")).show
    |abbreviation|                name|
    |          AL|             Alabama|
    |          AK|              Alaska|
    |          AS|      American Samoa|
    |          AZ|             Arizona|
    |          AR|            Arkansas|
    |          CA|          California|
    |          CO|            Colorado|
    |          CT|         Connecticut|
    |          DE|            Delaware|
    |          DC|District Of Columbia|
    |          FM|Federated States ...|
    |          FL|             Florida|
    |          GA|             Georgia|
    |          GU|                Guam|
    |          HI|              Hawaii|
    |          ID|               Idaho|
    |          IL|            Illinois|
    |          IN|             Indiana|
    |          IA|                Iowa|
    |          KS|              Kansas|
    only showing top 20 rows

Worked on Spark V2.*

import sqlContext.implicits._
df.filter($"state" === "TX")

if needs to be compared against a variable (e.g., var):

import sqlContext.implicits._
df.filter($"state" === var)

Note : import sqlContext.implicits._

There is another simple sql like option. With Spark 1.6 below also should work.

df.filter("state = 'TX'")

This is a new way of specifying sql like filters. For a full list of supported operators, check out this class.

You should be using where, select is a projection that returns the output of the statement, thus why you get boolean values. where is a filter that keeps the structure of the dataframe, but only keeps data where the filter works.

Along the same line though, per the documentation, you can write this in 3 different ways

// The following are equivalent:
peopleDf.filter($"age" > 15)
peopleDf.where($"age" > 15)
peopleDf($"age" > 15)

