Reading and processing a CSV file with Scala

In this article we will create a simple but comprehensive Scala application responsible for reading and processing a CSV file in order to extract information out of it.

Although simple, this app will touch in the following points:

  • Creating an application from scratch using SBT
  • Usage of traits, case class and a few collection methods
  • Usage of scalatest framework to write unit tests

About the application

The application will be responsible for reading a CSV file that is a subset of a public data set and can be downloaded here. The subset used in the application contains only 50 rows and looks like this:

Transaction_date,Product,Price,Payment_Type,Country
1/2/09 6:17,Product1,1200,Mastercard,United Kingdom
1/2/09 4:53,Product1,1200,Visa,United States
1/2/09 13:08,Product1,1200,Mastercard,United States
1/3/09 14:44,Product1,1200,Visa,Australia
...

Ultimately, we want to extract the following information from it:

  • Total number of sales
  • Average sale price grouped by payment type
  • Number of sales grouped by day
  • Total number of sales made out of USA and their total price

So let’s get started.

Creating the basic structure

Create a folder named “scala-cvsprocessor” and the following structure under it:

scala-csvprocessor

build.sbt
src/main/resources
src/main/scala
src/test/resources
src/test/scala

Edit the build.sbt file and add the content:

organization := "com.lucianomolinari"

name := "scala-cvsprocessor"

version := "1.0.0-SNAPSHOT"

scalaVersion := "2.11.7"

libraryDependencies += "org.scalatest" % "scalatest_2.11" % "2.2.6" % "test"

After that, you can import the project in the IDE you prefer, be it IntelliJ or Eclipse.

Reading the CSV file

As we read the CSV file, we want to convert each line to an instance of a class, so we can manipulate it easily later. For that, we will use a case class. A case class is a regular class that comes with some additional features for free, like automatic implementation of toString, equals, hashCode and copy, and also a companion object to make it easier to create instances of it.
We also want to abstract the way we load the sales data, so that our computation logic is not coupled to it. For that, we can use a trait.

SalesReader.scala
package com.lucianomolinari.csvprocessor

/**
  * Trait responsible for reading/loading [[Sale]].
  *
  * @author Luciano Molinari
  */
trait SalesReader {

  /**
    * @return A [[Seq]] containing all the sales.
    */
  def readSales(): Seq[Sale]

}

case class Sale(date: String, product: String, price: Int, paymentType: String, country: String)

Our trait has just one single method, taht is responsible for returning a Seq.
We need now an implementation for this trait that will read our CSV file. Download and copy the CSV file under src/main/resources folder. You can see below the code for the implementation:

SalesCSVReader.scala
package com.lucianomolinari.csvprocessor

import scala.io.Source

/**
  * Implementation of [[SalesReader]] responsible for reading sales from a CSV file.
  *
  * @param fileName The name of the CSV file to be read.
  * @author Luciano Molinari
  */
class SalesCSVReader(val fileName: String) extends SalesReader {

  override def readSales(): Seq[Sale] = {
    for {
      line <- Source.fromFile(fileName).getLines().drop(1).toVector
      values = line.split(",").map(_.trim)
    } yield Sale(values(0), values(1), values(2).toInt, values(3), values(4))
  }

}

Let’s go through this code in detail.

In the declaration of this class we specify it will extend the trait SalesReader and that it will receive the name of the file as a parameter in the constructor. In the readSales method we use a for comprehension to read and convert the data to a Seq of Sale. In the first line inside the for, we read the file and skip (drop) the first line, as it contains the header and we are not interested on it. On the next line of code, we split the current CSV line using the “,” character and, for each element returned by it, we remove the spaces using the trim function. Finally, we create an instance of Sale for each one of the lines read from the file. To make sure it’s working as expected, let’s write a test using ScalaTest FunSuite.

SalesCSVReaderTest.scala
package com.lucianomolinari.csvprocessor

import org.scalatest.FunSuite
import org.scalatest.Matchers._

/**
  *
  * @author Luciano Molinari
  */
class SalesCSVReaderTest extends FunSuite {

  test("Load CSV file") {
    val sales = new SalesCSVReader("src/main/resources/SalesJan2009.csv").readSales

    sales.size shouldBe 50

    sales(0) shouldBe Sale("1/2/09 6:17", "Product1", 1200, "Mastercard", "United Kingdom")
    sales(49) shouldBe Sale("1/10/09 14:43", "Product1", 1200, "Diners", "Ireland")
  }

}

As you can see, we need to extend the FunSuite class. Then, in order to create test cases, the method test() needs to be used and a descriptive name can be given as a parameter. ScalaTest provides a lot of matchers and in this case we’re using the shouldBe to check if all 50 sales were loaded and if the first and last sales are loaded properly (we don’t need to check all of them!). See how we can compare a Sale object even though we didn’t implement its equals method manually. That’s because we’re using a case class.

Processing the data

Now that we already have our CSV parser ready, we can process the data to extract the information mentioned above. The code responsible for that can be seen below:

SalesStatisticsComputer.scala
package com.lucianomolinari.csvprocessor

/**
  * Responsible for computing a few statistics based on a [[Seq]] of [[Sale]] extracted from [[SalesReader]].
  *
  * @param salesReader The reader responsible for loading the sales
  * @author Luciano Molinari
  */
class SalesStatisticsComputer(val salesReader: SalesReader) {

  val sales = salesReader.readSales

  /**
    * @return The number of sales
    */
  def getTotalNumberOfSales(): Int = sales size

  /**
    * Creates and returns a [[Map]] where the key is a payment type and the value is the
    * average sale price for all sales made using that payment type.
    *
    * @return The map with the data grouped by payment type.
    */
  def getAvgSalePricesGroupedByPaymentType(): Map[String, Double] = {
    def avg(salesOfAPaymentType: Seq[Sale]): Double =
      salesOfAPaymentType.map(_.price).sum / salesOfAPaymentType.size

    sales.groupBy(_.paymentType).mapValues(avg(_))
  }

  /**
    * Creates and returns a [[Map]] where the key is a given day in the format month/day
    * and the value is the number of sales made in the day.
    *
    * @return The map with the data grouped by day.
    */
  def getNumberOfSalesGroupedByDay(): Map[String, Int] = {
    def extractDay(sale: Sale): String = {
      val parts = sale.date.split("/")
      parts(0) + "/" + parts(1)
    }

    sales.groupBy(extractDay(_)).mapValues(_.length)
  }

  /**
    * @return A tuple where the first value is the number of sales made out of USA and
    *         the second value is the average price of these sales.
    */
  def getTotalNumberAndPriceOfSalesMadeAbroad(): (Int, Int) = {
    val filtered = sales.filter(_.country != "United States")
    (filtered.size, filtered.map(_.price).sum)
  }

}

This class receives a SalesReader in the constructor and then defines sales, that will hold all the sales read and will be used to perform the computation. Let’s see how this is used in all the methods:

  • getTotalNumberOfSales: This is quite simple as we just return the size of the Seq sales.
  • getAvgSalePricesGroupedByPaymentType: This method returns a Map where the key is the payment type and the value is the average price of the sales for that payment type. In order to compute this map, we first use the groupBy method. This method will group the elements of the collection and will return a Map. In this case, we’ll end up with a map where the key is the paymentType and the value contains the sales for it. After that, we need to go through the values of each entry and compute the average sale price. For that, we use the mapValues method and the helper function avg.
  • getNumberOfSalesGroupedByDay: Here we want to get a Map where the key is a String containing a day in the format month/day and the value is the number of sales made in that specific day. We also use the groupBy method to create a map where the key is the day and then we use mapValues to get the number of sales.
  • getTotalNumberAndPriceOfSalesMadeAbroad: The goal of this method is to return a Tuple with the total number of sales made out of USA and the total value of them. In the implementation, we first use the filter method to create another Seq containing only the sales made out of USA. Then we just return a tuple with the size of the Seq and the sum of its prices.

In order to make sure everything is working, we need to create a test class.

SalesStatisticsComputerTest.scala
package com.lucianomolinari.csvprocessor

import org.scalatest.Matchers._
import org.scalatest.{BeforeAndAfter, FunSuite}

/**
  *
  * @author Luciano Molinari
  */
class SalesStatisticsComputerTest extends FunSuite with BeforeAndAfter {

  var statistics: SalesStatisticsComputer = _

  before {
    val salesReader = new SalesCSVReader("src/main/resources/SalesJan2009.csv")
    statistics = new SalesStatisticsComputer(salesReader)
  }

  test("get total number of sales") {
    statistics.getTotalNumberOfSales() shouldBe 50
  }

  test("get average sales price grouped by payment type") {
    val avgPrices = statistics.getAvgSalePricesGroupedByPaymentType()

    avgPrices("Mastercard") shouldBe 1200
    avgPrices("Visa") shouldBe 1800
    avgPrices("Amex") shouldBe 2000
    avgPrices("Diners") shouldBe 1200
  }

  test("get number of sales groups by day") {
    val salesByDay = statistics.getNumberOfSalesGroupedByDay()

    salesByDay should have size (10)
    salesByDay("1/1") shouldBe 3
    salesByDay("1/2") shouldBe 9
    salesByDay("1/3") shouldBe 5
    salesByDay("1/4") shouldBe 6
    salesByDay("1/5") shouldBe 7
    salesByDay("1/6") shouldBe 7
    salesByDay("1/7") shouldBe 4
    salesByDay("1/8") shouldBe 5
    salesByDay("1/9") shouldBe 2
    salesByDay("1/10") shouldBe 2
  }

  test("get total number of sales made abroad and total value of them") {
    val (numberOfSales, total) = statistics.getTotalNumberAndPriceOfSalesMadeAbroad()

    numberOfSales shouldBe 21
    total shouldBe 30000
  }

}

This time, besides extending FunSuite, we also use the BeforeAndAfter trait. This is needed so we can have the before declaration, that is run before each test case. Within this block, we create a SalesCSVReader instance and pass it to the SalesStatisticsComputer constructor. We then create a test case for each one of the methods exposed by SalesStatisticsComputer.

The complete source code can be seen here.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s