diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
index 54a5569b3d9f2964431bde26913ddf9abfea3e0c..b4fcc9229bd7446a2d015ec1038a7ac3c2225147 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -48,16 +48,17 @@ class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialS
 
   def size: Int = _numElements
 
-  /** Get the underlying array backing this vector. */
+  /** Gets the underlying array backing this vector. */
   def array: Array[V] = _array
 
   /** Trims this vector so that the capacity is equal to the size. */
-  def trim(): Unit = resize(size)
+  def trim(): PrimitiveVector[V] = resize(size)
 
   /** Resizes the array, dropping elements if the total length decreases. */
-  def resize(newLength: Int) {
+  def resize(newLength: Int): PrimitiveVector[V] = {
     val newArray = new Array[V](newLength)
     _array.copyToArray(newArray)
     _array = newArray
+    this
   }
 }