私はついにコードベースでこの問題に対処することに取り掛かり、ここで私の解決策を共有することにしました。簡単な説明:有効な状態 (通常は列挙型) を定義するStateMachineTaskFactory<T>
クラスがあります。T
このタスク ファクトリを使用すると、有効な遷移を登録し (たとえば、Started
状態に遷移する場合Starting
、遷移の進行中に使用する)、遷移を実行できます。非同期 API を維持しながら、ステート マシンのセマンティクスを保証します。基本的には、元のコードに存在するステート マシンを形式化して、堅牢で再利用可能な方法で行います。
まず、私の質問で提示されたユースケースでどのように使用されるかの例を次に示します。
public enum ServiceState
{
Uninitialized,
Initializing,
Initialized,
Starting,
Started,
Stopping,
Stopped
}
public class SomeService
{
private readonly StateMachineTaskFactory<ServiceState> stateMachineTaskFactory;
public Service()
{
this.stateMachineTaskFactory = new StateMachineTaskFactory<ServiceState>();
this.stateMachineTaskFactory.RegisterTransition(ServiceState.Initializing, ServiceState.Initialized, this.OnInitializeAsync);
this.stateMachineTaskFactory.RegisterTransition(ServiceState.Starting, ServiceState.Started, this.OnStartAsync);
this.stateMachineTaskFactory.RegisterTransition(ServiceState.Stopping, ServiceState.Stopped, this.OnStopAsync);
}
// we don't support cancellation in our initialize API
public Task InitializeAsync()
{
return this.stateMachineTaskFactory.TransitionTo(ServiceState.Initialized);
}
public Task StartAsync(CancellationToken cancellationToken = default(CancellationToken))
{
return this.stateMachineTaskFactory.TransitionTo(ServiceState.Started, cancellationToken);
}
public Task StopAsync(CancellationToken cancellationToken = default(CancellationToken))
{
return this.stateMachineTaskFactory.TransitionTo(ServiceState.Stopped, cancellationToken);
}
// even though we don't support cancellation during initialization, we'll still get a cancellation token, but it will CancellationToken.None
private Task OnInitializeAsync(CancellationToken cancellationToken, object state)
{
// return a Task that performs the actual work involved in initializing
}
private Task OnStartAsync(CancellationToken cancellationToken, object state)
{
// return a Task that performs the actual work involved in starting, passing on the cancellation token as relevant
}
private Task OnStopAsync(CancellationToken cancellationToken, object state)
{
// return a Task that performs the actual work involved in stopping, passing on the cancellation token as relevant
}
}
上記の使用例で明らかなよりも多くの機能と柔軟性を利用できますが、これはおそらく通常の使用例です。
次のコードの壁で申し訳ありません。読みやすくするために API ドキュメントを削除しました。含まれていないユーティリティ クラスがいくつかありますが、それらは一目瞭然です。
[Serializable]
public sealed class StateTransitionForbiddenException<T> : InvalidOperationException
where T : struct
{
private readonly T targetState;
private readonly T state;
public StateTransitionForbiddenException()
{
}
public StateTransitionForbiddenException(string message)
: base(message)
{
}
public StateTransitionForbiddenException(string message, Exception innerException)
: base(message, innerException)
{
}
public StateTransitionForbiddenException(T targetState, T state)
: base("A transition to state '" + targetState + "' was forbidden by the validate transition callback.")
{
this.targetState = targetState;
this.state = state;
}
public StateTransitionForbiddenException(string message, T targetState, T state)
: base(message)
{
this.targetState = targetState;
this.state = state;
}
private StateTransitionForbiddenException(SerializationInfo info, StreamingContext context)
: base(info, context)
{
this.targetState = (T)info.GetValue("TargetState", typeof(T));
this.state = (T)info.GetValue("State", typeof(T));
}
public T TargetState
{
get { return this.targetState; }
}
public T State
{
get { return this.state; }
}
public override void GetObjectData(SerializationInfo info, StreamingContext context)
{
base.GetObjectData(info, context);
info.AddValue("TargetState", this.targetState);
info.AddValue("State", this.state);
}
}
[DebuggerDisplay("{OldState} -> {NewState}")]
public sealed class StateChangedEventArgs<T> : EventArgs
where T : struct
{
private readonly T oldState;
private readonly T newState;
public StateChangedEventArgs(T oldState, T newState)
{
this.oldState = oldState;
this.newState = newState;
}
public T OldState
{
get { return this.oldState; }
}
public T NewState
{
get { return this.newState; }
}
}
public delegate Task CreateTaskForTransitionCallback(CancellationToken cancellationToken, object state);
public delegate bool ValidateTransitionCallback<T>(T currentState)
where T : struct;
public class StateMachineTaskFactory<T> : TaskFactory
where T : struct
{
private static readonly ExceptionHelper exceptionHelper = new ExceptionHelper(typeof(StateMachineTaskFactory<>));
private readonly ConcurrentDictionary<T, TransitionRegistrationInfo> transitionRegistrations;
private readonly object stateSync;
// the current state
private T state;
// the state to which we're currently transitioning
private T? transitionToState;
// the task performing the transition
private Task transitionToTask;
public StateMachineTaskFactory()
: this(default(T))
{
}
public StateMachineTaskFactory(T startState)
{
this.transitionRegistrations = new ConcurrentDictionary<T, TransitionRegistrationInfo>();
this.stateSync = new object();
this.state = startState;
}
public event EventHandler<StateChangedEventArgs<T>> StateChanged;
public T State
{
get
{
return this.state;
}
private set
{
if (!EqualityComparer<T>.Default.Equals(this.state, value))
{
var oldState = this.state;
this.state = value;
this.OnStateChanged(new StateChangedEventArgs<T>(oldState, value));
}
}
}
public void RegisterTransition(T beginTransitionState, T endTransitionState, CreateTaskForTransitionCallback createTaskCallback)
{
createTaskCallback.AssertNotNull("factory");
var transitionRegistrationInfo = new TransitionRegistrationInfo(beginTransitionState, createTaskCallback);
var registered = this.transitionRegistrations.TryAdd(endTransitionState, transitionRegistrationInfo);
exceptionHelper.ResolveAndThrowIf(!registered, "transitionAlreadyRegistered", endTransitionState);
}
public Task TransitionTo(T endTransitionState, CancellationToken cancellationToken = default(CancellationToken), ValidateTransitionCallback<T> validateTransitionCallback = null, object state = null)
{
lock (this.stateSync)
{
if (EqualityComparer<T>.Default.Equals(this.state, endTransitionState))
{
// already in the requested state - nothing to do
return TaskUtil.FromResult(true);
}
else if (this.transitionToState.HasValue && EqualityComparer<T>.Default.Equals(this.transitionToState.Value, endTransitionState))
{
// already in the process of transitioning to the requested state - return same transition task
return this.transitionToTask;
}
else if (this.transitionToTask != null)
{
// not in the requested state, but there is an outstanding transition in progress, so come back to this request once it's done
return this.transitionToTask.Then(x => this.TransitionTo(endTransitionState, cancellationToken, validateTransitionCallback, state));
}
else if (validateTransitionCallback != null && !validateTransitionCallback(this.State))
{
// transition is forbidden, so return a failing task to that affect
var taskCompletionSource = new TaskCompletionSource<bool>();
var exception = new StateTransitionForbiddenException<T>(endTransitionState, this.State);
taskCompletionSource.TrySetException(exception);
return taskCompletionSource.Task;
}
// else, need to transition to the chosen state
TransitionRegistrationInfo transitionRegistrationInfo;
var result = this.transitionRegistrations.TryGetValue(endTransitionState, out transitionRegistrationInfo);
exceptionHelper.ResolveAndThrowIf(!result, "transitionNotRegistered", endTransitionState);
var beginTransitionState = transitionRegistrationInfo.BeginTransitionState;
var task = transitionRegistrationInfo.TaskFactory(cancellationToken, state);
exceptionHelper.ResolveAndThrowIf(task == null, "taskFactoryReturnedNull", endTransitionState);
var previousState = this.State;
this.State = beginTransitionState;
this.transitionToState = endTransitionState;
this.transitionToTask = task
.ContinueWith(
x =>
{
if (x.IsFaulted || cancellationToken.IsCancellationRequested)
{
// faulted or canceled, so roll back to previous state
lock (this.stateSync)
{
this.State = previousState;
this.transitionToState = null;
this.transitionToTask = null;
}
if (x.IsFaulted)
{
throw x.Exception;
}
cancellationToken.ThrowIfCancellationRequested();
}
else
{
// succeeded, so commit to end state
lock (this.stateSync)
{
this.State = endTransitionState;
this.transitionToState = null;
this.transitionToTask = null;
}
}
});
return this.transitionToTask;
}
}
protected virtual void OnStateChanged(StateChangedEventArgs<T> e)
{
this.StateChanged.Raise(this, e);
}
private struct TransitionRegistrationInfo
{
private readonly T beginTransitionState;
private readonly CreateTaskForTransitionCallback taskFactory;
public TransitionRegistrationInfo(T beginTransitionState, CreateTaskForTransitionCallback taskFactory)
{
this.beginTransitionState = beginTransitionState;
this.taskFactory = taskFactory;
}
public T BeginTransitionState
{
get { return this.beginTransitionState; }
}
public CreateTaskForTransitionCallback TaskFactory
{
get { return this.taskFactory; }
}
}
}
そして、完全を期すために、私の単体テスト:
public sealed class StateMachineTaskFactoryFixture
{
#region Supporting Enums
private enum State
{
Undefined,
Starting,
Started,
Stopping,
Stopped
}
#endregion
[Fact]
public void default_ctor_uses_default_value_for_start_state()
{
var factory = new StateMachineTaskFactory<State>();
Assert.Equal(State.Undefined, factory.State);
}
[Fact]
public void ctor_can_set_start_state()
{
var factory = new StateMachineTaskFactory<State>(State.Stopped);
Assert.Equal(State.Stopped, factory.State);
}
[Fact]
public void register_transition_throws_if_factory_is_null()
{
var factory = new StateMachineTaskFactory<State>();
Assert.Throws<ArgumentNullException>(() => factory.RegisterTransition(State.Starting, State.Started, null));
}
[Fact]
public void register_transition_throws_if_transition_already_registered()
{
var factory = new StateMachineTaskFactory<State>();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.FromResult(true));
var ex = Assert.Throws<InvalidOperationException>(() => factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.FromResult(true)));
Assert.Equal("A transition to state 'Started' has already been registered.", ex.Message);
}
[Fact]
public void transition_to_throws_if_no_transition_registered_for_state()
{
var factory = new StateMachineTaskFactory<State>();
var ex = Assert.Throws<InvalidOperationException>(() => factory.TransitionTo(State.Started));
Assert.Equal("No transition to state 'Started' has been registered.", ex.Message);
}
[Fact]
public void transition_to_throws_if_task_factory_returns_null()
{
var factory = new StateMachineTaskFactory<State>();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => null);
var ex = Assert.Throws<InvalidOperationException>(() => factory.TransitionTo(State.Started));
Assert.Equal("Task factory for end state 'Started' returned null.", ex.Message);
}
[Fact]
public void transition_to_returns_same_task_if_called_multiple_times_whilst_initial_task_is_still_in_progress()
{
var factory = new StateMachineTaskFactory<State>();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.Delay(TimeSpan.FromMilliseconds(250)));
var initialTask = factory.TransitionTo(State.Started);
Assert.Equal(initialTask, factory.TransitionTo(State.Started));
Assert.Equal(initialTask, factory.TransitionTo(State.Started));
Assert.Equal(initialTask, factory.TransitionTo(State.Started));
Assert.True(initialTask.Wait(TimeSpan.FromSeconds(3)));
}
[Fact]
public void transition_to_returns_completed_task_if_already_in_desired_state()
{
var factory = new StateMachineTaskFactory<State>();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.FromResult(true));
factory.TransitionTo(State.Started).Wait();
Assert.Equal(TaskStatus.RanToCompletion, factory.TransitionTo(State.Started).Status);
}
[Fact]
public void transition_to_passes_any_state_to_task_creation_function()
{
var factory = new StateMachineTaskFactory<State>();
string receivedState = null;
factory.RegisterTransition(
State.Starting,
State.Started,
(ct, o) =>
{
receivedState = o as string;
return TaskUtil.FromResult(true);
});
factory.TransitionTo(State.Started, CancellationToken.None, null, "here is the state").Wait();
Assert.Equal("here is the state", receivedState);
}
[Fact]
[SuppressMessage("Microsoft.Naming", "CA2204", Justification = "It's not a word - it's a format string!")]
public void transition_to_ensures_previous_transition_is_first_completed_before_starting_subsequent_transition()
{
var factory = new StateMachineTaskFactory<State>();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.Delay(TimeSpan.FromMilliseconds(10)));
factory.RegisterTransition(State.Stopping, State.Stopped, (ct, o) => TaskUtil.Delay(TimeSpan.FromMilliseconds(10)));
var startedAt = DateTime.MinValue;
var stoppedAt = DateTime.MinValue;
var startedTask = factory.TransitionTo(State.Started).ContinueWith(x => startedAt = DateTime.UtcNow, TaskContinuationOptions.ExecuteSynchronously);
var stoppedTask = factory.TransitionTo(State.Stopped).ContinueWith(x => stoppedAt = DateTime.UtcNow, TaskContinuationOptions.ExecuteSynchronously);
Assert.True(Task.WaitAll(new Task[] { startedTask, stoppedTask }, TimeSpan.FromSeconds(3)), "Timed out waiting for tasks to complete.");
Assert.True(stoppedAt > startedAt, "stoppedAt is " + stoppedAt.Millisecond + " and startedAt is " + startedAt.Millisecond + ", difference is " + (stoppedAt - startedAt).ToString());
}
[Fact]
public void transition_to_can_be_canceled_before_transition_takes_place()
{
var factory = new StateMachineTaskFactory<State>();
var cancellationTokenSource = new CancellationTokenSource();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.FromResult(true));
cancellationTokenSource.Cancel();
var startedTask = factory.TransitionTo(State.Started, cancellationTokenSource.Token);
try
{
startedTask.Wait();
Assert.True(false, "Failed to throw exception.");
}
catch (AggregateException ex)
{
Assert.Equal(1, ex.InnerExceptions.Count);
Assert.IsType<OperationCanceledException>(ex.InnerExceptions[0]);
}
}
[Fact]
public void transition_to_can_be_canceled()
{
var factory = new StateMachineTaskFactory<State>();
var cancellationTokenSource = new CancellationTokenSource();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.FromResult(true));
factory.RegisterTransition(State.Stopping, State.Stopped, (ct, o) => TaskUtil.Delay(TimeSpan.FromMilliseconds(150)));
var startedTask = factory.TransitionTo(State.Started, cancellationTokenSource.Token);
startedTask.ContinueWith(x => cancellationTokenSource.Cancel());
var stoppedTask = factory.TransitionTo(State.Stopped, cancellationTokenSource.Token);
startedTask.Wait(TimeSpan.FromSeconds(3));
try
{
stoppedTask.Wait(TimeSpan.FromSeconds(3));
Assert.True(false, "Failed to throw exception.");
}
catch (AggregateException ex)
{
Assert.Equal(1, ex.InnerExceptions.Count);
Assert.IsType<OperationCanceledException>(ex.InnerExceptions[0]);
}
}
[Fact]
public void transition_to_can_be_forbidden()
{
var factory = new StateMachineTaskFactory<State>();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.FromResult(true));
factory.RegisterTransition(State.Stopping, State.Stopped, (ct, o) => TaskUtil.FromResult(true));
var startedTask = factory.TransitionTo(State.Started, CancellationToken.None, x => x == State.Undefined);
var stoppedTask = factory.TransitionTo(State.Stopped, CancellationToken.None, x => x != State.Started);
startedTask.Wait(TimeSpan.FromSeconds(3));
try
{
stoppedTask.Wait(TimeSpan.FromSeconds(3));
Assert.True(false, "Failed to throw exception.");
}
catch (AggregateException ex)
{
Assert.Equal(1, ex.InnerExceptions.Count);
var ex2 = Assert.IsType<StateTransitionForbiddenException<State>>(ex.InnerExceptions[0]);
Assert.Equal(State.Stopped, ex2.TargetState);
Assert.Equal(State.Started, ex2.State);
Assert.Equal("A transition to state 'Stopped' was forbidden by the validate transition callback.", ex2.Message);
}
}
[Fact]
public void canceled_transition_reverts_back_to_original_state()
{
var factory = new StateMachineTaskFactory<State>();
var cancellationTokenSource = new CancellationTokenSource();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.FromResult(true));
factory.RegisterTransition(State.Stopping, State.Stopped, (ct, o) => TaskUtil.Delay(TimeSpan.FromSeconds(3), cancellationTokenSource.Token));
factory.StateChanged += (s, e) =>
{
if (e.NewState == State.Stopping)
{
// cancel the stop
cancellationTokenSource.Cancel();
}
};
var startedTask = factory.TransitionTo(State.Started);
var stoppedTask = factory.TransitionTo(State.Stopped, cancellationTokenSource.Token);
startedTask.Wait(TimeSpan.FromSeconds(3));
try
{
stoppedTask.Wait(TimeSpan.FromSeconds(3));
Assert.True(false, "Failed to throw exception.");
}
catch (AggregateException ex)
{
Assert.Equal(1, ex.InnerExceptions.Count);
Assert.IsType<OperationCanceledException>(ex.InnerExceptions[0]);
Assert.Equal(State.Started, factory.State);
}
}
[Fact]
public void failed_transition_reverts_back_to_original_state()
{
var factory = new StateMachineTaskFactory<State>();
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.FromResult(true));
factory.RegisterTransition(State.Stopping, State.Stopped, (ct, o) => { throw new InvalidOperationException("Something went wrong"); });
var startedTask = factory.TransitionTo(State.Started);
var stoppedTask = factory.TransitionTo(State.Stopped);
startedTask.Wait(TimeSpan.FromSeconds(3));
try
{
stoppedTask.Wait(TimeSpan.FromSeconds(3));
Assert.True(false, "Failed to throw exception.");
}
catch (AggregateException ex)
{
Assert.Equal(1, ex.InnerExceptions.Count);
Assert.IsType<InvalidOperationException>(ex.InnerExceptions[0]);
Assert.Equal(State.Started, factory.State);
}
}
[Fact]
public void state_change_is_raised_as_state_changes()
{
var factory = new StateMachineTaskFactory<State>(State.Stopped);
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.FromResult(true));
factory.RegisterTransition(State.Stopping, State.Stopped, (ct, o) => TaskUtil.FromResult(true));
var stateChanges = new List<StateChangedEventArgs<State>>();
factory.StateChanged += (s, e) => stateChanges.Add(e);
factory.TransitionTo(State.Started).Wait(TimeSpan.FromSeconds(1));
factory.TransitionTo(State.Stopped).Wait(TimeSpan.FromSeconds(1));
factory.TransitionTo(State.Started).Wait(TimeSpan.FromSeconds(1));
factory.TransitionTo(State.Stopped).Wait(TimeSpan.FromSeconds(1));
Assert.Equal(8, stateChanges.Count);
Assert.Equal(State.Stopped, stateChanges[0].OldState);
Assert.Equal(State.Starting, stateChanges[0].NewState);
Assert.Equal(State.Starting, stateChanges[1].OldState);
Assert.Equal(State.Started, stateChanges[1].NewState);
Assert.Equal(State.Started, stateChanges[2].OldState);
Assert.Equal(State.Stopping, stateChanges[2].NewState);
Assert.Equal(State.Stopping, stateChanges[3].OldState);
Assert.Equal(State.Stopped, stateChanges[3].NewState);
Assert.Equal(State.Stopped, stateChanges[4].OldState);
Assert.Equal(State.Starting, stateChanges[4].NewState);
Assert.Equal(State.Starting, stateChanges[5].OldState);
Assert.Equal(State.Started, stateChanges[5].NewState);
Assert.Equal(State.Started, stateChanges[6].OldState);
Assert.Equal(State.Stopping, stateChanges[6].NewState);
Assert.Equal(State.Stopping, stateChanges[7].OldState);
Assert.Equal(State.Stopped, stateChanges[7].NewState);
}
[Fact]
public void state_gets_the_current_state()
{
var factory = new StateMachineTaskFactory<State>(State.Stopped);
factory.RegisterTransition(State.Starting, State.Started, (ct, o) => TaskUtil.Delay(TimeSpan.FromMilliseconds(100)));
factory.RegisterTransition(State.Stopping, State.Stopped, (ct, o) => TaskUtil.Delay(TimeSpan.FromMilliseconds(100)));
var task = factory.TransitionTo(State.Started);
Assert.Equal(State.Starting, factory.State);
task.Wait(TimeSpan.FromSeconds(3));
Assert.Equal(State.Started, factory.State);
task = factory.TransitionTo(State.Stopped);
Assert.Equal(State.Stopping, factory.State);
task.Wait(TimeSpan.FromSeconds(3));
Assert.Equal(State.Stopped, factory.State);
}
}