EDIT3:
結局のところ、考えられるすべてのバリエーションを考慮することは、一見したよりも複雑です。私のコードのこの 3 回目の繰り返しは、考えられるすべての入力に対して正しいはずです。複雑さが増したため、ベクトル化された numpy バリアントを破棄しました。ジェネレーターのバージョンは次のとおりです。
def overlapping_sectors3(sectors, interval):
"""
Yields overlapping radial intervals.
Returns the overlapping intervals between each of the sector-intervals
and the comparison-interval.
Args:
sectors: List of intervals.
Interval borders must be in [0, 2*pi).
interval: Single interval aginst which the overlap is calculated.
Interval borders must be in [0, 2*pi).
Yields:
A list of intervals marking the overlaping areas.
Interval borders are guaranteed to be in [0, 2*pi).
"""
i_lhs, i_rhs = interval
if i_lhs > i_rhs:
for s_lhs, s_rhs in sectors:
if s_lhs > s_rhs:
# CASE 1
o_lhs = max(s_lhs, i_lhs)
# o_rhs = min(s_rhs+2*np.pi, i_rhs+2*np.pi)
o_rhs = min(s_rhs, i_rhs)
# since o_rhs > 2pi > o_lhs
yield o_lhs, o_rhs
#o_lhs = max(s_lhs+2pi, i_lhs)
# o_rhs = min(s_rhs+4pi, i_rhs+2pi)
# since o_lhs and o_rhs > 2pi
o_lhs = s_lhs
o_rhs = i_rhs
if o_lhs < o_rhs:
yield o_lhs, o_rhs
else:
# CASE 2
o_lhs = max(s_lhs, i_lhs)
# o_rhs = min(s_rhs, i_rhs+2*np.pi)
o_rhs = s_rhs # since i_rhs + 2pi > 2pi > s_rhs
if o_lhs < o_rhs:
yield o_lhs, o_rhs
# o_lhs = max(s_lhs+2pi, i_lhs)
# o_rhs = min(s_rhs+2pi, i_rhs+2pi)
# since s_lhs+2pi > 2pi > i_lhs and both o_lhs and o_rhs > 2pi
o_lhs = s_lhs
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs, o_rhs
else:
for s_lhs, s_rhs in sectors:
if s_lhs > s_rhs:
# CASE 3
o_lhs = max(s_lhs, i_lhs)
o_rhs = i_rhs
if o_lhs < o_rhs:
yield o_lhs, o_rhs
o_lhs = i_lhs
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs, o_rhs
else:
# CASE 4
o_lhs = max(s_lhs, i_lhs)
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs, o_rhs
次の方法でテストできます。
import numpy as np
from collections import namedtuple
TestCase = namedtuple('TestCase', ['sectors', 'interval', 'expected', 'remark'])
testcases = []
def newcase(sectors, interval, expected, remark=None):
testcases.append( TestCase(sectors, interval, expected, remark) )
newcase(
[[280,70]],
[270,90],
[[280,70]],
"type 1"
)
newcase(
[[10,150]],
[270,90],
[[10,90]],
"type 2"
)
newcase(
[[10,150]],
[270,350],
[],
"type 4"
)
newcase(
[[50,350]],
[10,90],
[[50,90]],
"type 4"
)
newcase(
[[30,0]],
[300,60],
[[30,60],[300,0]],
"type 1"
)
newcase(
[[30,5]],
[300,60],
[[30,60],[300,5]],
"type 1"
)
newcase(
[[30,355]],
[300,60],
[[30,60],[300,355]],
"type 3"
)
def isequal(A,B):
if len(A) != len(B):
return False
A = np.array(A).round()
B = np.array(B).round()
a = set(map(tuple, A))
b = set(map(tuple, B))
return a == b
for caseindex, case in enumerate(testcases):
print("### testcase %2d ###" % caseindex)
print("sectors : %s" % case.sectors)
print("interval: %s" % case.interval)
if case.remark:
print(case.remark)
sectors = np.array(case.sectors)/180*np.pi
interval = np.array(case.interval)/180*np.pi
result = overlapping_sectors3(sectors, interval)
result = np.array(list(result))*180/np.pi
if isequal(case.expected, result):
print('PASS')
else:
print('FAIL')
print('\texp: %s' % case.expected)
print('\tgot: %s' % result)
その背後にあるロジックを理解するには、次のことを考慮してください。
- 各間隔には左側 (lhs) と右側 (rhs) があります。
- lhs > rhs の場合、間隔は「ラップラウンド」します。つまり、実際には間隔 [lhs, rhs+2pi] になります。
- 現在のセクターと比較間隔を比較する場合、4 つのケースを考慮する必要があります。
- 両方ともラップラウンド
- 比較間隔のみが折り返されます
- セクター間隔のみがラップラウンドします
- どちらも丸めない
- 通常の間隔では、重複する間隔は
[o_lhs, o_rhs]
witho_lhs=max(lhs_1, lhs2)
および o _rhs=min(rhs_1, rhs_2)
iffです。o_lhs < o_rhs
2pi
rhs iff に間隔を追加することによってすべての間隔を「歪ませない」と、次の間隔rhs<lhs
が得られます[0, 4*np.pi)
[0,2*pi)
1[2*pi, 4*pi)
番目、2 番目、3 番目の周回を呼びます[4*pi, 6*pi)
。
4 つのケース:
- ケース 4: どちらの区間も回り込まないため、すべての境界は最初の周回内にあります。任意の元の間隔と同様に、オーバーラップを計算するだけです。
- ケース 2 および 3: ちょうど 1 つの間隔がラップします。つまり、1 つの間隔 (a と呼びます) は完全に最初の周回内にあり、2 番目の間隔 (b と呼びます) は最初と 2 番目の周回の両方を生成します。つまり、a は最初と 2 番目の周回の両方で b と交差できます。まず、最初の周回を考えます。a_lhs、a_rhs、および b_lhs が含まれています。b の右側は「ラップされていない」と見なされるため、 になり
b_rhs+2pi
ます。これにより と が得られo_lhs=max(a_lhs, b_lhs)
ますo_rhs=a_rhs
。次に、2 番目の周回を考えます。これには、b at の右辺だけでなく、a atb_rhs+2pi
の周期的な繰り返しも含まれ[a_lhs+2pi, a_rhs+2pi]
ます。これにより、o_lhs=max(a_lhs+2pi, b_lhs)
とが得られo_rhs=min(a_rhs+2pi, b_rhs+2pi)
ます。モジュロはと2pi
の両方にシフトします。o_lhs=a_lhs
o_rhs=min(a_rhs, b_rhs)
- ケース 1: 両方の間隔で周回 1 と 2 が生成されます。最初の交点は
[0, 4pi)
2 番目の区間内にあり、間隔の 1 つを定期的に繰り返す必要があるため、 内にあります[2pi,6pi)
。
古い回答、非推奨:
ここで、numpy ベクトル操作を使用した私のバージョンを示します。np.where などのより抽象的な numpy 関数を利用することで改善される可能性があります。
別のアイデアは、numpy を無視して、ある種のイテレータ/ジェネレータ関数を使用することです。おそらく、次はそのようなことを試してみます。
import numpy as np
sectors = np.array( [[5.23,0.50], [0.7,1.8], [1.9,3.71],[4.1,5.11]] )
interval = np.array([5.7,2.15])
def normalize_sectors(sectors):
# normalize might not be the best word here
idx = sectors[...,0] > sectors[...,1]
sectors[idx,1] += 2*np.pi
return sectors
def overlapping_sectors(sectors, interval):
# 'reverse' modulo 2*pi, so that rhs is always larger than lhs"
sectors = normalize_sectors(sectors)
interval = normalize_sectors(interval.reshape(1,2)).squeeze()
# when comparing two intervals A and B, the intersection is
# [max(A.left, B.left), min(A.right, B.right)
left = np.maximum(sectors[:,0], interval[0])
right = np.minimum(sectors[:,1], interval[1])
# construct overlapping intervals
res = np.hstack([left,right]).reshape((2,-1)).T
# neither empty (lhs=rhs) nor 'reversed' lhs>rhs intervals are allowed
res = res[res[:,0] < res[:,1]]
#reapply modulo
res = res % (2*np.pi)
return res
print(overlapping_sectors(sectors, interval))
編集:
ここではイテレータ ベースのバージョンです。同様に機能しますが、数値的にはやや劣っているようです。
def overlapping_sectors2(sectors, interval):
i_lhs, i_rhs = interval
if i_lhs>i_rhs:
i_rhs += 2*np.pi
for s_lhs, s_rhs in sectors:
if s_lhs>s_rhs:
s_rhs += 2*np.pi
o_lhs = max(s_lhs, i_lhs)
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs % (2*np.pi), o_rhs % (2*np.pi)
print(list(overlapping_sectors2(sectors, interval)))
EDIT2:
2 つの場所でオーバーラップする間隔をサポートするようになりました。
sectors = np.array( [[30,330]] )/180*np.pi
interval = np.array( [300,60] )/180*np.pi
def normalize_sectors(sectors):
# normalize might not be the best word here
idx = sectors[...,0] > sectors[...,1]
sectors[idx,1] += 2*np.pi
return sectors
def overlapping_sectors(sectors, interval):
# 'reverse' modulo 2*pi, so that rhs is always larger than lhs"
sectors = normalize_sectors(sectors)
# if interval rhs is smaller than lhs, the interval crosses 360 degrees
# and we have to consider it as two intervals
if interval[0] > interval[1]:
interval_1 = np.array([interval[0], 2*np.pi])
interval_2 = np.array([0, interval[1]])
res_1 = _overlapping_sectors(sectors, interval_1)
res_2 = _overlapping_sectors(sectors, interval_2)
res = np.vstack((res_1, res_2))
else:
res = _overlapping_sectors(sectors, interval)
#reapply modulo
res = res % (2*np.pi)
return res
def _overlapping_sectors(sector, interval):
# when comparing two intervals A and B, the intersection is
# [max(A.left, B.left), min(A.right, B.right)
left = np.maximum(sectors[:,0], interval[0])
right = np.minimum(sectors[:,1], interval[1])
# construct overlapping intervals
res = np.hstack([left,right]).reshape((2,-1)).T
# neither empty (lhs=rhs) nor 'reversed' lhs>rhs intervals are allowed
res = res[res[:,0] < res[:,1]]
return res
print(overlapping_sectors(sectors, interval)*180/np.pi)
def overlapping_sectors2(sectors, interval):
i_lhs, i_rhs = interval
for s_lhs, s_rhs in sectors:
if s_lhs>s_rhs:
s_rhs += 2*np.pi
if i_lhs > i_rhs:
o_lhs = max(s_lhs, i_lhs)
o_rhs = min(s_rhs, 2*np.pi)
if o_lhs < o_rhs:
yield o_lhs % (2*np.pi), o_rhs % (2*np.pi)
o_lhs = max(s_lhs, 0)
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs % (2*np.pi), o_rhs % (2*np.pi)
else:
o_lhs = max(s_lhs, i_lhs)
o_rhs = min(s_rhs, i_rhs)
if o_lhs < o_rhs:
yield o_lhs % (2*np.pi), o_rhs % (2*np.pi)
print(np.array(list(overlapping_sectors2(sectors, interval)))*180/np.pi)