基本上這個問題就是給你一些線段(格式通常為兩個端點),你要找出這些線段的交點。直觀的做法兩兩進行計算會花上\ord{n^2}的時間,但大多數的情況下交點不會很多。為了解決這個問題,修改自Shamos–Hoey演算法的Bentley–Ottmann演算法可以在\ord{(n+k)\log n}的時間內找出所有交點,其中k是交點數量。
這裡附上實作時需要用到的基本資料結構:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <typename T> struct point { | |
T x, y; | |
point() {} | |
point(const T &x, const T &y) : x(x), y(y) {} | |
point operator+(const point &b) const { return point(x + b.x, y + b.y); } | |
point operator-(const point &b) const { return point(x - b.x, y - b.y); } | |
point operator*(const T &b) const { return point(x * b, y * b); } | |
bool operator==(const point &b) const { return x == b.x && y == b.y; } | |
T dot(const point &b) const { return x * b.x + y * b.y; } | |
T cross(const point &b) const { return x * b.y - y * b.x; } | |
}; | |
template <typename T> struct line { | |
line() {} | |
point<T> p1, p2; | |
T a, b, c; // ax+by+c=0 | |
line(const point<T> &p1, const point<T> &p2) : p1(p1), p2(p2) {} | |
void pton() { | |
a = p1.y - p2.y; | |
b = p2.x - p1.x; | |
c = -a * p1.x - b * p1.y; | |
} | |
T ori(const point<T> &p) const { return (p2 - p1).cross(p - p1); } | |
T btw(const point<T> &p) const { return (p1 - p).dot(p2 - p); } | |
int seg_intersect(const line &l) const { | |
// -1: Infinitude of intersections | |
// 0: No intersection | |
// 1: One intersection | |
// 2: Collinear and intersect at p1 | |
// 3: Collinear and intersect at p2 | |
T c1 = ori(l.p1), c2 = ori(l.p2); | |
T c3 = l.ori(p1), c4 = l.ori(p2); | |
if (c1 == 0 && c2 == 0) { | |
bool b1 = btw(l.p1) >= 0, b2 = btw(l.p2) >= 0; | |
T a3 = l.btw(p1), a4 = l.btw(p2); | |
if (b1 && b2 && a3 == 0 && a4 >= 0) | |
return 2; | |
if (b1 && b2 && a3 >= 0 && a4 == 0) | |
return 3; | |
if (b1 && b2 && a3 >= 0 && a4 >= 0) | |
return 0; | |
return -1; | |
} else if (c1 * c2 <= 0 && c3 * c4 <= 0) | |
return 1; | |
return 0; | |
} | |
point<T> line_intersection(const line &l) const { | |
point<T> a = p2 - p1, b = l.p2 - l.p1, s = l.p1 - p1; | |
// if(a.cross(b)==0) return INF; | |
return p1 + a * (s.cross(b) / a.cross(b)); | |
} | |
}; | |
template <typename T> using segment = line<T>; |
演算法使用掃描線進行。掃描線是一條垂直線從左邊掃到右邊(有些實作是水平線從上面掃到下面),並且在遇到事件點的時候進行相關處理。
線段的兩端點以及交點都作為事件點被紀錄在最終結果中。對於每個事件點P,我們會計算三個集合:
- U集合:所有以P為起始點的線段集合
- C集合:所有包含P的線段集合
- L集合:所有以P為結束點的線段集合
當然要先保證每條線段的起始點移動會在結束點的左方,只要得到線段後稍微判斷一下就可以做到了。每個事件點找出這三個集合後就可以很容易的判斷相交資訊,但要注意的是會有以下的退化情形:
- 線段退化成點:這種情況該點的U和L都會包含該線段。
- 兩線段重合:只有重合處的兩端點會被紀錄為事件點,可以根據UCL判斷出是否線段重合
- 垂直線段:排序點和線段時如果x一樣就按照y來比較
最後是掃描線的資料結構,需要一棵平衡的BST根據當前掃描線和各個線段切點的y值進行排序,但這件事是可以用STL做到的!我們把當前事件點傳進比較函數裡面進行計算,因為在任何一個時刻BST中的資料都是根據當前的比較函數由小排到大的,應該不算undefined behavior。另外該演算法的浮點數誤差很大,建議使用時套上處理誤差的模板或是直接用分數計算:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <map> | |
#include <set> | |
#include <vector> | |
template <typename T> struct SegmentIntersection { | |
struct PointInfo { | |
std::vector<const segment<T> *> U, C, L; | |
// U: the set of segments start at the point. | |
// C: the set of segments contain the point. | |
// L: the set of segments end at the point. | |
}; | |
private: | |
struct PointCmp { | |
bool operator()(const point<T> &a, const point<T> &b) const { | |
return (a.x < b.x) || (a.x == b.x && a.y < b.y); | |
} | |
}; | |
struct LineCmp { | |
const point<T> &P; | |
LineCmp(const point<T> &P) : P(P) {} | |
T getY(const segment<T> *s) const { | |
if (s->b == 0) | |
return P.y; | |
return (s->a * P.x + s->c) / -s->b; | |
} | |
bool operator()(const line<T> *a, const line<T> *b) const { | |
return getY(a) < getY(b); | |
} | |
}; | |
const std::vector<segment<T>> Segs; | |
std::map<point<T>, PointInfo, PointCmp> PointInfoMap; | |
std::set<point<T>, PointCmp> Queue; | |
point<T> Current; | |
std::multiset<const segment<T> *, LineCmp> BST; | |
std::vector<segment<T>> initSegs(std::vector<segment<T>> &&Segs) { | |
for (auto &S : Segs) { | |
if (!PointCmp()(S.p1, S.p2)) | |
std::swap(S.p1, S.p2); | |
S.pton(); | |
} | |
return Segs; | |
} | |
void init() { | |
for (auto &S : Segs) { | |
PointInfoMap[S.p1].U.emplace_back(&S); | |
PointInfoMap[S.p2].L.emplace_back(&S); | |
Queue.emplace(S.p1); | |
Queue.emplace(S.p2); | |
} | |
} | |
void FindNewEvent(const segment<T> *A, const segment<T> *B) { | |
auto Type = A->seg_intersect(*B); | |
if (Type <= 0) | |
return; | |
point<T> P; | |
if (Type == 2) | |
P = A->p1; | |
else if (Type == 3) | |
P = A->p2; | |
else | |
P = A->line_intersection(*B); | |
if (PointCmp()(Current, P)) | |
Queue.emplace(P); | |
} | |
void HandleEventPoint() { | |
auto &Info = PointInfoMap[Current]; | |
segment<T> Tmp(Current, Current); | |
auto LBound = BST.lower_bound(&Tmp); | |
auto UBound = BST.upper_bound(&Tmp); | |
std::copy_if( | |
LBound, UBound, std::back_inserter(Info.C), | |
[&](const segment<T> *S) -> bool { return !(S->p2 == Current); }); | |
BST.erase(LBound, UBound); | |
auto UC = Info.U; | |
UC.insert(UC.end(), Info.C.begin(), Info.C.end()); | |
UC.erase(std::remove_if( | |
UC.begin(), UC.end(), | |
[&](const segment<T> *S) -> bool { return S->p1 == S->p2; }), | |
UC.end()); | |
std::sort(UC.begin(), UC.end(), | |
[&](const segment<T> *A, const segment<T> *B) -> bool { | |
return (A->p2 - Current).cross(B->p2 - Current) > 0; | |
}); | |
if (UC.empty()) { | |
if (UBound != BST.end() && UBound != BST.begin()) { | |
auto Sr = *UBound; | |
auto Sl = *(--UBound); | |
FindNewEvent(Sl, Sr); | |
} | |
} else { | |
if (UBound != BST.end()) { | |
auto Sr = *UBound; | |
FindNewEvent(UC.back(), Sr); | |
} | |
if (UBound != BST.begin()) { | |
auto Sl = *(--UBound); | |
FindNewEvent(Sl, UC.front()); | |
} | |
} | |
for (auto S : UC) | |
BST.emplace(S); | |
} | |
public: | |
SegmentIntersection(std::vector<segment<T>> Segs) | |
: Segs(initSegs(std::move(Segs))), BST(Current) { | |
init(); | |
while (Queue.size()) { | |
auto It = Queue.begin(); | |
Current = *It; | |
Queue.erase(It); | |
HandleEventPoint(); | |
} | |
} | |
const std::vector<segment<T>> getSegments() const { return Segs; } | |
const std::map<point<T>, PointInfo, PointCmp> &getPointInfoMap() const { | |
return PointInfoMap; | |
} | |
}; |
最後是測試的部分,以下圖做為測試範例:
將該圖轉換成我們接受的input如下:
10-2 7 2 0-2 7 -2 0-2 6 2 5-2 6 2 2-2 4 2 7-2 4 2 2-2 4 4 1-2 0 2 20 1 0 10 3 4 1
最後附上測試程式碼,需要的話可以自己執行看看:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cmath> | |
#include <iostream> | |
const double EPS = 1e-9; | |
struct Double { | |
double d; | |
Double(double d = 0) : d(d) {} | |
Double operator-() const { return -d; } | |
Double operator+(const Double &b) const { return d + b.d; } | |
Double operator-(const Double &b) const { return d - b.d; } | |
Double operator*(const Double &b) const { return d * b.d; } | |
Double operator/(const Double &b) const { return d / b.d; } | |
Double operator+=(const Double &b) { return d += b.d; } | |
Double operator-=(const Double &b) { return d -= b.d; } | |
Double operator*=(const Double &b) { return d *= b.d; } | |
Double operator/=(const Double &b) { return d /= b.d; } | |
bool operator<(const Double &b) const { return d - b.d < -EPS; } | |
bool operator>(const Double &b) const { return d - b.d > EPS; } | |
bool operator==(const Double &b) const { return fabs(d - b.d) <= EPS; } | |
bool operator!=(const Double &b) const { return fabs(d - b.d) > EPS; } | |
bool operator<=(const Double &b) const { return d - b.d <= EPS; } | |
bool operator>=(const Double &b) const { return d - b.d >= -EPS; } | |
friend std::ostream &operator<<(std::ostream &os, const Double &db) { | |
return os << db.d; | |
} | |
friend std::istream &operator>>(std::istream &is, Double &db) { | |
return is >> db.d; | |
} | |
}; | |
void getResult(const SegmentIntersection<Double> &SI) { | |
std::cout << R"(-----------------Info--------------------- | |
U: the set of segments start at the point. | |
C: the set of segments contain the point. | |
L: the set of segments end at the point. | |
------------------------------------------ | |
)"; | |
for (auto p : SI.getPointInfoMap()) { | |
std::cout << "(" << p.first.x << ',' << p.first.y << "):\n"; | |
std::cout << " U:\n"; | |
for (auto s : p.second.U) { | |
std::cout << " (" << s->p1.x << ',' << s->p1.y << ") (" << s->p2.x | |
<< ',' << s->p2.y << ")\n"; | |
} | |
std::cout << " C:\n"; | |
for (auto s : p.second.C) { | |
std::cout << " (" << s->p1.x << ',' << s->p1.y << ") (" << s->p2.x | |
<< ',' << s->p2.y << ")\n"; | |
} | |
std::cout << " L:\n"; | |
for (auto s : p.second.L) { | |
std::cout << " (" << s->p1.x << ',' << s->p1.y << ") (" << s->p2.x | |
<< ',' << s->p2.y << ")\n"; | |
} | |
std::cout << '\n'; | |
} | |
} | |
int main() { | |
int n; | |
std::cin >> n; | |
std::vector<segment<Double>> segments; | |
while (n--) { | |
point<Double> a, b; | |
std::cin >> a.x >> a.y >> b.x >> b.y; | |
segments.emplace_back(a, b); | |
} | |
SegmentIntersection<Double> SI(segments); | |
getResult(SI); | |
return 0; | |
} |