diff --git a/app/models/enterprise_relationship.rb b/app/models/enterprise_relationship.rb index 09fff8ba03..26bf8fc012 100644 --- a/app/models/enterprise_relationship.rb +++ b/app/models/enterprise_relationship.rb @@ -9,4 +9,8 @@ class EnterpriseRelationship < ActiveRecord::Base joins('LEFT JOIN enterprises AS parent_enterprises ON parent_enterprises.id = enterprise_relationships.parent_id'). joins('LEFT JOIN enterprises AS child_enterprises ON child_enterprises.id = enterprise_relationships.child_id') scope :by_name, with_enterprises.order('parent_enterprises.name, child_enterprises.name') + + scope :involving_enterprises, ->(enterprises) { + where('parent_id IN (?) OR child_id IN (?)', enterprises, enterprises) + } end diff --git a/spec/models/enterprise_relationship_spec.rb b/spec/models/enterprise_relationship_spec.rb index 7cca07472e..b715674594 100644 --- a/spec/models/enterprise_relationship_spec.rb +++ b/spec/models/enterprise_relationship_spec.rb @@ -2,15 +2,32 @@ require 'spec_helper' describe EnterpriseRelationship do describe "scopes" do + let(:e1) { create(:enterprise, name: 'A') } + let(:e2) { create(:enterprise, name: 'B') } + let(:e3) { create(:enterprise, name: 'C') } + it "sorts by parent, child enterprise name" do - e1 = create(:enterprise, name: 'A') - e2 = create(:enterprise, name: 'B') - e3 = create(:enterprise, name: 'C') er1 = create(:enterprise_relationship, parent: e1, child: e3) er2 = create(:enterprise_relationship, parent: e2, child: e1) er3 = create(:enterprise_relationship, parent: e1, child: e2) EnterpriseRelationship.by_name.should == [er3, er1, er2] end + + describe "finding relationships involving some enterprises" do + let!(:er) { create(:enterprise_relationship, parent: e1, child: e2) } + + it "returns relationships where an enterprise is the parent" do + EnterpriseRelationship.involving_enterprises([e1]).should == [er] + end + + it "returns relationships where an enterprise is the child" do + EnterpriseRelationship.involving_enterprises([e2]).should == [er] + end + + it "does not return other relationships" do + EnterpriseRelationship.involving_enterprises([e3]).should == [] + end + end end end